diff --git a/services/memory/service.py b/services/memory/service.py index d07eb17..0a5b9b5 100644 --- a/services/memory/service.py +++ b/services/memory/service.py @@ -7,6 +7,7 @@ import os from .memory import MemoryManager from .memory_vector import MemoryVectorStore +from src.memory_provider import MemoryRecord, NativeMemoryProvider @dataclass @@ -42,6 +43,10 @@ class MemoryService: self.vector_store = MemoryVectorStore(data_dir) if os.path.exists( os.path.join(data_dir, "memory_vectors") ) else None + self.provider = NativeMemoryProvider(self.manager, self.vector_store) + + def _sync_provider(self) -> None: + self.provider.memory_vector = self.vector_store @staticmethod def _to_memory(entry: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> Memory: @@ -53,6 +58,19 @@ class MemoryService: metadata=metadata or {}, ) + @staticmethod + def _record_to_memory(record: MemoryRecord, metadata: Optional[Dict[str, Any]] = None) -> Memory: + merged_metadata = dict(record.metadata) + if metadata: + merged_metadata.update(metadata) + return Memory( + id=record.id, + text=record.text, + timestamp=record.timestamp, + session_id=record.session_id, + metadata=merged_metadata, + ) + async def remember(self, text: str, session_id: Optional[str] = None) -> Memory: """ Store a new memory. @@ -64,19 +82,9 @@ class MemoryService: Returns: Created Memory object """ - entry = self.manager.add_entry(text) - if session_id: - entry["session_id"] = session_id - - memories = self.manager.load_all() - memories.append(entry) - self.manager.save(memories) - - # Also add to vector store if available - if self.vector_store and self.vector_store.healthy: - self.vector_store.add(entry["id"], entry["text"]) - - return self._to_memory(entry) + self._sync_provider() + record = await self.provider.remember(text, session_id=session_id) + return self._record_to_memory(record) async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult: """ @@ -89,28 +97,20 @@ class MemoryService: Returns: MemorySearchResult with matching memories """ - # Try vector search first - all_memories = self.manager.load_all() - by_id = {m.get("id"): m for m in all_memories} - if self.vector_store and self.vector_store.healthy: - results = self.vector_store.search(query, k=top_k) - found = [] - for result in results: - entry = by_id.get(result.get("memory_id")) - if entry: - found.append(self._to_memory(entry, metadata={"score": result.get("score")})) - if found: - return MemorySearchResult(memories=found, query=query, total=len(found)) - - # Fallback to keyword search - results = self.manager.get_relevant_memories(query, all_memories, max_items=top_k) - memories = [self._to_memory(m) for m in results] + self._sync_provider() + results = await self.provider.recall(query, top_k=top_k) + memories = [ + self._record_to_memory(hit.memory, metadata={"score": hit.score}) + if hit.score is not None + else self._record_to_memory(hit.memory) + for hit in results + ] return MemorySearchResult(memories=memories, query=query, total=len(memories)) def get_all(self, limit: int = 100) -> List[Memory]: """Get all memories.""" - memories = self.manager.load_all()[:limit] - return [self._to_memory(m) for m in memories] + records = self.manager.load_all()[:limit] + return [self._to_memory(m) for m in records] def delete(self, memory_id: str) -> bool: """Delete a memory by ID.""" diff --git a/src/app_initializer.py b/src/app_initializer.py index 1cfa308..7d6b8c2 100644 --- a/src/app_initializer.py +++ b/src/app_initializer.py @@ -9,6 +9,7 @@ from src.constants import ( SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY ) from src.memory import MemoryManager +from src.memory_provider import MemoryProviderRegistry, NativeMemoryProvider from services.memory.skills import SkillsManager from core.session_manager import SessionManager from core.models import set_session_manager @@ -73,6 +74,10 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]: logger.warning(f"MemoryVectorStore DEGRADED: {e}") memory_vector = None + memory_provider_registry = MemoryProviderRegistry([ + NativeMemoryProvider(memory_manager, memory_vector), + ]) + # Initialize processors chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager) research_handler = ResearchHandler() @@ -99,6 +104,7 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]: return { "memory_manager": memory_manager, "memory_vector": memory_vector, + "memory_provider_registry": memory_provider_registry, "skills_manager": skills_manager, "session_manager": session_manager, "upload_handler": upload_handler, diff --git a/src/memory_provider.py b/src/memory_provider.py new file mode 100644 index 0000000..925c591 --- /dev/null +++ b/src/memory_provider.py @@ -0,0 +1,320 @@ +"""Memory provider interfaces for native and external memory systems.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional + + +@dataclass +class MemoryRecord: + """Provider-neutral memory entry.""" + + id: str + text: str + timestamp: int = 0 + category: str = "fact" + source: str = "unknown" + owner: Optional[str] = None + session_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MemorySearchHit: + """A memory returned by provider recall.""" + + memory: MemoryRecord + provider_id: str + score: Optional[float] = None + + +class MemoryProvider(ABC): + """Base contract for Odysseus memory providers. + + The native memory provider should always be available. External providers + can add recall/write behavior and their own tools without replacing the + built-in local memory baseline. + """ + + provider_id = "unknown" + display_name = "Unknown" + enabled = True + + async def initialize(self) -> None: + """Prepare provider resources before use.""" + + async def shutdown(self) -> None: + """Release provider resources.""" + + @abstractmethod + async def remember( + self, + text: str, + *, + owner: Optional[str] = None, + session_id: Optional[str] = None, + category: str = "fact", + source: str = "user", + metadata: Optional[Dict[str, Any]] = None, + ) -> MemoryRecord: + """Store a memory and return the stored record.""" + + @abstractmethod + async def recall( + self, + query: str, + *, + owner: Optional[str] = None, + top_k: int = 5, + ) -> List[MemorySearchHit]: + """Return provider memories relevant to the query.""" + + @abstractmethod + async def list_memories( + self, + *, + owner: Optional[str] = None, + limit: int = 100, + ) -> List[MemoryRecord]: + """List memories visible to the owner.""" + + @abstractmethod + async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool: + """Delete a memory by ID when allowed by the provider.""" + + def get_tool_schemas(self) -> List[Dict[str, Any]]: + """Return provider-defined tool schemas when this provider is enabled.""" + return [] + + async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any: + """Handle a provider-defined tool call.""" + raise KeyError(f"Provider {self.provider_id} does not expose tool {name}") + + +class NativeMemoryProvider(MemoryProvider): + """Provider adapter for Odysseus' built-in memory manager and vector store.""" + + provider_id = "native" + display_name = "Odysseus native memory" + + _CORE_FIELDS = { + "id", + "text", + "timestamp", + "source", + "category", + "uses", + "owner", + "session_id", + "metadata", + } + + def __init__(self, memory_manager, memory_vector=None): + self.memory_manager = memory_manager + self.memory_vector = memory_vector + + def _to_record(self, entry: Dict[str, Any]) -> MemoryRecord: + metadata = { + key: value + for key, value in entry.items() + if key not in self._CORE_FIELDS + } + stored_metadata = entry.get("metadata") + if isinstance(stored_metadata, dict): + metadata.update(stored_metadata) + + return MemoryRecord( + id=entry.get("id", ""), + text=entry.get("text", ""), + timestamp=entry.get("timestamp", 0), + category=entry.get("category", "fact"), + source=entry.get("source", "unknown"), + owner=entry.get("owner"), + session_id=entry.get("session_id"), + metadata=metadata, + ) + + async def remember( + self, + text: str, + *, + owner: Optional[str] = None, + session_id: Optional[str] = None, + category: str = "fact", + source: str = "user", + metadata: Optional[Dict[str, Any]] = None, + ) -> MemoryRecord: + entry = self.memory_manager.add_entry( + text, + source=source, + category=category, + owner=owner, + ) + if session_id: + entry["session_id"] = session_id + if metadata: + entry["metadata"] = dict(metadata) + + memories = self.memory_manager.load_all() + memories.append(entry) + self.memory_manager.save(memories) + + if self._vector_available(): + self.memory_vector.add(entry["id"], entry["text"]) + + return self._to_record(entry) + + async def recall( + self, + query: str, + *, + owner: Optional[str] = None, + top_k: int = 5, + ) -> List[MemorySearchHit]: + memories = self.memory_manager.load(owner=owner) + by_id = {m.get("id"): m for m in memories} + + if self._vector_available(): + hits: List[MemorySearchHit] = [] + for result in self.memory_vector.search(query, k=top_k): + if not isinstance(result, dict): + continue + memory_id = result.get("memory_id") + entry = by_id.get(memory_id) if memory_id else result + if not entry: + continue + if owner is not None and entry.get("owner") != owner: + continue + hits.append( + MemorySearchHit( + memory=self._to_record(entry), + provider_id=self.provider_id, + score=result.get("score"), + ) + ) + if hits: + return hits + + fallback = self.memory_manager.get_relevant_memories( + query, + memories, + max_items=top_k, + ) + return [ + MemorySearchHit( + memory=self._to_record(entry), + provider_id=self.provider_id, + score=None, + ) + for entry in fallback + ] + + async def list_memories( + self, + *, + owner: Optional[str] = None, + limit: int = 100, + ) -> List[MemoryRecord]: + return [ + self._to_record(entry) + for entry in self.memory_manager.load(owner=owner)[:limit] + ] + + async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool: + memories = self.memory_manager.load_all() + remaining = [] + deleted_id = None + + for entry in memories: + if entry.get("id") != memory_id: + remaining.append(entry) + continue + if owner is not None and entry.get("owner") != owner: + remaining.append(entry) + continue + deleted_id = entry.get("id") + + if deleted_id is None: + return False + + self.memory_manager.save(remaining) + if self._vector_available(): + self.memory_vector.remove(deleted_id) + return True + + def _vector_available(self) -> bool: + return bool(self.memory_vector and getattr(self.memory_vector, "healthy", True)) + + +class MemoryProviderRegistry: + """Container for native and optional external memory providers.""" + + def __init__(self, providers: Optional[Iterable[MemoryProvider]] = None): + self._providers: Dict[str, MemoryProvider] = {} + for provider in providers or []: + self.register(provider) + + def register(self, provider: MemoryProvider) -> None: + if provider.provider_id in self._providers: + raise ValueError(f"Memory provider already registered: {provider.provider_id}") + self._providers[provider.provider_id] = provider + + def get(self, provider_id: str) -> MemoryProvider: + return self._providers[provider_id] + + def all(self) -> List[MemoryProvider]: + return list(self._providers.values()) + + def active(self) -> List[MemoryProvider]: + return [provider for provider in self._providers.values() if provider.enabled] + + def get_tool_schemas(self) -> List[Dict[str, Any]]: + schemas: List[Dict[str, Any]] = [] + seen: Dict[str, str] = {} + + for provider in self.active(): + for schema in provider.get_tool_schemas(): + name = self._tool_name(schema) + if name in seen: + raise ValueError( + f"Memory tool name conflict: {name} from " + f"{provider.provider_id} already exposed by {seen[name]}" + ) + seen[name] = provider.provider_id + schemas.append(schema) + + return schemas + + async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any: + provider_by_tool: Dict[str, MemoryProvider] = {} + for provider in self.active(): + for schema in provider.get_tool_schemas(): + tool_name = self._tool_name(schema) + if tool_name in provider_by_tool: + raise ValueError( + f"Memory tool name conflict: {tool_name} from " + f"{provider.provider_id} already exposed by " + f"{provider_by_tool[tool_name].provider_id}" + ) + provider_by_tool[tool_name] = provider + + provider = provider_by_tool.get(name) + if provider: + return await provider.handle_tool_call(name, arguments) + raise KeyError(f"No active memory provider exposes tool {name}") + + @staticmethod + def _tool_name(schema: Dict[str, Any]) -> str: + if not isinstance(schema, dict): + raise ValueError("Memory provider tool schema must be a dict") + name = schema.get("name") + if isinstance(name, str) and name: + return name + function = schema.get("function") + if isinstance(function, dict): + function_name = function.get("name") + if isinstance(function_name, str) and function_name: + return function_name + raise ValueError("Memory provider tool schema is missing a tool name") diff --git a/tests/test_memory_provider.py b/tests/test_memory_provider.py new file mode 100644 index 0000000..5523273 --- /dev/null +++ b/tests/test_memory_provider.py @@ -0,0 +1,181 @@ +"""Tests for the memory provider interface and native adapter.""" + +import asyncio + + +class FakeVectorStore: + healthy = True + + def __init__(self): + self.added = [] + self.removed = [] + self.results = [] + + def add(self, memory_id, text): + self.added.append((memory_id, text)) + + def remove(self, memory_id): + self.removed.append(memory_id) + + def search(self, query, k=5): + return self.results[:k] + + +def run(coro): + return asyncio.run(coro) + + +def test_native_provider_remember_writes_native_memory_and_vector(tmp_path): + from src.memory import MemoryManager + from src.memory_provider import NativeMemoryProvider + + manager = MemoryManager(str(tmp_path)) + vector = FakeVectorStore() + provider = NativeMemoryProvider(manager, vector) + + record = run(provider.remember( + "User prefers concise responses", + owner="alice", + session_id="session-1", + category="preference", + metadata={"confidence": 0.9}, + )) + + stored = manager.load(owner="alice") + assert len(stored) == 1 + assert stored[0]["id"] == record.id + assert stored[0]["text"] == "User prefers concise responses" + assert stored[0]["category"] == "preference" + assert stored[0]["session_id"] == "session-1" + assert record.metadata["confidence"] == 0.9 + assert vector.added == [(record.id, "User prefers concise responses")] + + +def test_native_provider_recall_filters_vector_hits_by_owner(tmp_path): + from src.memory import MemoryManager + from src.memory_provider import NativeMemoryProvider + + manager = MemoryManager(str(tmp_path)) + vector = FakeVectorStore() + provider = NativeMemoryProvider(manager, vector) + + alice = run(provider.remember("Alice likes green tea", owner="alice")) + bob = run(provider.remember("Bob likes espresso", owner="bob")) + vector.results = [ + {"memory_id": bob.id, "score": 0.99}, + {"memory_id": alice.id, "score": 0.75}, + ] + + hits = run(provider.recall("what does Alice like?", owner="alice", top_k=5)) + + assert [hit.memory.id for hit in hits] == [alice.id] + assert hits[0].provider_id == "native" + assert hits[0].score == 0.75 + + +def test_native_provider_recall_accepts_legacy_vector_rows(tmp_path): + from src.memory import MemoryManager + from src.memory_provider import NativeMemoryProvider + + manager = MemoryManager(str(tmp_path)) + vector = FakeVectorStore() + provider = NativeMemoryProvider(manager, vector) + + vector.results = [ + {"id": "legacy-1", "text": "real memory", "timestamp": 5}, + "corrupt-row", + None, + ] + + hits = run(provider.recall("anything", top_k=5)) + + assert [hit.memory.id for hit in hits] == ["legacy-1"] + assert hits[0].memory.text == "real memory" + + +def test_native_provider_recall_falls_back_to_keyword_search(tmp_path): + from src.memory import MemoryManager + from src.memory_provider import NativeMemoryProvider + + manager = MemoryManager(str(tmp_path)) + provider = NativeMemoryProvider(manager) + saved = run(provider.remember( + "Alice prefers markdown notes", + owner="alice", + category="preference", + )) + + hits = run(provider.recall("markdown preference", owner="alice", top_k=3)) + + assert [hit.memory.id for hit in hits] == [saved.id] + assert hits[0].score is None + + +def test_memory_provider_registry_exposes_only_active_provider_tools(): + from src.memory_provider import MemoryProvider, MemoryProviderRegistry + + class DummyProvider(MemoryProvider): + def __init__(self, provider_id, enabled=True): + self.provider_id = provider_id + self.display_name = provider_id + self.enabled = enabled + + async def remember(self, text, **kwargs): + raise NotImplementedError + + async def recall(self, query, **kwargs): + return [] + + async def list_memories(self, **kwargs): + return [] + + async def delete(self, memory_id, **kwargs): + return False + + def get_tool_schemas(self): + return [{"name": f"{self.provider_id}_search", "description": "Search memory"}] + + registry = MemoryProviderRegistry([ + DummyProvider("active"), + DummyProvider("disabled", enabled=False), + ]) + + assert registry.get_tool_schemas() == [ + {"name": "active_search", "description": "Search memory"} + ] + + +def test_memory_provider_registry_rejects_tool_name_conflicts(): + from src.memory_provider import MemoryProvider, MemoryProviderRegistry + + class ConflictingProvider(MemoryProvider): + def __init__(self, provider_id): + self.provider_id = provider_id + self.display_name = provider_id + + async def remember(self, text, **kwargs): + raise NotImplementedError + + async def recall(self, query, **kwargs): + return [] + + async def list_memories(self, **kwargs): + return [] + + async def delete(self, memory_id, **kwargs): + return False + + def get_tool_schemas(self): + return [{"name": "memory_search"}] + + registry = MemoryProviderRegistry([ + ConflictingProvider("first"), + ConflictingProvider("second"), + ]) + + try: + registry.get_tool_schemas() + except ValueError as exc: + assert "memory_search" in str(exc) + else: + raise AssertionError("Expected duplicate memory tool names to be rejected")