diff --git a/src/tool_execution.py b/src/tool_execution.py index e0a04d2..c4294a6 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -651,15 +651,15 @@ async def execute_tool_block( elif tool == "create_document": title = content.split("\n")[0].strip()[:60] desc = f"create_document: {title}" - result = await do_create_document(content, session_id=session_id) + result = await do_create_document(content, session_id=session_id, owner=owner) elif tool == "update_document": desc = f"update_document: {content.split(chr(10))[0][:60]}" - result = await do_update_document(content) + result = await do_update_document(content, owner=owner) elif tool == "edit_document": - result = await do_edit_document(content) + result = await do_edit_document(content, owner=owner) desc = f"edit_document: {result.get('title', '')}" elif tool == "suggest_document": - result = await do_suggest_document(content) + result = await do_suggest_document(content, owner=owner) desc = f"suggest_document: {result.get('count', 0)} suggestions" elif tool == "search_chats": query = content.split("\n")[0].strip() diff --git a/src/tool_implementations.py b/src/tool_implementations.py index 1e9032f..40d17be 100644 --- a/src/tool_implementations.py +++ b/src/tool_implementations.py @@ -88,6 +88,28 @@ def get_active_document(): return _active_document_id +def _owned_document_query(query, Document, owner: Optional[str]): + if owner is None: + return query.filter(False) + return query.filter(Document.owner == owner) + + +def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False): + q = db.query(Document).filter(Document.id == doc_id) + if active_only: + q = q.filter(Document.is_active == True) + q = _owned_document_query(q, Document, owner) + return q.first() + + +def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False): + q = db.query(Document) + if active_only: + q = q.filter(Document.is_active == True) + q = _owned_document_query(q, Document, owner) + return q.order_by(Document.updated_at.desc()).first() + + # --------------------------------------------------------------------------- # Document tools — create/update/edit/suggest living documents # --------------------------------------------------------------------------- @@ -171,7 +193,7 @@ def _coerce_email_document_content(existing: str, incoming: str) -> str: return header.rstrip() + "\n---\n" + body -async def do_create_document(content_block: str, session_id: Optional[str] = None) -> Dict: +async def do_create_document(content_block: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Create a new document. Supports two formats: 1) Line-based: line 1 = title, line 2 (optional) = language, rest = content 2) XML-like tags: ......... @@ -240,6 +262,8 @@ async def do_create_document(content_block: str, session_id: Optional[str] = Non # Inherit ownership from the chat session so the doc survives that # session later being deleted (session_id → NULL). _sess = db.query(DbSession).filter(DbSession.id == session_id).first() + if owner is not None and (not _sess or _sess.owner != owner): + return {"error": "Cannot create document in another user's session"} _owner = _sess.owner if _sess else None doc = Document( @@ -286,7 +310,7 @@ async def do_create_document(content_block: str, session_id: Optional[str] = Non db.close() -async def do_update_document(content: str, doc_id: Optional[str] = None) -> Dict: +async def do_update_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Update an existing document. Content = full new document text.""" import uuid from src.database import SessionLocal, Document, DocumentVersion @@ -297,9 +321,9 @@ async def do_update_document(content: str, doc_id: Optional[str] = None) -> Dict try: doc = None if target_id: - doc = db.query(Document).filter(Document.id == target_id).first() + doc = _get_owned_document(db, Document, target_id, owner) if not doc: - doc = db.query(Document).order_by(Document.updated_at.desc()).first() + doc = _most_recent_owned_document(db, Document, owner) if doc: target_id = doc.id set_active_document(target_id) @@ -350,7 +374,7 @@ def parse_edit_blocks(content: str) -> list: return edits -async def do_edit_document(content: str, doc_id: Optional[str] = None) -> Dict: +async def do_edit_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Apply targeted FIND/REPLACE edits to an existing document.""" import uuid from src.database import SessionLocal, Document, DocumentVersion @@ -365,11 +389,11 @@ async def do_edit_document(content: str, doc_id: Optional[str] = None) -> Dict: try: doc = None if target_id: - doc = db.query(Document).filter(Document.id == target_id).first() + doc = _get_owned_document(db, Document, target_id, owner) if not doc: # Fallback: most recently updated document. Avoids "no active doc" errors # after server restart or when the agent loses track of which doc to edit. - doc = db.query(Document).order_by(Document.updated_at.desc()).first() + doc = _most_recent_owned_document(db, Document, owner) if doc: target_id = doc.id set_active_document(target_id) @@ -458,7 +482,7 @@ def parse_suggest_blocks(content: str) -> list: return suggestions -async def do_suggest_document(content: str, doc_id: str = None) -> Dict: +async def do_suggest_document(content: str, doc_id: str = None, owner: Optional[str] = None) -> Dict: """Create inline suggestions for the active document WITHOUT modifying it.""" from src.database import SessionLocal, Document @@ -472,7 +496,7 @@ async def do_suggest_document(content: str, doc_id: str = None) -> Dict: db = SessionLocal() try: - doc = db.query(Document).filter(Document.id == target_id).first() + doc = _get_owned_document(db, Document, target_id, owner) if not doc: return {"error": f"Document {target_id} not found"} @@ -1368,6 +1392,7 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict try: if action == "list": q = db.query(Document).filter(Document.is_active == True) + q = _owned_document_query(q, Document, owner) if args.get("search"): q = q.filter(Document.title.ilike(f"%{args['search']}%")) if args.get("language"): @@ -1398,7 +1423,7 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict doc_id = args.get("document_id") or args.get("id") or args.get("uid") if not doc_id: return {"error": "Need document_id (use action=list to find one)", "exit_code": 1} - doc = db.query(Document).filter(Document.id == doc_id, Document.is_active == True).first() + doc = _get_owned_document(db, Document, doc_id, owner, active_only=True) if not doc: return {"error": f"Document '{doc_id}' not found", "exit_code": 1} body = doc.current_content or "" @@ -1423,10 +1448,10 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id doc = None if doc_id: - doc = db.query(Document).filter(Document.id == doc_id).first() + doc = _get_owned_document(db, Document, doc_id, owner) if not doc: # Fallback: most recently updated doc (likely what the user means) - doc = db.query(Document).filter(Document.is_active == True).order_by(Document.updated_at.desc()).first() + doc = _most_recent_owned_document(db, Document, owner, active_only=True) if not doc: return {"error": "No document to delete", "exit_code": 1} title = doc.title diff --git a/tests/test_document_tool_owner_scope.py b/tests/test_document_tool_owner_scope.py new file mode 100644 index 0000000..be5f3f0 --- /dev/null +++ b/tests/test_document_tool_owner_scope.py @@ -0,0 +1,150 @@ +import asyncio +import sys +import types + +from src import tool_implementations as tools + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, value): + return (self.name, "eq", value) + + def desc(self): + return (self.name, "desc") + + def ilike(self, value): + return (self.name, "ilike", value) + + +class _Document: + id = _Column("id") + owner = _Column("owner") + is_active = _Column("is_active") + title = _Column("title") + language = _Column("language") + updated_at = _Column("updated_at") + + +class _Query: + def __init__(self, docs=None, first_doc=None): + self.filters = [] + self.docs = docs or [] + self.first_doc = first_doc + + def filter(self, *clauses): + self.filters.extend(clauses) + return self + + def order_by(self, *args): + return self + + def limit(self, *args): + return self + + def all(self): + return self.docs + + def first(self): + return self.first_doc + + +class _Db: + def __init__(self, query): + self.query_obj = query + + def query(self, *args): + return self.query_obj + + def close(self): + pass + + +def _install_database_stub(monkeypatch, module_name, query): + db = _Db(query) + db_mod = types.ModuleType(module_name) + db_mod.SessionLocal = lambda: db + db_mod.Document = _Document + db_mod.DocumentVersion = object + db_mod.Session = object + monkeypatch.setitem(sys.modules, module_name, db_mod) + return db + + +def test_owned_document_query_rejects_missing_owner(): + query = _Query() + + assert tools._owned_document_query(query, _Document, None) is query + assert False in query.filters + + +def test_owned_document_query_filters_to_owner(): + query = _Query() + + assert tools._owned_document_query(query, _Document, "alice") is query + assert ("owner", "eq", "alice") in query.filters + + +def test_manage_documents_list_filters_to_calling_owner(monkeypatch): + query = _Query() + _install_database_stub(monkeypatch, "core.database", query) + + result = asyncio.run(tools.do_manage_documents('{"action":"list"}', owner="alice")) + + assert result["documents"] == [] + assert ("owner", "eq", "alice") in query.filters + + +def test_manage_documents_read_filters_to_calling_owner(monkeypatch): + query = _Query() + _install_database_stub(monkeypatch, "core.database", query) + + result = asyncio.run( + tools.do_manage_documents('{"action":"read","document_id":"doc-bob"}', owner="alice") + ) + + assert result["exit_code"] == 1 + assert ("id", "eq", "doc-bob") in query.filters + assert ("owner", "eq", "alice") in query.filters + + +def test_update_document_active_id_filters_to_calling_owner(monkeypatch): + query = _Query() + _install_database_stub(monkeypatch, "src.database", query) + tools.set_active_document("doc-bob") + try: + result = asyncio.run(tools.do_update_document("new content", owner="alice")) + finally: + tools.set_active_document(None) + + assert result["error"] == "No documents exist to update" + assert ("id", "eq", "doc-bob") in query.filters + assert ("owner", "eq", "alice") in query.filters + + +def test_suggest_document_active_id_filters_to_calling_owner(monkeypatch): + query = _Query() + _install_database_stub(monkeypatch, "src.database", query) + tools.set_active_document("doc-bob") + try: + result = asyncio.run(tools.do_suggest_document( + "<<>>\nold\n<<>>\nnew\n<<>>\nbetter\n<<>>", + owner="alice", + )) + finally: + tools.set_active_document(None) + + assert result["error"] == "Document doc-bob not found" + assert ("id", "eq", "doc-bob") in query.filters + assert ("owner", "eq", "alice") in query.filters + + +def test_document_tool_dispatch_forwards_owner(): + source = open("src/tool_execution.py", encoding="utf-8").read() + + assert "do_create_document(content, session_id=session_id, owner=owner)" in source + assert "do_update_document(content, owner=owner)" in source + assert "do_edit_document(content, owner=owner)" in source + assert "do_suggest_document(content, owner=owner)" in source