feat(memory): add provider interface (#72)
This commit is contained in:
@@ -7,6 +7,7 @@ import os
|
|||||||
|
|
||||||
from .memory import MemoryManager
|
from .memory import MemoryManager
|
||||||
from .memory_vector import MemoryVectorStore
|
from .memory_vector import MemoryVectorStore
|
||||||
|
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -42,6 +43,10 @@ class MemoryService:
|
|||||||
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
||||||
os.path.join(data_dir, "memory_vectors")
|
os.path.join(data_dir, "memory_vectors")
|
||||||
) else None
|
) else None
|
||||||
|
self.provider = NativeMemoryProvider(self.manager, self.vector_store)
|
||||||
|
|
||||||
|
def _sync_provider(self) -> None:
|
||||||
|
self.provider.memory_vector = self.vector_store
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_memory(entry: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> Memory:
|
def _to_memory(entry: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> Memory:
|
||||||
@@ -53,6 +58,19 @@ class MemoryService:
|
|||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_to_memory(record: MemoryRecord, metadata: Optional[Dict[str, Any]] = None) -> Memory:
|
||||||
|
merged_metadata = dict(record.metadata)
|
||||||
|
if metadata:
|
||||||
|
merged_metadata.update(metadata)
|
||||||
|
return Memory(
|
||||||
|
id=record.id,
|
||||||
|
text=record.text,
|
||||||
|
timestamp=record.timestamp,
|
||||||
|
session_id=record.session_id,
|
||||||
|
metadata=merged_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
|
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
|
||||||
"""
|
"""
|
||||||
Store a new memory.
|
Store a new memory.
|
||||||
@@ -64,19 +82,9 @@ class MemoryService:
|
|||||||
Returns:
|
Returns:
|
||||||
Created Memory object
|
Created Memory object
|
||||||
"""
|
"""
|
||||||
entry = self.manager.add_entry(text)
|
self._sync_provider()
|
||||||
if session_id:
|
record = await self.provider.remember(text, session_id=session_id)
|
||||||
entry["session_id"] = session_id
|
return self._record_to_memory(record)
|
||||||
|
|
||||||
memories = self.manager.load_all()
|
|
||||||
memories.append(entry)
|
|
||||||
self.manager.save(memories)
|
|
||||||
|
|
||||||
# Also add to vector store if available
|
|
||||||
if self.vector_store and self.vector_store.healthy:
|
|
||||||
self.vector_store.add(entry["id"], entry["text"])
|
|
||||||
|
|
||||||
return self._to_memory(entry)
|
|
||||||
|
|
||||||
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
|
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
|
||||||
"""
|
"""
|
||||||
@@ -89,28 +97,20 @@ class MemoryService:
|
|||||||
Returns:
|
Returns:
|
||||||
MemorySearchResult with matching memories
|
MemorySearchResult with matching memories
|
||||||
"""
|
"""
|
||||||
# Try vector search first
|
self._sync_provider()
|
||||||
all_memories = self.manager.load_all()
|
results = await self.provider.recall(query, top_k=top_k)
|
||||||
by_id = {m.get("id"): m for m in all_memories}
|
memories = [
|
||||||
if self.vector_store and self.vector_store.healthy:
|
self._record_to_memory(hit.memory, metadata={"score": hit.score})
|
||||||
results = self.vector_store.search(query, k=top_k)
|
if hit.score is not None
|
||||||
found = []
|
else self._record_to_memory(hit.memory)
|
||||||
for result in results:
|
for hit in results
|
||||||
entry = by_id.get(result.get("memory_id"))
|
]
|
||||||
if entry:
|
|
||||||
found.append(self._to_memory(entry, metadata={"score": result.get("score")}))
|
|
||||||
if found:
|
|
||||||
return MemorySearchResult(memories=found, query=query, total=len(found))
|
|
||||||
|
|
||||||
# Fallback to keyword search
|
|
||||||
results = self.manager.get_relevant_memories(query, all_memories, max_items=top_k)
|
|
||||||
memories = [self._to_memory(m) for m in results]
|
|
||||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
||||||
|
|
||||||
def get_all(self, limit: int = 100) -> List[Memory]:
|
def get_all(self, limit: int = 100) -> List[Memory]:
|
||||||
"""Get all memories."""
|
"""Get all memories."""
|
||||||
memories = self.manager.load_all()[:limit]
|
records = self.manager.load_all()[:limit]
|
||||||
return [self._to_memory(m) for m in memories]
|
return [self._to_memory(m) for m in records]
|
||||||
|
|
||||||
def delete(self, memory_id: str) -> bool:
|
def delete(self, memory_id: str) -> bool:
|
||||||
"""Delete a memory by ID."""
|
"""Delete a memory by ID."""
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from src.constants import (
|
|||||||
SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY
|
SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY
|
||||||
)
|
)
|
||||||
from src.memory import MemoryManager
|
from src.memory import MemoryManager
|
||||||
|
from src.memory_provider import MemoryProviderRegistry, NativeMemoryProvider
|
||||||
from services.memory.skills import SkillsManager
|
from services.memory.skills import SkillsManager
|
||||||
from core.session_manager import SessionManager
|
from core.session_manager import SessionManager
|
||||||
from core.models import set_session_manager
|
from core.models import set_session_manager
|
||||||
@@ -73,6 +74,10 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
|
|||||||
logger.warning(f"MemoryVectorStore DEGRADED: {e}")
|
logger.warning(f"MemoryVectorStore DEGRADED: {e}")
|
||||||
memory_vector = None
|
memory_vector = None
|
||||||
|
|
||||||
|
memory_provider_registry = MemoryProviderRegistry([
|
||||||
|
NativeMemoryProvider(memory_manager, memory_vector),
|
||||||
|
])
|
||||||
|
|
||||||
# Initialize processors
|
# Initialize processors
|
||||||
chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager)
|
chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager)
|
||||||
research_handler = ResearchHandler()
|
research_handler = ResearchHandler()
|
||||||
@@ -99,6 +104,7 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
|
|||||||
return {
|
return {
|
||||||
"memory_manager": memory_manager,
|
"memory_manager": memory_manager,
|
||||||
"memory_vector": memory_vector,
|
"memory_vector": memory_vector,
|
||||||
|
"memory_provider_registry": memory_provider_registry,
|
||||||
"skills_manager": skills_manager,
|
"skills_manager": skills_manager,
|
||||||
"session_manager": session_manager,
|
"session_manager": session_manager,
|
||||||
"upload_handler": upload_handler,
|
"upload_handler": upload_handler,
|
||||||
|
|||||||
320
src/memory_provider.py
Normal file
320
src/memory_provider.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""Memory provider interfaces for native and external memory systems."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryRecord:
|
||||||
|
"""Provider-neutral memory entry."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
timestamp: int = 0
|
||||||
|
category: str = "fact"
|
||||||
|
source: str = "unknown"
|
||||||
|
owner: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemorySearchHit:
|
||||||
|
"""A memory returned by provider recall."""
|
||||||
|
|
||||||
|
memory: MemoryRecord
|
||||||
|
provider_id: str
|
||||||
|
score: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProvider(ABC):
|
||||||
|
"""Base contract for Odysseus memory providers.
|
||||||
|
|
||||||
|
The native memory provider should always be available. External providers
|
||||||
|
can add recall/write behavior and their own tools without replacing the
|
||||||
|
built-in local memory baseline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider_id = "unknown"
|
||||||
|
display_name = "Unknown"
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Prepare provider resources before use."""
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Release provider resources."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remember(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
category: str = "fact",
|
||||||
|
source: str = "user",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> MemoryRecord:
|
||||||
|
"""Store a memory and return the stored record."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def recall(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[MemorySearchHit]:
|
||||||
|
"""Return provider memories relevant to the query."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_memories(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[MemoryRecord]:
|
||||||
|
"""List memories visible to the owner."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
|
||||||
|
"""Delete a memory by ID when allowed by the provider."""
|
||||||
|
|
||||||
|
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Return provider-defined tool schemas when this provider is enabled."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
|
"""Handle a provider-defined tool call."""
|
||||||
|
raise KeyError(f"Provider {self.provider_id} does not expose tool {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class NativeMemoryProvider(MemoryProvider):
|
||||||
|
"""Provider adapter for Odysseus' built-in memory manager and vector store."""
|
||||||
|
|
||||||
|
provider_id = "native"
|
||||||
|
display_name = "Odysseus native memory"
|
||||||
|
|
||||||
|
_CORE_FIELDS = {
|
||||||
|
"id",
|
||||||
|
"text",
|
||||||
|
"timestamp",
|
||||||
|
"source",
|
||||||
|
"category",
|
||||||
|
"uses",
|
||||||
|
"owner",
|
||||||
|
"session_id",
|
||||||
|
"metadata",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, memory_manager, memory_vector=None):
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
self.memory_vector = memory_vector
|
||||||
|
|
||||||
|
def _to_record(self, entry: Dict[str, Any]) -> MemoryRecord:
|
||||||
|
metadata = {
|
||||||
|
key: value
|
||||||
|
for key, value in entry.items()
|
||||||
|
if key not in self._CORE_FIELDS
|
||||||
|
}
|
||||||
|
stored_metadata = entry.get("metadata")
|
||||||
|
if isinstance(stored_metadata, dict):
|
||||||
|
metadata.update(stored_metadata)
|
||||||
|
|
||||||
|
return MemoryRecord(
|
||||||
|
id=entry.get("id", ""),
|
||||||
|
text=entry.get("text", ""),
|
||||||
|
timestamp=entry.get("timestamp", 0),
|
||||||
|
category=entry.get("category", "fact"),
|
||||||
|
source=entry.get("source", "unknown"),
|
||||||
|
owner=entry.get("owner"),
|
||||||
|
session_id=entry.get("session_id"),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remember(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
category: str = "fact",
|
||||||
|
source: str = "user",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> MemoryRecord:
|
||||||
|
entry = self.memory_manager.add_entry(
|
||||||
|
text,
|
||||||
|
source=source,
|
||||||
|
category=category,
|
||||||
|
owner=owner,
|
||||||
|
)
|
||||||
|
if session_id:
|
||||||
|
entry["session_id"] = session_id
|
||||||
|
if metadata:
|
||||||
|
entry["metadata"] = dict(metadata)
|
||||||
|
|
||||||
|
memories = self.memory_manager.load_all()
|
||||||
|
memories.append(entry)
|
||||||
|
self.memory_manager.save(memories)
|
||||||
|
|
||||||
|
if self._vector_available():
|
||||||
|
self.memory_vector.add(entry["id"], entry["text"])
|
||||||
|
|
||||||
|
return self._to_record(entry)
|
||||||
|
|
||||||
|
async def recall(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[MemorySearchHit]:
|
||||||
|
memories = self.memory_manager.load(owner=owner)
|
||||||
|
by_id = {m.get("id"): m for m in memories}
|
||||||
|
|
||||||
|
if self._vector_available():
|
||||||
|
hits: List[MemorySearchHit] = []
|
||||||
|
for result in self.memory_vector.search(query, k=top_k):
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
continue
|
||||||
|
memory_id = result.get("memory_id")
|
||||||
|
entry = by_id.get(memory_id) if memory_id else result
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
if owner is not None and entry.get("owner") != owner:
|
||||||
|
continue
|
||||||
|
hits.append(
|
||||||
|
MemorySearchHit(
|
||||||
|
memory=self._to_record(entry),
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
score=result.get("score"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hits:
|
||||||
|
return hits
|
||||||
|
|
||||||
|
fallback = self.memory_manager.get_relevant_memories(
|
||||||
|
query,
|
||||||
|
memories,
|
||||||
|
max_items=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
MemorySearchHit(
|
||||||
|
memory=self._to_record(entry),
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
score=None,
|
||||||
|
)
|
||||||
|
for entry in fallback
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_memories(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[MemoryRecord]:
|
||||||
|
return [
|
||||||
|
self._to_record(entry)
|
||||||
|
for entry in self.memory_manager.load(owner=owner)[:limit]
|
||||||
|
]
|
||||||
|
|
||||||
|
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
|
||||||
|
memories = self.memory_manager.load_all()
|
||||||
|
remaining = []
|
||||||
|
deleted_id = None
|
||||||
|
|
||||||
|
for entry in memories:
|
||||||
|
if entry.get("id") != memory_id:
|
||||||
|
remaining.append(entry)
|
||||||
|
continue
|
||||||
|
if owner is not None and entry.get("owner") != owner:
|
||||||
|
remaining.append(entry)
|
||||||
|
continue
|
||||||
|
deleted_id = entry.get("id")
|
||||||
|
|
||||||
|
if deleted_id is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.memory_manager.save(remaining)
|
||||||
|
if self._vector_available():
|
||||||
|
self.memory_vector.remove(deleted_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _vector_available(self) -> bool:
|
||||||
|
return bool(self.memory_vector and getattr(self.memory_vector, "healthy", True))
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProviderRegistry:
|
||||||
|
"""Container for native and optional external memory providers."""
|
||||||
|
|
||||||
|
def __init__(self, providers: Optional[Iterable[MemoryProvider]] = None):
|
||||||
|
self._providers: Dict[str, MemoryProvider] = {}
|
||||||
|
for provider in providers or []:
|
||||||
|
self.register(provider)
|
||||||
|
|
||||||
|
def register(self, provider: MemoryProvider) -> None:
|
||||||
|
if provider.provider_id in self._providers:
|
||||||
|
raise ValueError(f"Memory provider already registered: {provider.provider_id}")
|
||||||
|
self._providers[provider.provider_id] = provider
|
||||||
|
|
||||||
|
def get(self, provider_id: str) -> MemoryProvider:
|
||||||
|
return self._providers[provider_id]
|
||||||
|
|
||||||
|
def all(self) -> List[MemoryProvider]:
|
||||||
|
return list(self._providers.values())
|
||||||
|
|
||||||
|
def active(self) -> List[MemoryProvider]:
|
||||||
|
return [provider for provider in self._providers.values() if provider.enabled]
|
||||||
|
|
||||||
|
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||||
|
schemas: List[Dict[str, Any]] = []
|
||||||
|
seen: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for provider in self.active():
|
||||||
|
for schema in provider.get_tool_schemas():
|
||||||
|
name = self._tool_name(schema)
|
||||||
|
if name in seen:
|
||||||
|
raise ValueError(
|
||||||
|
f"Memory tool name conflict: {name} from "
|
||||||
|
f"{provider.provider_id} already exposed by {seen[name]}"
|
||||||
|
)
|
||||||
|
seen[name] = provider.provider_id
|
||||||
|
schemas.append(schema)
|
||||||
|
|
||||||
|
return schemas
|
||||||
|
|
||||||
|
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
|
provider_by_tool: Dict[str, MemoryProvider] = {}
|
||||||
|
for provider in self.active():
|
||||||
|
for schema in provider.get_tool_schemas():
|
||||||
|
tool_name = self._tool_name(schema)
|
||||||
|
if tool_name in provider_by_tool:
|
||||||
|
raise ValueError(
|
||||||
|
f"Memory tool name conflict: {tool_name} from "
|
||||||
|
f"{provider.provider_id} already exposed by "
|
||||||
|
f"{provider_by_tool[tool_name].provider_id}"
|
||||||
|
)
|
||||||
|
provider_by_tool[tool_name] = provider
|
||||||
|
|
||||||
|
provider = provider_by_tool.get(name)
|
||||||
|
if provider:
|
||||||
|
return await provider.handle_tool_call(name, arguments)
|
||||||
|
raise KeyError(f"No active memory provider exposes tool {name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_name(schema: Dict[str, Any]) -> str:
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
raise ValueError("Memory provider tool schema must be a dict")
|
||||||
|
name = schema.get("name")
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
return name
|
||||||
|
function = schema.get("function")
|
||||||
|
if isinstance(function, dict):
|
||||||
|
function_name = function.get("name")
|
||||||
|
if isinstance(function_name, str) and function_name:
|
||||||
|
return function_name
|
||||||
|
raise ValueError("Memory provider tool schema is missing a tool name")
|
||||||
181
tests/test_memory_provider.py
Normal file
181
tests/test_memory_provider.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""Tests for the memory provider interface and native adapter."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class FakeVectorStore:
|
||||||
|
healthy = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.added = []
|
||||||
|
self.removed = []
|
||||||
|
self.results = []
|
||||||
|
|
||||||
|
def add(self, memory_id, text):
|
||||||
|
self.added.append((memory_id, text))
|
||||||
|
|
||||||
|
def remove(self, memory_id):
|
||||||
|
self.removed.append(memory_id)
|
||||||
|
|
||||||
|
def search(self, query, k=5):
|
||||||
|
return self.results[:k]
|
||||||
|
|
||||||
|
|
||||||
|
def run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_provider_remember_writes_native_memory_and_vector(tmp_path):
|
||||||
|
from src.memory import MemoryManager
|
||||||
|
from src.memory_provider import NativeMemoryProvider
|
||||||
|
|
||||||
|
manager = MemoryManager(str(tmp_path))
|
||||||
|
vector = FakeVectorStore()
|
||||||
|
provider = NativeMemoryProvider(manager, vector)
|
||||||
|
|
||||||
|
record = run(provider.remember(
|
||||||
|
"User prefers concise responses",
|
||||||
|
owner="alice",
|
||||||
|
session_id="session-1",
|
||||||
|
category="preference",
|
||||||
|
metadata={"confidence": 0.9},
|
||||||
|
))
|
||||||
|
|
||||||
|
stored = manager.load(owner="alice")
|
||||||
|
assert len(stored) == 1
|
||||||
|
assert stored[0]["id"] == record.id
|
||||||
|
assert stored[0]["text"] == "User prefers concise responses"
|
||||||
|
assert stored[0]["category"] == "preference"
|
||||||
|
assert stored[0]["session_id"] == "session-1"
|
||||||
|
assert record.metadata["confidence"] == 0.9
|
||||||
|
assert vector.added == [(record.id, "User prefers concise responses")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_provider_recall_filters_vector_hits_by_owner(tmp_path):
|
||||||
|
from src.memory import MemoryManager
|
||||||
|
from src.memory_provider import NativeMemoryProvider
|
||||||
|
|
||||||
|
manager = MemoryManager(str(tmp_path))
|
||||||
|
vector = FakeVectorStore()
|
||||||
|
provider = NativeMemoryProvider(manager, vector)
|
||||||
|
|
||||||
|
alice = run(provider.remember("Alice likes green tea", owner="alice"))
|
||||||
|
bob = run(provider.remember("Bob likes espresso", owner="bob"))
|
||||||
|
vector.results = [
|
||||||
|
{"memory_id": bob.id, "score": 0.99},
|
||||||
|
{"memory_id": alice.id, "score": 0.75},
|
||||||
|
]
|
||||||
|
|
||||||
|
hits = run(provider.recall("what does Alice like?", owner="alice", top_k=5))
|
||||||
|
|
||||||
|
assert [hit.memory.id for hit in hits] == [alice.id]
|
||||||
|
assert hits[0].provider_id == "native"
|
||||||
|
assert hits[0].score == 0.75
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_provider_recall_accepts_legacy_vector_rows(tmp_path):
|
||||||
|
from src.memory import MemoryManager
|
||||||
|
from src.memory_provider import NativeMemoryProvider
|
||||||
|
|
||||||
|
manager = MemoryManager(str(tmp_path))
|
||||||
|
vector = FakeVectorStore()
|
||||||
|
provider = NativeMemoryProvider(manager, vector)
|
||||||
|
|
||||||
|
vector.results = [
|
||||||
|
{"id": "legacy-1", "text": "real memory", "timestamp": 5},
|
||||||
|
"corrupt-row",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
hits = run(provider.recall("anything", top_k=5))
|
||||||
|
|
||||||
|
assert [hit.memory.id for hit in hits] == ["legacy-1"]
|
||||||
|
assert hits[0].memory.text == "real memory"
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_provider_recall_falls_back_to_keyword_search(tmp_path):
|
||||||
|
from src.memory import MemoryManager
|
||||||
|
from src.memory_provider import NativeMemoryProvider
|
||||||
|
|
||||||
|
manager = MemoryManager(str(tmp_path))
|
||||||
|
provider = NativeMemoryProvider(manager)
|
||||||
|
saved = run(provider.remember(
|
||||||
|
"Alice prefers markdown notes",
|
||||||
|
owner="alice",
|
||||||
|
category="preference",
|
||||||
|
))
|
||||||
|
|
||||||
|
hits = run(provider.recall("markdown preference", owner="alice", top_k=3))
|
||||||
|
|
||||||
|
assert [hit.memory.id for hit in hits] == [saved.id]
|
||||||
|
assert hits[0].score is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_provider_registry_exposes_only_active_provider_tools():
|
||||||
|
from src.memory_provider import MemoryProvider, MemoryProviderRegistry
|
||||||
|
|
||||||
|
class DummyProvider(MemoryProvider):
|
||||||
|
def __init__(self, provider_id, enabled=True):
|
||||||
|
self.provider_id = provider_id
|
||||||
|
self.display_name = provider_id
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
async def remember(self, text, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def recall(self, query, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def list_memories(self, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def delete(self, memory_id, **kwargs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_tool_schemas(self):
|
||||||
|
return [{"name": f"{self.provider_id}_search", "description": "Search memory"}]
|
||||||
|
|
||||||
|
registry = MemoryProviderRegistry([
|
||||||
|
DummyProvider("active"),
|
||||||
|
DummyProvider("disabled", enabled=False),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert registry.get_tool_schemas() == [
|
||||||
|
{"name": "active_search", "description": "Search memory"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_provider_registry_rejects_tool_name_conflicts():
|
||||||
|
from src.memory_provider import MemoryProvider, MemoryProviderRegistry
|
||||||
|
|
||||||
|
class ConflictingProvider(MemoryProvider):
|
||||||
|
def __init__(self, provider_id):
|
||||||
|
self.provider_id = provider_id
|
||||||
|
self.display_name = provider_id
|
||||||
|
|
||||||
|
async def remember(self, text, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def recall(self, query, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def list_memories(self, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def delete(self, memory_id, **kwargs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_tool_schemas(self):
|
||||||
|
return [{"name": "memory_search"}]
|
||||||
|
|
||||||
|
registry = MemoryProviderRegistry([
|
||||||
|
ConflictingProvider("first"),
|
||||||
|
ConflictingProvider("second"),
|
||||||
|
])
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry.get_tool_schemas()
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "memory_search" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected duplicate memory tool names to be rejected")
|
||||||
Reference in New Issue
Block a user