refactor(memory): canonicalize memory imports (#50)
This commit is contained in:
@@ -1,364 +1,10 @@
|
|||||||
|
"""Compatibility import for the canonical memory manager.
|
||||||
|
|
||||||
import json
|
Historically this package carried a second copy of ``MemoryManager``. The
|
||||||
import logging
|
application runtime instantiates ``src.memory.MemoryManager``, so keeping a
|
||||||
import os
|
parallel implementation here risks silent drift between import paths.
|
||||||
import time
|
"""
|
||||||
import uuid
|
|
||||||
import re
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from src.memory import MemoryManager, get_text_similarity, tokenize
|
||||||
|
|
||||||
def tokenize(text: str) -> List[str]:
|
__all__ = ["MemoryManager", "get_text_similarity", "tokenize"]
|
||||||
"""Simple tokenizer that splits on whitespace and removes punctuation."""
|
|
||||||
return [cleaned for word in text.split() if (cleaned := word.strip('.,!?";'))]
|
|
||||||
|
|
||||||
def get_text_similarity(text1: str, text2: str) -> float:
|
|
||||||
"""Calculate Jaccard similarity between two texts."""
|
|
||||||
if not text1 or not text2:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
tokens1 = set(tokenize(text1.lower()))
|
|
||||||
tokens2 = set(tokenize(text2.lower()))
|
|
||||||
|
|
||||||
if not tokens1 and not tokens2:
|
|
||||||
return 1.0
|
|
||||||
if not tokens1 or not tokens2:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
intersection = tokens1.intersection(tokens2)
|
|
||||||
union = tokens1.union(tokens2)
|
|
||||||
|
|
||||||
return len(intersection) / len(union)
|
|
||||||
|
|
||||||
class MemoryManager:
|
|
||||||
def __init__(self, data_dir: str):
|
|
||||||
self.memory_file = os.path.join(data_dir, "memory.json")
|
|
||||||
self.ensure_file_exists()
|
|
||||||
|
|
||||||
def extract_memory_from_chat(self, chat_history: List[Dict], session_id: str = None) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Extract memory entries from chat history as a fallback when LLM fails.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_history: List of chat messages with 'role' and 'content' keys
|
|
||||||
session_id: Optional session ID to associate with extracted memories
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of memory entries with text, timestamp, and optional session_id
|
|
||||||
"""
|
|
||||||
memories = []
|
|
||||||
|
|
||||||
for msg in chat_history:
|
|
||||||
if msg.get("role") == "assistant":
|
|
||||||
content = str(msg.get("content", ""))
|
|
||||||
lines = content.split('\n')
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
line = line.strip()
|
|
||||||
# Look for bullet points or numbered lists that might contain memories
|
|
||||||
if re.match(r'^[-*•]|\d+\.', line):
|
|
||||||
# Extract the text after the bullet/number. Group both
|
|
||||||
# markers so the capture applies to either. The previous
|
|
||||||
# `^[-*•]|\d+\.\s*(.*)` put the group on the numbered
|
|
||||||
# branch only, so a bullet line matched with group(1)=None
|
|
||||||
# and crashed on .strip().
|
|
||||||
text_match = re.match(r'^(?:[-*•]|\d+\.)\s*(.*)', line)
|
|
||||||
if text_match:
|
|
||||||
text = text_match.group(1).strip()
|
|
||||||
if text:
|
|
||||||
memories.append({
|
|
||||||
"text": text,
|
|
||||||
"timestamp": int(time.time()),
|
|
||||||
"session_id": session_id
|
|
||||||
})
|
|
||||||
# If we see a heading that suggests memories
|
|
||||||
elif re.search(r'memory|fact|note|remember', line, re.I):
|
|
||||||
pass
|
|
||||||
# If we see a clear separator or end
|
|
||||||
elif re.match(r'^={3,}|-{3,}|_{3,}', line):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return memories
|
|
||||||
|
|
||||||
def process_inline_memory_command(self, message: str) -> Tuple[bool, str]:
|
|
||||||
"""
|
|
||||||
Check if a message is an inline memory command (e.g. "remember: X").
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The user message to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_command, extracted_text) where is_command is True if
|
|
||||||
the message matches the memory command pattern
|
|
||||||
"""
|
|
||||||
# Pattern for memory commands: "remember: X", "memorize: X", "save: X", etc.
|
|
||||||
pattern = r'^(?:remember|memorize|save|note|store)[:\-]?\s+(.+)$'
|
|
||||||
match = re.match(pattern, message.strip(), re.IGNORECASE)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
return True, match.group(1).strip()
|
|
||||||
else:
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
def ensure_file_exists(self):
|
|
||||||
"""Create memory file if it doesn't exist."""
|
|
||||||
if not os.path.exists(self.memory_file):
|
|
||||||
os.makedirs(os.path.dirname(self.memory_file), exist_ok=True)
|
|
||||||
with open(self.memory_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump([], f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
def load_all(self) -> List[Dict]:
|
|
||||||
"""Load all memory entries from JSON file (unfiltered)."""
|
|
||||||
if not os.path.exists(self.memory_file):
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(self.memory_file, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if isinstance(data, list):
|
|
||||||
return self._validate_entries(data)
|
|
||||||
except (json.JSONDecodeError, PermissionError) as e:
|
|
||||||
logger.error("Error loading memory.json: %s", e)
|
|
||||||
return self._migrate_from_legacy()
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
def load(self, owner: str = None) -> List[Dict]:
|
|
||||||
"""Load memory entries, filtered by owner."""
|
|
||||||
entries = self.load_all()
|
|
||||||
if owner is None:
|
|
||||||
return entries
|
|
||||||
return [e for e in entries if e.get("owner") == owner]
|
|
||||||
|
|
||||||
def claim_ownerless(self, owner: str):
|
|
||||||
"""Assign all ownerless memory entries to the given owner. Run once to migrate."""
|
|
||||||
entries = self.load_all()
|
|
||||||
changed = False
|
|
||||||
for e in entries:
|
|
||||||
if not e.get("owner"):
|
|
||||||
e["owner"] = owner
|
|
||||||
changed = True
|
|
||||||
if changed:
|
|
||||||
self.save(entries)
|
|
||||||
logger.info("Claimed %d ownerless memories for %s", sum(1 for e in entries if e.get("owner") == owner), owner)
|
|
||||||
|
|
||||||
def _validate_entries(self, entries: List[Dict]) -> List[Dict]:
|
|
||||||
"""Ensure all entries have required fields."""
|
|
||||||
validated = []
|
|
||||||
for entry in entries:
|
|
||||||
if "id" not in entry:
|
|
||||||
entry["id"] = str(uuid.uuid4())
|
|
||||||
if "timestamp" not in entry:
|
|
||||||
entry["timestamp"] = int(time.time())
|
|
||||||
if "source" not in entry:
|
|
||||||
entry["source"] = "unknown"
|
|
||||||
if "category" not in entry:
|
|
||||||
entry["category"] = "fact"
|
|
||||||
validated.append(entry)
|
|
||||||
return validated
|
|
||||||
|
|
||||||
def _migrate_from_legacy(self) -> List[Dict]:
|
|
||||||
"""Migrate from old text format to JSON if needed."""
|
|
||||||
legacy_path = os.path.join(os.path.dirname(self.memory_file), "memory.txt")
|
|
||||||
if not os.path.exists(legacy_path):
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info("Converting legacy memory.txt to new JSON format")
|
|
||||||
try:
|
|
||||||
with open(legacy_path, "r", encoding="utf-8") as f:
|
|
||||||
lines = [ln.strip() for ln in f.readlines() if ln.strip()]
|
|
||||||
|
|
||||||
entries = []
|
|
||||||
for line in lines:
|
|
||||||
entries.append({
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"text": line,
|
|
||||||
"timestamp": int(time.time()),
|
|
||||||
"source": "user",
|
|
||||||
"category": "fact"
|
|
||||||
})
|
|
||||||
|
|
||||||
self.save(entries)
|
|
||||||
return entries
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to convert legacy memory: %s", e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
def save(self, entries: List[Dict]):
|
|
||||||
"""Save memory entries to JSON file."""
|
|
||||||
# Validate entries before saving
|
|
||||||
for entry in entries:
|
|
||||||
if "id" not in entry:
|
|
||||||
entry["id"] = str(uuid.uuid4())
|
|
||||||
if "timestamp" not in entry:
|
|
||||||
entry["timestamp"] = int(time.time())
|
|
||||||
if "source" not in entry:
|
|
||||||
entry["source"] = "user"
|
|
||||||
if "category" not in entry:
|
|
||||||
entry["category"] = "fact"
|
|
||||||
|
|
||||||
# Use atomic write
|
|
||||||
tmp_file = self.memory_file + ".tmp"
|
|
||||||
with open(tmp_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(entries, f, ensure_ascii=False, indent=2)
|
|
||||||
os.replace(tmp_file, self.memory_file)
|
|
||||||
|
|
||||||
def add_entry(self, text: str, source: str = "user", category: str = "fact", owner: str = None) -> Dict:
|
|
||||||
"""Add a new memory entry."""
|
|
||||||
if not text.strip():
|
|
||||||
raise ValueError("Memory text cannot be empty")
|
|
||||||
|
|
||||||
entry = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"text": text.strip(),
|
|
||||||
"timestamp": int(time.time()),
|
|
||||||
"source": source,
|
|
||||||
"category": category
|
|
||||||
}
|
|
||||||
if owner:
|
|
||||||
entry["owner"] = owner
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def find_duplicates(self, text: str, entries: List[Dict] = None) -> List[Dict]:
|
|
||||||
"""Find duplicate memory entries based on text content."""
|
|
||||||
if entries is None:
|
|
||||||
entries = self.load()
|
|
||||||
|
|
||||||
text_lower = text.strip().lower()
|
|
||||||
return [entry for entry in entries if entry["text"].lower() == text_lower]
|
|
||||||
|
|
||||||
def categorize_memory_by_relevance(self, message: str, memories: list):
|
|
||||||
"""Categorize memories by type and relevance"""
|
|
||||||
categories = {
|
|
||||||
"contacts": [],
|
|
||||||
"preferences": [],
|
|
||||||
"facts": [],
|
|
||||||
"tasks": []
|
|
||||||
}
|
|
||||||
|
|
||||||
msg_lower = message.lower()
|
|
||||||
|
|
||||||
for mem in memories:
|
|
||||||
text_lower = mem["text"].lower()
|
|
||||||
|
|
||||||
# Contact info
|
|
||||||
if any(word in text_lower for word in ["phone", "email", "address", "lives", "works"]):
|
|
||||||
if any(word in msg_lower for word in ["contact", "phone", "address", "email"]):
|
|
||||||
categories["contacts"].append(mem)
|
|
||||||
|
|
||||||
# Personal preferences
|
|
||||||
elif any(word in text_lower for word in ["likes", "dislikes", "prefers", "favorite"]):
|
|
||||||
if any(word in msg_lower for word in ["like", "prefer", "favorite", "want"]):
|
|
||||||
categories["preferences"].append(mem)
|
|
||||||
|
|
||||||
# Tasks and todos
|
|
||||||
elif any(word in text_lower for word in ["todo", "task", "remind", "meeting"]):
|
|
||||||
if any(word in msg_lower for word in ["todo", "task", "schedule", "remind"]):
|
|
||||||
categories["tasks"].append(mem)
|
|
||||||
|
|
||||||
# General facts - only if very relevant
|
|
||||||
else:
|
|
||||||
if get_text_similarity(message, mem["text"]) > 0.4:
|
|
||||||
categories["facts"].append(mem)
|
|
||||||
|
|
||||||
return categories
|
|
||||||
|
|
||||||
def get_relevant_memories(self, query: str, memories: list, threshold: float = 0.05, max_items: int = 8):
|
|
||||||
"""Get memories that are relevant to the query based on text similarity and semantic keyword matching."""
|
|
||||||
if not memories or not query.strip():
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Define keyword categories for semantic matching
|
|
||||||
identity_words = ["name", "who", "i", "am", "called", "identity", "myself", "me", "my"]
|
|
||||||
contact_words = ["phone", "email", "address", "contact", "number", "where", "located", "reach"]
|
|
||||||
preference_words = ["like", "prefer", "favorite", "want", "love", "hate", "dislike", "enjoy", "interested"]
|
|
||||||
task_words = ["todo", "task", "remind", "meeting", "appointment", "schedule", "deadline"]
|
|
||||||
fact_words = ["what", "when", "where", "how", "why", "explain", "describe", "information", "know"]
|
|
||||||
|
|
||||||
query_lower = query.lower()
|
|
||||||
|
|
||||||
# Determine query type based on keywords
|
|
||||||
query_type = None
|
|
||||||
if any(word in query_lower for word in identity_words):
|
|
||||||
query_type = "identity"
|
|
||||||
elif any(word in query_lower for word in contact_words):
|
|
||||||
query_type = "contact"
|
|
||||||
elif any(word in query_lower for word in preference_words):
|
|
||||||
query_type = "preference"
|
|
||||||
elif any(word in query_lower for word in task_words):
|
|
||||||
query_type = "task"
|
|
||||||
elif any(word in query_lower for word in fact_words):
|
|
||||||
query_type = "fact"
|
|
||||||
|
|
||||||
relevant = []
|
|
||||||
identity_memories = []
|
|
||||||
other_memories = []
|
|
||||||
|
|
||||||
# Separate identity memories from others
|
|
||||||
for memory in memories:
|
|
||||||
memory_text = memory["text"].lower()
|
|
||||||
# Check if this is an identity memory (contains name patterns or identity indicators)
|
|
||||||
is_identity = any([
|
|
||||||
re.search(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', memory["text"]),
|
|
||||||
any(word in memory_text for word in ["name is", "i'm", "i am", "called", "my name", "named", "call me"])
|
|
||||||
])
|
|
||||||
if is_identity:
|
|
||||||
identity_memories.append(memory)
|
|
||||||
else:
|
|
||||||
other_memories.append(memory)
|
|
||||||
|
|
||||||
# For identity queries, include all identity memories regardless of similarity
|
|
||||||
if query_type == "identity" and identity_memories:
|
|
||||||
# Give them high scores to ensure they're included first
|
|
||||||
for memory in identity_memories:
|
|
||||||
relevant.append((0.9, memory)) # High score for identity memories in identity queries
|
|
||||||
|
|
||||||
# Process other memories with similarity scoring
|
|
||||||
for memory in other_memories:
|
|
||||||
memory_text = memory["text"].lower()
|
|
||||||
memory_tokens = set(tokenize(memory_text))
|
|
||||||
query_tokens = set(tokenize(query_lower))
|
|
||||||
|
|
||||||
# Calculate base Jaccard similarity
|
|
||||||
if not query_tokens or not memory_tokens:
|
|
||||||
continue
|
|
||||||
|
|
||||||
base_similarity = len(query_tokens & memory_tokens) / len(query_tokens | memory_tokens)
|
|
||||||
final_score = base_similarity
|
|
||||||
|
|
||||||
# Apply boosts based on semantic matching
|
|
||||||
if query_type == "contact":
|
|
||||||
# Boost memories with contact information
|
|
||||||
has_contact_info = any(word in memory_text for word in ["@gmail.com", "@", ".com",
|
|
||||||
"phone", "number", "address",
|
|
||||||
"http", "www", "tel:"])
|
|
||||||
if has_contact_info:
|
|
||||||
final_score *= 1.4 # 40% boost for contact-related memories
|
|
||||||
|
|
||||||
elif query_type == "preference":
|
|
||||||
# Boost memories with preference indicators
|
|
||||||
has_preference = any(word in memory_text for word in ["like", "love", "hate", "dislike",
|
|
||||||
"prefer", "favorite", "enjoy", "interested"])
|
|
||||||
if has_preference:
|
|
||||||
final_score *= 1.3 # 30% boost for preference-related memories
|
|
||||||
|
|
||||||
elif query_type == "task":
|
|
||||||
# Boost memories with task indicators
|
|
||||||
has_task = any(word in memory_text for word in ["todo", "task", "remind", "meeting",
|
|
||||||
"appointment", "schedule", "deadline", "need to"])
|
|
||||||
if has_task:
|
|
||||||
final_score *= 1.3 # 30% boost for task-related memories
|
|
||||||
|
|
||||||
# Always consider exact phrase matches as highly relevant
|
|
||||||
if query.lower() in memory["text"].lower():
|
|
||||||
final_score = max(final_score, 0.8) # Ensure high relevance for exact matches
|
|
||||||
|
|
||||||
# Include memory if it meets threshold after boosts
|
|
||||||
if final_score >= threshold:
|
|
||||||
relevant.append((final_score, memory))
|
|
||||||
|
|
||||||
# Sort by final score (descending) and return top matches
|
|
||||||
relevant.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
return [mem for _, mem in relevant[:max_items]]
|
|
||||||
|
|||||||
@@ -1,175 +1,5 @@
|
|||||||
"""
|
"""Compatibility import for the canonical memory vector store."""
|
||||||
memory_vector.py
|
|
||||||
|
|
||||||
ChromaDB-backed vector store for memory entries.
|
from src.memory_vector import MemoryVectorStore
|
||||||
Shares the EmbeddingClient with RAG to save memory.
|
|
||||||
Stores pre-computed embeddings (ChromaDB does not manage embedding).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
__all__ = ["MemoryVectorStore"]
|
||||||
from typing import List, Dict, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryVectorStore:
|
|
||||||
"""Vector index over memory entries for semantic retrieval."""
|
|
||||||
|
|
||||||
COLLECTION_NAME = "odysseus_memories"
|
|
||||||
|
|
||||||
def __init__(self, data_dir: str, embedding_model=None):
|
|
||||||
self._model = embedding_model
|
|
||||||
self._collection = None
|
|
||||||
self._healthy = False
|
|
||||||
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def _initialize(self):
|
|
||||||
try:
|
|
||||||
from src.chroma_client import get_chroma_client
|
|
||||||
|
|
||||||
if self._model is None:
|
|
||||||
from src.embeddings import get_embedding_client
|
|
||||||
self._model = get_embedding_client()
|
|
||||||
if self._model is None:
|
|
||||||
raise RuntimeError("No embedding backend available")
|
|
||||||
logger.info(f"MemoryVectorStore using embeddings: {self._model.url}")
|
|
||||||
|
|
||||||
client = get_chroma_client()
|
|
||||||
self._collection = client.get_or_create_collection(
|
|
||||||
name=self.COLLECTION_NAME,
|
|
||||||
metadata={"hnsw:space": "cosine"},
|
|
||||||
)
|
|
||||||
|
|
||||||
self._healthy = True
|
|
||||||
count = self._collection.count()
|
|
||||||
logger.info(f"MemoryVectorStore ready (entries={count})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"MemoryVectorStore init failed: {e}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def healthy(self) -> bool:
|
|
||||||
return self._healthy
|
|
||||||
|
|
||||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
vecs = self._model.encode(texts, normalize_embeddings=True)
|
|
||||||
return vecs.tolist()
|
|
||||||
|
|
||||||
def count(self) -> int:
|
|
||||||
"""Return the number of stored vectors."""
|
|
||||||
if not self._healthy:
|
|
||||||
return 0
|
|
||||||
return self._collection.count()
|
|
||||||
|
|
||||||
def add(self, memory_id: str, text: str):
|
|
||||||
"""Add a single memory entry to the vector index."""
|
|
||||||
if not self._healthy:
|
|
||||||
return
|
|
||||||
# Skip if already exists
|
|
||||||
existing = self._collection.get(ids=[memory_id])
|
|
||||||
if existing["ids"]:
|
|
||||||
return
|
|
||||||
embeddings = self._embed([text])
|
|
||||||
self._collection.add(
|
|
||||||
ids=[memory_id],
|
|
||||||
embeddings=embeddings,
|
|
||||||
documents=[text],
|
|
||||||
metadatas=[{"source": "memory"}],
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove(self, memory_id: str):
|
|
||||||
"""Remove a memory entry. O(1) — no rebuild needed."""
|
|
||||||
if not self._healthy:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
self._collection.delete(ids=[memory_id])
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"memory remove {memory_id}: {e}")
|
|
||||||
|
|
||||||
def search(self, query: str, k: int = 8) -> List[Dict]:
|
|
||||||
"""Search for the most relevant memory IDs by semantic similarity.
|
|
||||||
Returns list of {"memory_id": str, "score": float}.
|
|
||||||
|
|
||||||
ChromaDB cosine distance = 1 - cosine_similarity.
|
|
||||||
We convert back: similarity = 1.0 - distance.
|
|
||||||
"""
|
|
||||||
if not self._healthy or self._collection.count() == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
embeddings = self._embed([query])
|
|
||||||
actual_k = min(k, self._collection.count())
|
|
||||||
results = self._collection.query(
|
|
||||||
query_embeddings=embeddings,
|
|
||||||
n_results=actual_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for idx, mid in enumerate(results["ids"][0]):
|
|
||||||
distance = results["distances"][0][idx]
|
|
||||||
out.append({
|
|
||||||
"memory_id": mid,
|
|
||||||
"score": round(1.0 - distance, 4),
|
|
||||||
})
|
|
||||||
return out
|
|
||||||
|
|
||||||
def find_similar(self, text: str, threshold: float = 0.92) -> Optional[str]:
|
|
||||||
"""Check if a near-duplicate exists. Returns memory_id if found, else None."""
|
|
||||||
if not self._healthy or self._collection.count() == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
embeddings = self._embed([text])
|
|
||||||
results = self._collection.query(
|
|
||||||
query_embeddings=embeddings,
|
|
||||||
n_results=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if results["ids"][0]:
|
|
||||||
distance = results["distances"][0][0]
|
|
||||||
similarity = 1.0 - distance
|
|
||||||
if similarity >= threshold:
|
|
||||||
return results["ids"][0][0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def rebuild(self, memories: List[Dict]):
|
|
||||||
"""Rebuild the entire index from a list of memory entries.
|
|
||||||
Each entry must have 'id' and 'text' keys."""
|
|
||||||
if not self._healthy:
|
|
||||||
return
|
|
||||||
|
|
||||||
from src.chroma_client import get_chroma_client
|
|
||||||
|
|
||||||
# Delete and recreate collection for a clean rebuild
|
|
||||||
client = get_chroma_client()
|
|
||||||
try:
|
|
||||||
client.delete_collection(self.COLLECTION_NAME)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._collection = client.get_or_create_collection(
|
|
||||||
name=self.COLLECTION_NAME,
|
|
||||||
metadata={"hnsw:space": "cosine"},
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = []
|
|
||||||
ids = []
|
|
||||||
for mem in memories:
|
|
||||||
text = mem.get("text", "").strip()
|
|
||||||
mid = mem.get("id", "")
|
|
||||||
if text and mid:
|
|
||||||
texts.append(text)
|
|
||||||
ids.append(mid)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
# Batch in chunks of 100 to avoid oversized requests
|
|
||||||
for i in range(0, len(texts), 100):
|
|
||||||
batch_texts = texts[i:i + 100]
|
|
||||||
batch_ids = ids[i:i + 100]
|
|
||||||
embeddings = self._embed(batch_texts)
|
|
||||||
self._collection.add(
|
|
||||||
ids=batch_ids,
|
|
||||||
embeddings=embeddings,
|
|
||||||
documents=batch_texts,
|
|
||||||
metadatas=[{"source": "memory"}] * len(batch_ids),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"MemoryVectorStore rebuilt with {len(ids)} entries")
|
|
||||||
|
|||||||
@@ -43,6 +43,16 @@ class MemoryService:
|
|||||||
os.path.join(data_dir, "memory_vectors")
|
os.path.join(data_dir, "memory_vectors")
|
||||||
) else None
|
) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_memory(entry: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> Memory:
|
||||||
|
return Memory(
|
||||||
|
id=entry.get("id", ""),
|
||||||
|
text=entry.get("text", ""),
|
||||||
|
timestamp=entry.get("timestamp", 0),
|
||||||
|
session_id=entry.get("session_id"),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
|
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
|
||||||
"""
|
"""
|
||||||
Store a new memory.
|
Store a new memory.
|
||||||
@@ -54,31 +64,19 @@ class MemoryService:
|
|||||||
Returns:
|
Returns:
|
||||||
Created Memory object
|
Created Memory object
|
||||||
"""
|
"""
|
||||||
import uuid
|
entry = self.manager.add_entry(text)
|
||||||
import time
|
if session_id:
|
||||||
|
entry["session_id"] = session_id
|
||||||
|
|
||||||
memory_id = str(uuid.uuid4())[:8]
|
memories = self.manager.load_all()
|
||||||
timestamp = int(time.time())
|
memories.append(entry)
|
||||||
|
self.manager.save(memories)
|
||||||
entry = {
|
|
||||||
"id": memory_id,
|
|
||||||
"text": text,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.manager.add_memory(entry)
|
|
||||||
|
|
||||||
# Also add to vector store if available
|
# Also add to vector store if available
|
||||||
if self.vector_store:
|
if self.vector_store and self.vector_store.healthy:
|
||||||
self.vector_store.add(text, {"id": memory_id, "session_id": session_id})
|
self.vector_store.add(entry["id"], entry["text"])
|
||||||
|
|
||||||
return Memory(
|
return self._to_memory(entry)
|
||||||
id=memory_id,
|
|
||||||
text=text,
|
|
||||||
timestamp=timestamp,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
|
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
|
||||||
"""
|
"""
|
||||||
@@ -92,47 +90,36 @@ class MemoryService:
|
|||||||
MemorySearchResult with matching memories
|
MemorySearchResult with matching memories
|
||||||
"""
|
"""
|
||||||
# Try vector search first
|
# Try vector search first
|
||||||
if self.vector_store:
|
all_memories = self.manager.load_all()
|
||||||
|
by_id = {m.get("id"): m for m in all_memories}
|
||||||
|
if self.vector_store and self.vector_store.healthy:
|
||||||
results = self.vector_store.search(query, k=top_k)
|
results = self.vector_store.search(query, k=top_k)
|
||||||
memories = [
|
found = []
|
||||||
Memory(
|
for result in results:
|
||||||
id=r.get("id", ""),
|
entry = by_id.get(result.get("memory_id"))
|
||||||
text=r.get("text", ""),
|
if entry:
|
||||||
timestamp=r.get("timestamp", 0),
|
found.append(self._to_memory(entry, metadata={"score": result.get("score")}))
|
||||||
session_id=r.get("session_id"),
|
if found:
|
||||||
metadata=r.get("metadata", {}),
|
return MemorySearchResult(memories=found, query=query, total=len(found))
|
||||||
)
|
|
||||||
for r in results
|
|
||||||
if isinstance(r, dict)
|
|
||||||
]
|
|
||||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
|
||||||
|
|
||||||
# Fallback to keyword search
|
# Fallback to keyword search
|
||||||
results = self.manager.search_memories(query, limit=top_k)
|
results = self.manager.get_relevant_memories(query, all_memories, max_items=top_k)
|
||||||
memories = [
|
memories = [self._to_memory(m) for m in results]
|
||||||
Memory(
|
|
||||||
id=m.get("id", ""),
|
|
||||||
text=m.get("text", ""),
|
|
||||||
timestamp=m.get("timestamp", 0),
|
|
||||||
session_id=m.get("session_id"),
|
|
||||||
)
|
|
||||||
for m in results
|
|
||||||
]
|
|
||||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
||||||
|
|
||||||
def get_all(self, limit: int = 100) -> List[Memory]:
|
def get_all(self, limit: int = 100) -> List[Memory]:
|
||||||
"""Get all memories."""
|
"""Get all memories."""
|
||||||
memories = self.manager.get_memories(limit=limit)
|
memories = self.manager.load_all()[:limit]
|
||||||
return [
|
return [self._to_memory(m) for m in memories]
|
||||||
Memory(
|
|
||||||
id=m.get("id", ""),
|
|
||||||
text=m.get("text", ""),
|
|
||||||
timestamp=m.get("timestamp", 0),
|
|
||||||
session_id=m.get("session_id"),
|
|
||||||
)
|
|
||||||
for m in memories
|
|
||||||
]
|
|
||||||
|
|
||||||
def delete(self, memory_id: str) -> bool:
|
def delete(self, memory_id: str) -> bool:
|
||||||
"""Delete a memory by ID."""
|
"""Delete a memory by ID."""
|
||||||
return self.manager.delete_memory(memory_id)
|
memories = self.manager.load_all()
|
||||||
|
remaining = [m for m in memories if m.get("id") != memory_id]
|
||||||
|
if len(remaining) == len(memories):
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.manager.save(remaining)
|
||||||
|
if self.vector_store and self.vector_store.healthy:
|
||||||
|
self.vector_store.remove(memory_id)
|
||||||
|
return True
|
||||||
|
|||||||
@@ -132,6 +132,20 @@ class MemoryManager:
|
|||||||
if owner is None:
|
if owner is None:
|
||||||
return entries
|
return entries
|
||||||
return [e for e in entries if e.get("owner") == owner]
|
return [e for e in entries if e.get("owner") == owner]
|
||||||
|
|
||||||
|
def claim_ownerless(self, owner: str):
|
||||||
|
"""Assign all ownerless memory entries to the given owner."""
|
||||||
|
entries = self.load_all()
|
||||||
|
changed = False
|
||||||
|
claimed = 0
|
||||||
|
for entry in entries:
|
||||||
|
if not entry.get("owner"):
|
||||||
|
entry["owner"] = owner
|
||||||
|
changed = True
|
||||||
|
claimed += 1
|
||||||
|
if changed:
|
||||||
|
self.save(entries)
|
||||||
|
logger.info("Claimed %d ownerless memories for %s", claimed, owner)
|
||||||
|
|
||||||
def _validate_entries(self, entries: List[Dict]) -> List[Dict]:
|
def _validate_entries(self, entries: List[Dict]) -> List[Dict]:
|
||||||
"""Ensure all entries have required fields."""
|
"""Ensure all entries have required fields."""
|
||||||
|
|||||||
56
tests/test_memory_imports.py
Normal file
56
tests/test_memory_imports.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Regression tests for memory import-path compatibility."""
|
||||||
|
|
||||||
|
|
||||||
|
def test_services_memory_manager_is_canonical_src_class():
|
||||||
|
from services.memory import MemoryManager as package_manager
|
||||||
|
from services.memory.memory import MemoryManager as module_manager
|
||||||
|
from src.memory import MemoryManager as canonical_manager
|
||||||
|
|
||||||
|
assert module_manager is canonical_manager
|
||||||
|
assert package_manager is canonical_manager
|
||||||
|
assert hasattr(package_manager, "increment_uses")
|
||||||
|
assert hasattr(package_manager, "claim_ownerless")
|
||||||
|
|
||||||
|
|
||||||
|
def test_services_memory_vector_is_canonical_src_class():
|
||||||
|
from services.memory import MemoryVectorStore as package_vector_store
|
||||||
|
from services.memory.memory_vector import MemoryVectorStore as module_vector_store
|
||||||
|
from src.memory_vector import MemoryVectorStore as canonical_vector_store
|
||||||
|
|
||||||
|
assert module_vector_store is canonical_vector_store
|
||||||
|
assert package_vector_store is canonical_vector_store
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_service_uses_canonical_manager_api(tmp_path):
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from services.memory import MemoryService
|
||||||
|
|
||||||
|
service = MemoryService(str(tmp_path))
|
||||||
|
|
||||||
|
remembered = asyncio.run(service.remember("User prefers dark mode", session_id="sess-1"))
|
||||||
|
assert remembered.text == "User prefers dark mode"
|
||||||
|
assert remembered.session_id == "sess-1"
|
||||||
|
|
||||||
|
all_memories = service.get_all()
|
||||||
|
assert [m.id for m in all_memories] == [remembered.id]
|
||||||
|
|
||||||
|
recalled = asyncio.run(service.recall("dark mode", top_k=5))
|
||||||
|
assert [m.id for m in recalled.memories] == [remembered.id]
|
||||||
|
|
||||||
|
assert service.delete(remembered.id) is True
|
||||||
|
assert service.delete(remembered.id) is False
|
||||||
|
assert service.get_all() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_canonical_manager_keeps_ownerless_claim_helper(tmp_path):
|
||||||
|
from src.memory import MemoryManager
|
||||||
|
|
||||||
|
manager = MemoryManager(str(tmp_path))
|
||||||
|
entry = manager.add_entry("User likes compact code reviews")
|
||||||
|
manager.save([entry])
|
||||||
|
|
||||||
|
manager.claim_ownerless("alice")
|
||||||
|
|
||||||
|
memories = manager.load_all()
|
||||||
|
assert memories[0]["owner"] == "alice"
|
||||||
Reference in New Issue
Block a user