From 5f58f9a45fd6d6f88f0f57dcdf57fa5d7b2806f0 Mon Sep 17 00:00:00 2001 From: Vykos Date: Thu, 4 Jun 2026 01:37:28 +0200 Subject: [PATCH] fix(ai): scope tool model resolution by owner * Stabilize full test collection * Scope AI tool model resolution by owner --- src/ai_interaction.py | 57 ++++++++++-------- tests/test_ai_interaction_owner_scope.py | 75 ++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 23 deletions(-) create mode 100644 tests/test_ai_interaction_owner_scope.py diff --git a/src/ai_interaction.py b/src/ai_interaction.py index 5d36507..383560e 100644 --- a/src/ai_interaction.py +++ b/src/ai_interaction.py @@ -58,7 +58,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None): from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url -def _resolve_model(spec: str) -> Tuple[str, str, Dict]: +def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]: """Resolve a model specifier to (endpoint_url, model_id, headers). Accepts: @@ -70,6 +70,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]: import httpx from src.database import SessionLocal, ModelEndpoint from src.llm_core import _detect_provider, ANTHROPIC_MODELS + from src.auth_helpers import owner_filter spec = spec.strip() target_endpoint_name = None @@ -86,6 +87,8 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]: query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) if target_endpoint_name: query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%")) + if owner: + query = owner_filter(query, ModelEndpoint, owner) endpoints = query.all() if not endpoints: @@ -141,7 +144,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]: # Tool implementations # --------------------------------------------------------------------------- -async def do_chat_with_model(content: str, session_id: Optional[str] = None) -> Dict: +async def do_chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Send a message to a specific model and return its response. Content format: @@ -160,7 +163,7 @@ async def do_chat_with_model(content: str, session_id: Optional[str] = None) -> return {"error": "No message provided (line 2+ is the message)"} try: - url, model, headers = _resolve_model(model_spec) + url, model, headers = _resolve_model(model_spec, owner=owner) except ValueError as e: return {"error": str(e)} @@ -190,7 +193,7 @@ _TEACHER_SYSTEM_PROMPT = ( ) -async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict: +async def do_ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Ask a more capable model for help. Content format: @@ -213,7 +216,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."} try: - url, model, headers = _resolve_model(model_spec) + url, model, headers = _resolve_model(model_spec, owner=owner) except ValueError as e: return {"error": str(e)} @@ -235,7 +238,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict return {"error": f"Teacher call failed ({model_spec}): {e}"} -async def do_second_opinion(content: str, session_id: Optional[str] = None) -> Dict: +async def do_second_opinion(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Get a second opinion from another model, then have the original model evaluate the feedback and produce a unified version. @@ -259,7 +262,7 @@ async def do_second_opinion(content: str, session_id: Optional[str] = None) -> D focus = lines[1].strip() if len(lines) > 1 else "" try: - reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec) + reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec, owner=owner) except ValueError as e: return {"error": str(e)} @@ -400,7 +403,7 @@ async def do_create_session(content: str, session_id: Optional[str] = None, owne return {"error": "Session name cannot be empty"} try: - url, model, headers = _resolve_model(model_spec) + url, model, headers = _resolve_model(model_spec, owner=owner) except ValueError as e: return {"error": str(e)} @@ -584,7 +587,7 @@ async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = No yield {"_final": True, "desc": desc, "result": result} -async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict: +async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Execute a multi-step pipeline where each model's output feeds the next. Content format (JSON): @@ -638,7 +641,7 @@ async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict: if not model_spec or not instruction: return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"} try: - url, model, headers = _resolve_model(model_spec) + url, model, headers = _resolve_model(model_spec, owner=owner) resolved.append((url, model, headers, instruction)) except ValueError as e: return {"error": f"Step {i + 1}: {e}"} @@ -1091,7 +1094,7 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner # List models tool # --------------------------------------------------------------------------- -async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict: +async def do_list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """List all available models across configured endpoints. Content = optional filter keyword. @@ -1099,12 +1102,16 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict import httpx from src.database import SessionLocal, ModelEndpoint from src.llm_core import _detect_provider, ANTHROPIC_MODELS + from src.auth_helpers import owner_filter keyword = content.strip().lower() if content.strip() else None db = SessionLocal() try: - endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + if owner: + query = owner_filter(query, ModelEndpoint, owner) + endpoints = query.all() if not endpoints: return {"results": "No enabled model endpoints configured."} @@ -1250,7 +1257,7 @@ async def do_manage_rag(content: str, session_id: Optional[str] = None) -> Dict: # UI control tool (returns events for frontend to apply) # --------------------------------------------------------------------------- -async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict: +async def do_ui_control(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: """Control frontend UI: toggle settings, switch model, change theme. Content format: @@ -1325,7 +1332,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict: # Resolve the model to validate it exists try: - url, model_id, headers = _resolve_model(model_spec) + url, model_id, headers = _resolve_model(model_spec, owner=owner) except ValueError as e: return {"error": str(e)} @@ -1580,7 +1587,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne if not model_spec: for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"): try: - _resolve_model(candidate) + _resolve_model(candidate, owner=owner) model_spec = candidate break except ValueError: @@ -1589,13 +1596,17 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne if not model_spec: try: from src.database import SessionLocal, ModelEndpoint + from src.auth_helpers import owner_filter import httpx as _req _idb = SessionLocal() try: - _img_eps = _idb.query(ModelEndpoint).filter( + _img_q = _idb.query(ModelEndpoint).filter( ModelEndpoint.is_enabled == True, ModelEndpoint.model_type == "image", - ).all() + ) + if owner: + _img_q = owner_filter(_img_q, ModelEndpoint, owner) + _img_eps = _img_q.all() for _iep in _img_eps: _ibase = _iep.base_url.rstrip("/") if not _ibase.endswith("/v1"): @@ -1618,7 +1629,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne # Resolve the model to find the right endpoint try: - url, model_id, headers = _resolve_model(model_spec) + url, model_id, headers = _resolve_model(model_spec, owner=owner) except ValueError: return {"error": f"No endpoint found with image model '{model_spec}'. " "Configure an OpenAI-compatible endpoint with image generation support."} @@ -1760,7 +1771,7 @@ async def dispatch_ai_tool( if tool == "chat_with_model": model_spec = content.split("\n")[0].strip()[:60] desc = f"chat_with_model: {model_spec}" - result = await do_chat_with_model(content, session_id) + result = await do_chat_with_model(content, session_id, owner=owner) elif tool == "create_session": name = content.split("\n")[0].strip()[:60] @@ -1779,7 +1790,7 @@ async def dispatch_ai_tool( elif tool == "pipeline": desc = "pipeline: running steps" - result = await do_pipeline(content, session_id) + result = await do_pipeline(content, session_id, owner=owner) elif tool == "manage_session": action = content.split("\n")[0].strip()[:40] @@ -1794,17 +1805,17 @@ async def dispatch_ai_tool( elif tool == "list_models": keyword = content.strip()[:40] desc = f"list_models{': ' + keyword if keyword else ''}" - result = await do_list_models(content, session_id) + result = await do_list_models(content, session_id, owner=owner) elif tool == "ui_control": action = content.split("\n")[0].strip()[:60] desc = f"ui_control: {action}" - result = await do_ui_control(content, session_id) + result = await do_ui_control(content, session_id, owner=owner) elif tool == "ask_teacher": problem = content.split("\n", 1)[-1].strip()[:60] desc = f"ask_teacher: {problem}" - result = await do_ask_teacher(content, session_id) + result = await do_ask_teacher(content, session_id, owner=owner) else: desc = f"unknown ai tool: {tool}" diff --git a/tests/test_ai_interaction_owner_scope.py b/tests/test_ai_interaction_owner_scope.py new file mode 100644 index 0000000..7b2ac63 --- /dev/null +++ b/tests/test_ai_interaction_owner_scope.py @@ -0,0 +1,75 @@ +import inspect + +import pytest + +from src import ai_interaction + + +def _source(fn) -> str: + return inspect.getsource(fn) + + +def test_model_resolver_applies_owner_filter(): + body = _source(ai_interaction._resolve_model) + + assert "owner: Optional[str] = None" in body + assert "from src.auth_helpers import owner_filter" in body + assert "owner_filter(query, ModelEndpoint, owner)" in body + + +def test_model_listing_and_image_fallback_are_owner_scoped(): + list_body = _source(ai_interaction.do_list_models) + image_body = _source(ai_interaction.do_generate_image) + + assert "owner: Optional[str] = None" in list_body + assert "owner_filter(query, ModelEndpoint, owner)" in list_body + assert "_resolve_model(candidate, owner=owner)" in image_body + assert "owner_filter(_img_q, ModelEndpoint, owner)" in image_body + assert "_resolve_model(model_spec, owner=owner)" in image_body + + +@pytest.mark.parametrize("tool,content", [ + ("chat_with_model", "gpt-test\nhello"), + ("pipeline", "gpt-test | summarize this"), + ("list_models", ""), + ("ui_control", "switch_model gpt-test"), + ("ask_teacher", "gpt-test\nhelp me"), +]) +async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content): + seen = {} + + async def capture(name, content, session_id=None, owner=None): + seen[name] = {"content": content, "session_id": session_id, "owner": owner} + return {"ok": True} + + monkeypatch.setattr( + ai_interaction, + "do_chat_with_model", + lambda content, session_id=None, owner=None: capture("chat_with_model", content, session_id, owner), + ) + monkeypatch.setattr( + ai_interaction, + "do_pipeline", + lambda content, session_id=None, owner=None: capture("pipeline", content, session_id, owner), + ) + monkeypatch.setattr( + ai_interaction, + "do_list_models", + lambda content, session_id=None, owner=None: capture("list_models", content, session_id, owner), + ) + monkeypatch.setattr( + ai_interaction, + "do_ui_control", + lambda content, session_id=None, owner=None: capture("ui_control", content, session_id, owner), + ) + monkeypatch.setattr( + ai_interaction, + "do_ask_teacher", + lambda content, session_id=None, owner=None: capture("ask_teacher", content, session_id, owner), + ) + + _desc, result = await ai_interaction.dispatch_ai_tool(tool, content, session_id="sid1", owner="alice") + + assert result == {"ok": True} + assert seen[tool]["owner"] == "alice" + assert seen[tool]["session_id"] == "sid1"