diff --git a/src/builtin_actions.py b/src/builtin_actions.py index 6fe8101..b8eed3f 100644 --- a/src/builtin_actions.py +++ b/src/builtin_actions.py @@ -78,32 +78,46 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]: manager = MemoryManager(DATA_DIR) all_memories = manager.load_all() - # Empty owner means "all owners" for built-in housekeeping, but never - # mix owners in the same AI prompt/apply step. A specific owner is - # scoped strictly to that owner; unowned rows are their own group. _owner_clean = (owner or "").strip() - if _owner_clean: - def _belongs_to_owner(mem: dict) -> bool: - mem_owner = (mem.get("owner") or "").strip() - return mem_owner == _owner_clean - else: - def _belongs_to_owner(mem: dict) -> bool: - return True + text_limit = 2000 - owner_memories = [m for m in all_memories if _belongs_to_owner(m)] - if not owner_memories: + def _memory_owner(mem: dict) -> str: + return (mem.get("owner") or "").strip() + + # Built-in housekeeping can run without an owner. In that case scan all + # memories, but keep every AI prompt/apply step owner-local. + if _owner_clean: + memory_groups = { + _owner_clean: [m for m in all_memories if _memory_owner(m) == _owner_clean] + } + else: + memory_groups = {} + for mem in all_memories: + memory_groups.setdefault(_memory_owner(mem), []).append(mem) + + memory_groups = {group_owner: group for group_owner, group in memory_groups.items() if group} + if not memory_groups: raise TaskNoop("no memories to consolidate") - memory_owners = {(m.get("owner") or "").strip() for m in owner_memories} - allow_ai_tidy = len(memory_owners) <= 1 + total_removed = 0 + total_cleaned = 0 + total_scanned = 0 + removed_examples = [] + ai_reasons = [] + ai_used = False - url, model, headers = resolve_endpoint("utility", owner=owner) - if not url or not model: - url, model, headers = resolve_endpoint("default", owner=owner) + async def _try_ai_tidy_group(group_owner: str, group_memories: list) -> bool: + nonlocal all_memories, total_removed, total_cleaned, total_scanned, ai_used + if len(group_memories) < 2: + return False + + url, model, headers = resolve_endpoint("utility", owner=group_owner or None) + if not url or not model: + url, model, headers = resolve_endpoint("default", owner=group_owner or None) + if not url or not model: + return False - if url and model and allow_ai_tidy and len(owner_memories) >= 2: try: - text_limit = 2000 items = [ { "id": m.get("id"), @@ -111,9 +125,11 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]: "text": (m.get("text") or "").strip()[:text_limit], "truncated": len((m.get("text") or "").strip()) > text_limit, } - for m in owner_memories + for m in group_memories if m.get("id") and (m.get("text") or "").strip() ] + if len(items) < 2: + return False truncated_ids = {item["id"] for item in items if item.get("truncated")} prompt = ( "You are tidying a user's saved personal memories. Return ONLY raw JSON, no markdown.\n" @@ -146,7 +162,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]: keep_items = decision.get("keep") if isinstance(decision, dict) else None drop_items = decision.get("drop") if isinstance(decision, dict) else None if isinstance(keep_items, list) and isinstance(drop_items, list): - by_id = {m.get("id"): m for m in owner_memories} + by_id = {m.get("id"): m for m in group_memories if m.get("id")} keep_ids = set() cleaned_by_id = {} for item in keep_items: @@ -159,19 +175,24 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]: if not text: continue keep_ids.add(mid) - cleaned_by_id[mid] = { - "text": text, + cleaned = { "category": (item.get("category") or by_id[mid].get("category") or "fact").strip(), } + original_text = (by_id[mid].get("text") or "").strip() + if len(original_text) <= text_limit: + cleaned["text"] = text + cleaned_by_id[mid] = cleaned + # If the model only saw a truncated memory, do not let # that partial view delete or rewrite the full memory. keep_ids.update(mid for mid in truncated_ids if mid in by_id) if keep_ids: changed_text = 0 + group_ref_ids = {id(m) for m in group_memories} kept_all = [] for mem in all_memories: - if not _belongs_to_owner(mem): + if id(mem) not in group_ref_ids: kept_all.append(mem) continue mid = mem.get("id") @@ -185,65 +206,72 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]: changed_text += 1 if cleaned.get("category"): mem["category"] = cleaned["category"] - if owner and not mem.get("owner"): - mem["owner"] = owner kept_all.append(mem) - removed = len(owner_memories) - len(keep_ids) + removed = len(group_memories) - len(keep_ids) + total_scanned += len(group_memories) if removed or changed_text: - manager.save(kept_all) - reasons = [ + all_memories = kept_all + total_removed += removed + total_cleaned += changed_text + ai_used = True + ai_reasons.extend([ (d.get("reason") or "").strip() for d in drop_items if isinstance(d, dict) and (d.get("reason") or "").strip() - ][:3] - reason_text = f": {'; '.join(reasons)}" if reasons else "" - return ( - f"AI tidied {len(owner_memories)} memories: " - f"removed {removed}, cleaned {changed_text}{reason_text}", - True, - ) - - raise TaskNoop(f"AI scanned {len(owner_memories)} memories, no changes") - except TaskNoop: - raise + ]) + return True except Exception as ai_err: logger.warning("AI memory tidy failed; falling back to duplicate cleanup: %s", ai_err) + return False - seen = {} - keep_ids = set() - removed_examples = [] - for mem in owner_memories: - text = (mem.get("text") or "").strip() - normalized = " ".join(text.lower().split()) - if not normalized: - removed_examples.append("(empty)") + for group_owner, group_memories in memory_groups.items(): + if await _try_ai_tidy_group(group_owner, group_memories): continue - mem_owner = (mem.get("owner") or "").strip() - key = (mem_owner, normalized) - if key in seen: - if len(removed_examples) < 3: - removed_examples.append(text[:60] + ("..." if len(text) > 60 else "")) + + seen = {} + keep_refs = set() + total_scanned += len(group_memories) + for mem in group_memories: + text = (mem.get("text") or "").strip() + key = " ".join(text.lower().split()) + if not key: + if len(removed_examples) < 3: + removed_examples.append("(empty)") + continue + if key in seen: + if len(removed_examples) < 3: + removed_examples.append(text[:60] + ("..." if len(text) > 60 else "")) + continue + seen[key] = mem + keep_refs.add(id(mem)) + + group_removed = len(group_memories) - len(keep_refs) + if group_removed == 0: continue - seen[key] = mem - keep_ids.add(mem.get("id")) - removed = len(owner_memories) - len(keep_ids) - if removed == 0: - raise TaskNoop(f"scanned {len(owner_memories)} memories, no duplicates") + group_ref_ids = {id(m) for m in group_memories} + all_memories = [ + m for m in all_memories + if id(m) not in group_ref_ids or id(m) in keep_refs + ] + total_removed += group_removed - kept_all = [ - m for m in all_memories - if not _belongs_to_owner(m) or m.get("id") in keep_ids - ] - if owner: - for mem in kept_all: - if mem.get("id") in keep_ids and not mem.get("owner"): - mem["owner"] = owner - manager.save(kept_all) - preview = "; ".join(removed_examples) - extra = f" (+{removed - len(removed_examples)} more)" if removed > len(removed_examples) else "" - return f"Removed {removed} duplicate(s) of {len(owner_memories)}: {preview}{extra}", True + if total_removed or total_cleaned: + manager.save(all_memories) + if ai_used: + reasons = ai_reasons[:3] + reason_text = f": {'; '.join(reasons)}" if reasons else "" + return ( + f"AI tidied {total_scanned} memories: " + f"removed {total_removed}, cleaned {total_cleaned}{reason_text}", + True, + ) + preview = "; ".join(removed_examples) + extra = f" (+{total_removed - len(removed_examples)} more)" if total_removed > len(removed_examples) else "" + return f"Removed {total_removed} duplicate(s) of {total_scanned}: {preview}{extra}", True + + raise TaskNoop(f"scanned {total_scanned} memories, no duplicates") except Exception as e: logger.error(f"consolidate_memory action failed: {e}") return str(e), False diff --git a/tests/test_builtin_memory_consolidation.py b/tests/test_builtin_memory_consolidation.py new file mode 100644 index 0000000..bebd435 --- /dev/null +++ b/tests/test_builtin_memory_consolidation.py @@ -0,0 +1,112 @@ +import json +import sys + +import pytest + + +def _import_consolidate_action(): + mod = sys.modules.get("src.builtin_actions") + if mod is not None and not hasattr(mod, "action_consolidate_memory"): + sys.modules.pop("src.builtin_actions", None) + if "src" in sys.modules and hasattr(sys.modules["src"], "builtin_actions"): + delattr(sys.modules["src"], "builtin_actions") + from src.builtin_actions import action_consolidate_memory + + return action_consolidate_memory + + +def _write_memories(tmp_path, memories): + data_dir = tmp_path / "data" + data_dir.mkdir() + (data_dir / "memory.json").write_text(json.dumps(memories), encoding="utf-8") + return data_dir + + +def _read_memories(data_dir): + return json.loads((data_dir / "memory.json").read_text(encoding="utf-8")) + + +@pytest.mark.asyncio +async def test_consolidate_memory_empty_owner_treats_each_owner_separately(monkeypatch, tmp_path): + from src import constants + from src import endpoint_resolver + from src import llm_core + action_consolidate_memory = _import_consolidate_action() + + long_alice_text = "Alice private project context. " + ("A" * 2200) + data_dir = _write_memories( + tmp_path, + [ + {"id": "alice-long", "owner": "alice", "text": long_alice_text, "category": "project"}, + {"id": "alice-short", "owner": "alice", "text": "Alice likes quiet summaries.", "category": "preference"}, + {"id": "bob-keep", "owner": "bob", "text": "Bob secret deployment note.", "category": "project"}, + {"id": "bob-drop", "owner": "bob", "text": "Bob secret deployment note duplicate.", "category": "project"}, + ], + ) + monkeypatch.setattr(constants, "DATA_DIR", str(data_dir)) + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda *args, **kwargs: ("http://llm", "model", {})) + + prompts = [] + + async def fake_llm_call_async(**kwargs): + prompt = kwargs["messages"][0]["content"] + prompts.append(prompt) + if "alice-long" in prompt: + assert "bob-keep" not in prompt + return json.dumps( + { + "keep": [ + {"id": "alice-long", "text": "TRUNCATED REWRITE", "category": "project"}, + {"id": "alice-short", "text": "Alice likes concise summaries.", "category": "preference"}, + ], + "drop": [], + } + ) + assert "bob-keep" in prompt + assert "alice-long" not in prompt + return json.dumps( + { + "keep": [{"id": "bob-keep", "text": "Bob secret deployment note.", "category": "project"}], + "drop": [{"id": "bob-drop", "reason": "duplicate"}], + } + ) + + monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async) + + message, ok = await action_consolidate_memory("") + + assert ok is True + assert "removed 1" in message + assert len(prompts) == 2 + saved = {m["id"]: m for m in _read_memories(data_dir)} + assert set(saved) == {"alice-long", "alice-short", "bob-keep"} + assert saved["alice-long"]["text"] == long_alice_text + assert saved["alice-short"]["text"] == "Alice likes concise summaries." + + +@pytest.mark.asyncio +async def test_consolidate_memory_specific_owner_does_not_absorb_ownerless_rows(monkeypatch, tmp_path): + from src import constants + from src import endpoint_resolver + action_consolidate_memory = _import_consolidate_action() + + data_dir = _write_memories( + tmp_path, + [ + {"id": "alice-1", "owner": "alice", "text": "Alice likes local models.", "category": "preference"}, + {"id": "alice-2", "owner": "alice", "text": "Alice likes local models.", "category": "preference"}, + {"id": "legacy", "text": "Alice likes local models.", "category": "preference"}, + {"id": "bob-1", "owner": "bob", "text": "Bob likes hosted models.", "category": "preference"}, + ], + ) + monkeypatch.setattr(constants, "DATA_DIR", str(data_dir)) + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda *args, **kwargs: ("", "", {})) + + message, ok = await action_consolidate_memory("alice") + + assert ok is True + assert "Removed 1 duplicate" in message + saved = {m["id"]: m for m in _read_memories(data_dir)} + assert set(saved) == {"alice-1", "legacy", "bob-1"} + assert "owner" not in saved["legacy"] + assert saved["bob-1"]["owner"] == "bob"