Scope document tools to caller owner
Co-authored-by: Lohinth <lohinth25@proton.me>
This commit is contained in:
@@ -651,15 +651,15 @@ async def execute_tool_block(
|
|||||||
elif tool == "create_document":
|
elif tool == "create_document":
|
||||||
title = content.split("\n")[0].strip()[:60]
|
title = content.split("\n")[0].strip()[:60]
|
||||||
desc = f"create_document: {title}"
|
desc = f"create_document: {title}"
|
||||||
result = await do_create_document(content, session_id=session_id)
|
result = await do_create_document(content, session_id=session_id, owner=owner)
|
||||||
elif tool == "update_document":
|
elif tool == "update_document":
|
||||||
desc = f"update_document: {content.split(chr(10))[0][:60]}"
|
desc = f"update_document: {content.split(chr(10))[0][:60]}"
|
||||||
result = await do_update_document(content)
|
result = await do_update_document(content, owner=owner)
|
||||||
elif tool == "edit_document":
|
elif tool == "edit_document":
|
||||||
result = await do_edit_document(content)
|
result = await do_edit_document(content, owner=owner)
|
||||||
desc = f"edit_document: {result.get('title', '')}"
|
desc = f"edit_document: {result.get('title', '')}"
|
||||||
elif tool == "suggest_document":
|
elif tool == "suggest_document":
|
||||||
result = await do_suggest_document(content)
|
result = await do_suggest_document(content, owner=owner)
|
||||||
desc = f"suggest_document: {result.get('count', 0)} suggestions"
|
desc = f"suggest_document: {result.get('count', 0)} suggestions"
|
||||||
elif tool == "search_chats":
|
elif tool == "search_chats":
|
||||||
query = content.split("\n")[0].strip()
|
query = content.split("\n")[0].strip()
|
||||||
|
|||||||
@@ -88,6 +88,28 @@ def get_active_document():
|
|||||||
return _active_document_id
|
return _active_document_id
|
||||||
|
|
||||||
|
|
||||||
|
def _owned_document_query(query, Document, owner: Optional[str]):
|
||||||
|
if owner is None:
|
||||||
|
return query.filter(False)
|
||||||
|
return query.filter(Document.owner == owner)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False):
|
||||||
|
q = db.query(Document).filter(Document.id == doc_id)
|
||||||
|
if active_only:
|
||||||
|
q = q.filter(Document.is_active == True)
|
||||||
|
q = _owned_document_query(q, Document, owner)
|
||||||
|
return q.first()
|
||||||
|
|
||||||
|
|
||||||
|
def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False):
|
||||||
|
q = db.query(Document)
|
||||||
|
if active_only:
|
||||||
|
q = q.filter(Document.is_active == True)
|
||||||
|
q = _owned_document_query(q, Document, owner)
|
||||||
|
return q.order_by(Document.updated_at.desc()).first()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Document tools — create/update/edit/suggest living documents
|
# Document tools — create/update/edit/suggest living documents
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -171,7 +193,7 @@ def _coerce_email_document_content(existing: str, incoming: str) -> str:
|
|||||||
return header.rstrip() + "\n---\n" + body
|
return header.rstrip() + "\n---\n" + body
|
||||||
|
|
||||||
|
|
||||||
async def do_create_document(content_block: str, session_id: Optional[str] = None) -> Dict:
|
async def do_create_document(content_block: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Create a new document. Supports two formats:
|
"""Create a new document. Supports two formats:
|
||||||
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
|
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
|
||||||
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
|
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
|
||||||
@@ -240,6 +262,8 @@ async def do_create_document(content_block: str, session_id: Optional[str] = Non
|
|||||||
# Inherit ownership from the chat session so the doc survives that
|
# Inherit ownership from the chat session so the doc survives that
|
||||||
# session later being deleted (session_id → NULL).
|
# session later being deleted (session_id → NULL).
|
||||||
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||||
|
if owner is not None and (not _sess or _sess.owner != owner):
|
||||||
|
return {"error": "Cannot create document in another user's session"}
|
||||||
_owner = _sess.owner if _sess else None
|
_owner = _sess.owner if _sess else None
|
||||||
|
|
||||||
doc = Document(
|
doc = Document(
|
||||||
@@ -286,7 +310,7 @@ async def do_create_document(content_block: str, session_id: Optional[str] = Non
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
async def do_update_document(content: str, doc_id: Optional[str] = None) -> Dict:
|
async def do_update_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Update an existing document. Content = full new document text."""
|
"""Update an existing document. Content = full new document text."""
|
||||||
import uuid
|
import uuid
|
||||||
from src.database import SessionLocal, Document, DocumentVersion
|
from src.database import SessionLocal, Document, DocumentVersion
|
||||||
@@ -297,9 +321,9 @@ async def do_update_document(content: str, doc_id: Optional[str] = None) -> Dict
|
|||||||
try:
|
try:
|
||||||
doc = None
|
doc = None
|
||||||
if target_id:
|
if target_id:
|
||||||
doc = db.query(Document).filter(Document.id == target_id).first()
|
doc = _get_owned_document(db, Document, target_id, owner)
|
||||||
if not doc:
|
if not doc:
|
||||||
doc = db.query(Document).order_by(Document.updated_at.desc()).first()
|
doc = _most_recent_owned_document(db, Document, owner)
|
||||||
if doc:
|
if doc:
|
||||||
target_id = doc.id
|
target_id = doc.id
|
||||||
set_active_document(target_id)
|
set_active_document(target_id)
|
||||||
@@ -350,7 +374,7 @@ def parse_edit_blocks(content: str) -> list:
|
|||||||
return edits
|
return edits
|
||||||
|
|
||||||
|
|
||||||
async def do_edit_document(content: str, doc_id: Optional[str] = None) -> Dict:
|
async def do_edit_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Apply targeted FIND/REPLACE edits to an existing document."""
|
"""Apply targeted FIND/REPLACE edits to an existing document."""
|
||||||
import uuid
|
import uuid
|
||||||
from src.database import SessionLocal, Document, DocumentVersion
|
from src.database import SessionLocal, Document, DocumentVersion
|
||||||
@@ -365,11 +389,11 @@ async def do_edit_document(content: str, doc_id: Optional[str] = None) -> Dict:
|
|||||||
try:
|
try:
|
||||||
doc = None
|
doc = None
|
||||||
if target_id:
|
if target_id:
|
||||||
doc = db.query(Document).filter(Document.id == target_id).first()
|
doc = _get_owned_document(db, Document, target_id, owner)
|
||||||
if not doc:
|
if not doc:
|
||||||
# Fallback: most recently updated document. Avoids "no active doc" errors
|
# Fallback: most recently updated document. Avoids "no active doc" errors
|
||||||
# after server restart or when the agent loses track of which doc to edit.
|
# after server restart or when the agent loses track of which doc to edit.
|
||||||
doc = db.query(Document).order_by(Document.updated_at.desc()).first()
|
doc = _most_recent_owned_document(db, Document, owner)
|
||||||
if doc:
|
if doc:
|
||||||
target_id = doc.id
|
target_id = doc.id
|
||||||
set_active_document(target_id)
|
set_active_document(target_id)
|
||||||
@@ -458,7 +482,7 @@ def parse_suggest_blocks(content: str) -> list:
|
|||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
async def do_suggest_document(content: str, doc_id: str = None) -> Dict:
|
async def do_suggest_document(content: str, doc_id: str = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Create inline suggestions for the active document WITHOUT modifying it."""
|
"""Create inline suggestions for the active document WITHOUT modifying it."""
|
||||||
from src.database import SessionLocal, Document
|
from src.database import SessionLocal, Document
|
||||||
|
|
||||||
@@ -472,7 +496,7 @@ async def do_suggest_document(content: str, doc_id: str = None) -> Dict:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
doc = db.query(Document).filter(Document.id == target_id).first()
|
doc = _get_owned_document(db, Document, target_id, owner)
|
||||||
if not doc:
|
if not doc:
|
||||||
return {"error": f"Document {target_id} not found"}
|
return {"error": f"Document {target_id} not found"}
|
||||||
|
|
||||||
@@ -1368,6 +1392,7 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict
|
|||||||
try:
|
try:
|
||||||
if action == "list":
|
if action == "list":
|
||||||
q = db.query(Document).filter(Document.is_active == True)
|
q = db.query(Document).filter(Document.is_active == True)
|
||||||
|
q = _owned_document_query(q, Document, owner)
|
||||||
if args.get("search"):
|
if args.get("search"):
|
||||||
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
|
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
|
||||||
if args.get("language"):
|
if args.get("language"):
|
||||||
@@ -1398,7 +1423,7 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict
|
|||||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
|
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
|
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
|
||||||
doc = db.query(Document).filter(Document.id == doc_id, Document.is_active == True).first()
|
doc = _get_owned_document(db, Document, doc_id, owner, active_only=True)
|
||||||
if not doc:
|
if not doc:
|
||||||
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
|
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
|
||||||
body = doc.current_content or ""
|
body = doc.current_content or ""
|
||||||
@@ -1423,10 +1448,10 @@ async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict
|
|||||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
|
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
|
||||||
doc = None
|
doc = None
|
||||||
if doc_id:
|
if doc_id:
|
||||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
doc = _get_owned_document(db, Document, doc_id, owner)
|
||||||
if not doc:
|
if not doc:
|
||||||
# Fallback: most recently updated doc (likely what the user means)
|
# Fallback: most recently updated doc (likely what the user means)
|
||||||
doc = db.query(Document).filter(Document.is_active == True).order_by(Document.updated_at.desc()).first()
|
doc = _most_recent_owned_document(db, Document, owner, active_only=True)
|
||||||
if not doc:
|
if not doc:
|
||||||
return {"error": "No document to delete", "exit_code": 1}
|
return {"error": "No document to delete", "exit_code": 1}
|
||||||
title = doc.title
|
title = doc.title
|
||||||
|
|||||||
150
tests/test_document_tool_owner_scope.py
Normal file
150
tests/test_document_tool_owner_scope.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
from src import tool_implementations as tools
|
||||||
|
|
||||||
|
|
||||||
|
class _Column:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __eq__(self, value):
|
||||||
|
return (self.name, "eq", value)
|
||||||
|
|
||||||
|
def desc(self):
|
||||||
|
return (self.name, "desc")
|
||||||
|
|
||||||
|
def ilike(self, value):
|
||||||
|
return (self.name, "ilike", value)
|
||||||
|
|
||||||
|
|
||||||
|
class _Document:
|
||||||
|
id = _Column("id")
|
||||||
|
owner = _Column("owner")
|
||||||
|
is_active = _Column("is_active")
|
||||||
|
title = _Column("title")
|
||||||
|
language = _Column("language")
|
||||||
|
updated_at = _Column("updated_at")
|
||||||
|
|
||||||
|
|
||||||
|
class _Query:
|
||||||
|
def __init__(self, docs=None, first_doc=None):
|
||||||
|
self.filters = []
|
||||||
|
self.docs = docs or []
|
||||||
|
self.first_doc = first_doc
|
||||||
|
|
||||||
|
def filter(self, *clauses):
|
||||||
|
self.filters.extend(clauses)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def order_by(self, *args):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, *args):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
return self.docs
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return self.first_doc
|
||||||
|
|
||||||
|
|
||||||
|
class _Db:
|
||||||
|
def __init__(self, query):
|
||||||
|
self.query_obj = query
|
||||||
|
|
||||||
|
def query(self, *args):
|
||||||
|
return self.query_obj
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _install_database_stub(monkeypatch, module_name, query):
|
||||||
|
db = _Db(query)
|
||||||
|
db_mod = types.ModuleType(module_name)
|
||||||
|
db_mod.SessionLocal = lambda: db
|
||||||
|
db_mod.Document = _Document
|
||||||
|
db_mod.DocumentVersion = object
|
||||||
|
db_mod.Session = object
|
||||||
|
monkeypatch.setitem(sys.modules, module_name, db_mod)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def test_owned_document_query_rejects_missing_owner():
|
||||||
|
query = _Query()
|
||||||
|
|
||||||
|
assert tools._owned_document_query(query, _Document, None) is query
|
||||||
|
assert False in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_owned_document_query_filters_to_owner():
|
||||||
|
query = _Query()
|
||||||
|
|
||||||
|
assert tools._owned_document_query(query, _Document, "alice") is query
|
||||||
|
assert ("owner", "eq", "alice") in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_manage_documents_list_filters_to_calling_owner(monkeypatch):
|
||||||
|
query = _Query()
|
||||||
|
_install_database_stub(monkeypatch, "core.database", query)
|
||||||
|
|
||||||
|
result = asyncio.run(tools.do_manage_documents('{"action":"list"}', owner="alice"))
|
||||||
|
|
||||||
|
assert result["documents"] == []
|
||||||
|
assert ("owner", "eq", "alice") in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_manage_documents_read_filters_to_calling_owner(monkeypatch):
|
||||||
|
query = _Query()
|
||||||
|
_install_database_stub(monkeypatch, "core.database", query)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tools.do_manage_documents('{"action":"read","document_id":"doc-bob"}', owner="alice")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["exit_code"] == 1
|
||||||
|
assert ("id", "eq", "doc-bob") in query.filters
|
||||||
|
assert ("owner", "eq", "alice") in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_document_active_id_filters_to_calling_owner(monkeypatch):
|
||||||
|
query = _Query()
|
||||||
|
_install_database_stub(monkeypatch, "src.database", query)
|
||||||
|
tools.set_active_document("doc-bob")
|
||||||
|
try:
|
||||||
|
result = asyncio.run(tools.do_update_document("new content", owner="alice"))
|
||||||
|
finally:
|
||||||
|
tools.set_active_document(None)
|
||||||
|
|
||||||
|
assert result["error"] == "No documents exist to update"
|
||||||
|
assert ("id", "eq", "doc-bob") in query.filters
|
||||||
|
assert ("owner", "eq", "alice") in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggest_document_active_id_filters_to_calling_owner(monkeypatch):
|
||||||
|
query = _Query()
|
||||||
|
_install_database_stub(monkeypatch, "src.database", query)
|
||||||
|
tools.set_active_document("doc-bob")
|
||||||
|
try:
|
||||||
|
result = asyncio.run(tools.do_suggest_document(
|
||||||
|
"<<<FIND>>>\nold\n<<<SUGGEST>>>\nnew\n<<<REASON>>>\nbetter\n<<<END>>>",
|
||||||
|
owner="alice",
|
||||||
|
))
|
||||||
|
finally:
|
||||||
|
tools.set_active_document(None)
|
||||||
|
|
||||||
|
assert result["error"] == "Document doc-bob not found"
|
||||||
|
assert ("id", "eq", "doc-bob") in query.filters
|
||||||
|
assert ("owner", "eq", "alice") in query.filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_tool_dispatch_forwards_owner():
|
||||||
|
source = open("src/tool_execution.py", encoding="utf-8").read()
|
||||||
|
|
||||||
|
assert "do_create_document(content, session_id=session_id, owner=owner)" in source
|
||||||
|
assert "do_update_document(content, owner=owner)" in source
|
||||||
|
assert "do_edit_document(content, owner=owner)" in source
|
||||||
|
assert "do_suggest_document(content, owner=owner)" in source
|
||||||
Reference in New Issue
Block a user