fix(ai): scope tool model resolution by owner

* Stabilize full test collection

* Scope AI tool model resolution by owner
This commit is contained in:
Vykos
2026-06-04 01:37:28 +02:00
committed by GitHub
parent aaef6b1c49
commit 5f58f9a45f
2 changed files with 109 additions and 23 deletions

View File

@@ -58,7 +58,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
def _resolve_model(spec: str) -> Tuple[str, str, Dict]: def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
"""Resolve a model specifier to (endpoint_url, model_id, headers). """Resolve a model specifier to (endpoint_url, model_id, headers).
Accepts: Accepts:
@@ -70,6 +70,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
import httpx import httpx
from src.database import SessionLocal, ModelEndpoint from src.database import SessionLocal, ModelEndpoint
from src.llm_core import _detect_provider, ANTHROPIC_MODELS from src.llm_core import _detect_provider, ANTHROPIC_MODELS
from src.auth_helpers import owner_filter
spec = spec.strip() spec = spec.strip()
target_endpoint_name = None target_endpoint_name = None
@@ -86,6 +87,8 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if target_endpoint_name: if target_endpoint_name:
query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%")) query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%"))
if owner:
query = owner_filter(query, ModelEndpoint, owner)
endpoints = query.all() endpoints = query.all()
if not endpoints: if not endpoints:
@@ -141,7 +144,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
# Tool implementations # Tool implementations
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def do_chat_with_model(content: str, session_id: Optional[str] = None) -> Dict: async def do_chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Send a message to a specific model and return its response. """Send a message to a specific model and return its response.
Content format: Content format:
@@ -160,7 +163,7 @@ async def do_chat_with_model(content: str, session_id: Optional[str] = None) ->
return {"error": "No message provided (line 2+ is the message)"} return {"error": "No message provided (line 2+ is the message)"}
try: try:
url, model, headers = _resolve_model(model_spec) url, model, headers = _resolve_model(model_spec, owner=owner)
except ValueError as e: except ValueError as e:
return {"error": str(e)} return {"error": str(e)}
@@ -190,7 +193,7 @@ _TEACHER_SYSTEM_PROMPT = (
) )
async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict: async def do_ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Ask a more capable model for help. """Ask a more capable model for help.
Content format: Content format:
@@ -213,7 +216,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."} return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."}
try: try:
url, model, headers = _resolve_model(model_spec) url, model, headers = _resolve_model(model_spec, owner=owner)
except ValueError as e: except ValueError as e:
return {"error": str(e)} return {"error": str(e)}
@@ -235,7 +238,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
return {"error": f"Teacher call failed ({model_spec}): {e}"} return {"error": f"Teacher call failed ({model_spec}): {e}"}
async def do_second_opinion(content: str, session_id: Optional[str] = None) -> Dict: async def do_second_opinion(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Get a second opinion from another model, then have the original model """Get a second opinion from another model, then have the original model
evaluate the feedback and produce a unified version. evaluate the feedback and produce a unified version.
@@ -259,7 +262,7 @@ async def do_second_opinion(content: str, session_id: Optional[str] = None) -> D
focus = lines[1].strip() if len(lines) > 1 else "" focus = lines[1].strip() if len(lines) > 1 else ""
try: try:
reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec) reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec, owner=owner)
except ValueError as e: except ValueError as e:
return {"error": str(e)} return {"error": str(e)}
@@ -400,7 +403,7 @@ async def do_create_session(content: str, session_id: Optional[str] = None, owne
return {"error": "Session name cannot be empty"} return {"error": "Session name cannot be empty"}
try: try:
url, model, headers = _resolve_model(model_spec) url, model, headers = _resolve_model(model_spec, owner=owner)
except ValueError as e: except ValueError as e:
return {"error": str(e)} return {"error": str(e)}
@@ -584,7 +587,7 @@ async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = No
yield {"_final": True, "desc": desc, "result": result} yield {"_final": True, "desc": desc, "result": result}
async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict: async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Execute a multi-step pipeline where each model's output feeds the next. """Execute a multi-step pipeline where each model's output feeds the next.
Content format (JSON): Content format (JSON):
@@ -638,7 +641,7 @@ async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict:
if not model_spec or not instruction: if not model_spec or not instruction:
return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"} return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"}
try: try:
url, model, headers = _resolve_model(model_spec) url, model, headers = _resolve_model(model_spec, owner=owner)
resolved.append((url, model, headers, instruction)) resolved.append((url, model, headers, instruction))
except ValueError as e: except ValueError as e:
return {"error": f"Step {i + 1}: {e}"} return {"error": f"Step {i + 1}: {e}"}
@@ -1091,7 +1094,7 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner
# List models tool # List models tool
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict: async def do_list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""List all available models across configured endpoints. """List all available models across configured endpoints.
Content = optional filter keyword. Content = optional filter keyword.
@@ -1099,12 +1102,16 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict
import httpx import httpx
from src.database import SessionLocal, ModelEndpoint from src.database import SessionLocal, ModelEndpoint
from src.llm_core import _detect_provider, ANTHROPIC_MODELS from src.llm_core import _detect_provider, ANTHROPIC_MODELS
from src.auth_helpers import owner_filter
keyword = content.strip().lower() if content.strip() else None keyword = content.strip().lower() if content.strip() else None
db = SessionLocal() db = SessionLocal()
try: try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
query = owner_filter(query, ModelEndpoint, owner)
endpoints = query.all()
if not endpoints: if not endpoints:
return {"results": "No enabled model endpoints configured."} return {"results": "No enabled model endpoints configured."}
@@ -1250,7 +1257,7 @@ async def do_manage_rag(content: str, session_id: Optional[str] = None) -> Dict:
# UI control tool (returns events for frontend to apply) # UI control tool (returns events for frontend to apply)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict: async def do_ui_control(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Control frontend UI: toggle settings, switch model, change theme. """Control frontend UI: toggle settings, switch model, change theme.
Content format: Content format:
@@ -1325,7 +1332,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict:
# Resolve the model to validate it exists # Resolve the model to validate it exists
try: try:
url, model_id, headers = _resolve_model(model_spec) url, model_id, headers = _resolve_model(model_spec, owner=owner)
except ValueError as e: except ValueError as e:
return {"error": str(e)} return {"error": str(e)}
@@ -1580,7 +1587,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
if not model_spec: if not model_spec:
for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"): for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"):
try: try:
_resolve_model(candidate) _resolve_model(candidate, owner=owner)
model_spec = candidate model_spec = candidate
break break
except ValueError: except ValueError:
@@ -1589,13 +1596,17 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
if not model_spec: if not model_spec:
try: try:
from src.database import SessionLocal, ModelEndpoint from src.database import SessionLocal, ModelEndpoint
from src.auth_helpers import owner_filter
import httpx as _req import httpx as _req
_idb = SessionLocal() _idb = SessionLocal()
try: try:
_img_eps = _idb.query(ModelEndpoint).filter( _img_q = _idb.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True, ModelEndpoint.is_enabled == True,
ModelEndpoint.model_type == "image", ModelEndpoint.model_type == "image",
).all() )
if owner:
_img_q = owner_filter(_img_q, ModelEndpoint, owner)
_img_eps = _img_q.all()
for _iep in _img_eps: for _iep in _img_eps:
_ibase = _iep.base_url.rstrip("/") _ibase = _iep.base_url.rstrip("/")
if not _ibase.endswith("/v1"): if not _ibase.endswith("/v1"):
@@ -1618,7 +1629,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
# Resolve the model to find the right endpoint # Resolve the model to find the right endpoint
try: try:
url, model_id, headers = _resolve_model(model_spec) url, model_id, headers = _resolve_model(model_spec, owner=owner)
except ValueError: except ValueError:
return {"error": f"No endpoint found with image model '{model_spec}'. " return {"error": f"No endpoint found with image model '{model_spec}'. "
"Configure an OpenAI-compatible endpoint with image generation support."} "Configure an OpenAI-compatible endpoint with image generation support."}
@@ -1760,7 +1771,7 @@ async def dispatch_ai_tool(
if tool == "chat_with_model": if tool == "chat_with_model":
model_spec = content.split("\n")[0].strip()[:60] model_spec = content.split("\n")[0].strip()[:60]
desc = f"chat_with_model: {model_spec}" desc = f"chat_with_model: {model_spec}"
result = await do_chat_with_model(content, session_id) result = await do_chat_with_model(content, session_id, owner=owner)
elif tool == "create_session": elif tool == "create_session":
name = content.split("\n")[0].strip()[:60] name = content.split("\n")[0].strip()[:60]
@@ -1779,7 +1790,7 @@ async def dispatch_ai_tool(
elif tool == "pipeline": elif tool == "pipeline":
desc = "pipeline: running steps" desc = "pipeline: running steps"
result = await do_pipeline(content, session_id) result = await do_pipeline(content, session_id, owner=owner)
elif tool == "manage_session": elif tool == "manage_session":
action = content.split("\n")[0].strip()[:40] action = content.split("\n")[0].strip()[:40]
@@ -1794,17 +1805,17 @@ async def dispatch_ai_tool(
elif tool == "list_models": elif tool == "list_models":
keyword = content.strip()[:40] keyword = content.strip()[:40]
desc = f"list_models{': ' + keyword if keyword else ''}" desc = f"list_models{': ' + keyword if keyword else ''}"
result = await do_list_models(content, session_id) result = await do_list_models(content, session_id, owner=owner)
elif tool == "ui_control": elif tool == "ui_control":
action = content.split("\n")[0].strip()[:60] action = content.split("\n")[0].strip()[:60]
desc = f"ui_control: {action}" desc = f"ui_control: {action}"
result = await do_ui_control(content, session_id) result = await do_ui_control(content, session_id, owner=owner)
elif tool == "ask_teacher": elif tool == "ask_teacher":
problem = content.split("\n", 1)[-1].strip()[:60] problem = content.split("\n", 1)[-1].strip()[:60]
desc = f"ask_teacher: {problem}" desc = f"ask_teacher: {problem}"
result = await do_ask_teacher(content, session_id) result = await do_ask_teacher(content, session_id, owner=owner)
else: else:
desc = f"unknown ai tool: {tool}" desc = f"unknown ai tool: {tool}"

View File

@@ -0,0 +1,75 @@
import inspect
import pytest
from src import ai_interaction
def _source(fn) -> str:
return inspect.getsource(fn)
def test_model_resolver_applies_owner_filter():
body = _source(ai_interaction._resolve_model)
assert "owner: Optional[str] = None" in body
assert "from src.auth_helpers import owner_filter" in body
assert "owner_filter(query, ModelEndpoint, owner)" in body
def test_model_listing_and_image_fallback_are_owner_scoped():
list_body = _source(ai_interaction.do_list_models)
image_body = _source(ai_interaction.do_generate_image)
assert "owner: Optional[str] = None" in list_body
assert "owner_filter(query, ModelEndpoint, owner)" in list_body
assert "_resolve_model(candidate, owner=owner)" in image_body
assert "owner_filter(_img_q, ModelEndpoint, owner)" in image_body
assert "_resolve_model(model_spec, owner=owner)" in image_body
@pytest.mark.parametrize("tool,content", [
("chat_with_model", "gpt-test\nhello"),
("pipeline", "gpt-test | summarize this"),
("list_models", ""),
("ui_control", "switch_model gpt-test"),
("ask_teacher", "gpt-test\nhelp me"),
])
async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content):
seen = {}
async def capture(name, content, session_id=None, owner=None):
seen[name] = {"content": content, "session_id": session_id, "owner": owner}
return {"ok": True}
monkeypatch.setattr(
ai_interaction,
"do_chat_with_model",
lambda content, session_id=None, owner=None: capture("chat_with_model", content, session_id, owner),
)
monkeypatch.setattr(
ai_interaction,
"do_pipeline",
lambda content, session_id=None, owner=None: capture("pipeline", content, session_id, owner),
)
monkeypatch.setattr(
ai_interaction,
"do_list_models",
lambda content, session_id=None, owner=None: capture("list_models", content, session_id, owner),
)
monkeypatch.setattr(
ai_interaction,
"do_ui_control",
lambda content, session_id=None, owner=None: capture("ui_control", content, session_id, owner),
)
monkeypatch.setattr(
ai_interaction,
"do_ask_teacher",
lambda content, session_id=None, owner=None: capture("ask_teacher", content, session_id, owner),
)
_desc, result = await ai_interaction.dispatch_ai_tool(tool, content, session_id="sid1", owner="alice")
assert result == {"ok": True}
assert seen[tool]["owner"] == "alice"
assert seen[tool]["session_id"] == "sid1"