fix(ai): scope tool model resolution by owner
* Stabilize full test collection * Scope AI tool model resolution by owner
This commit is contained in:
@@ -58,7 +58,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||
|
||||
|
||||
def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||
"""Resolve a model specifier to (endpoint_url, model_id, headers).
|
||||
|
||||
Accepts:
|
||||
@@ -70,6 +70,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
import httpx
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||
from src.auth_helpers import owner_filter
|
||||
|
||||
spec = spec.strip()
|
||||
target_endpoint_name = None
|
||||
@@ -86,6 +87,8 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if target_endpoint_name:
|
||||
query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%"))
|
||||
if owner:
|
||||
query = owner_filter(query, ModelEndpoint, owner)
|
||||
endpoints = query.all()
|
||||
|
||||
if not endpoints:
|
||||
@@ -141,7 +144,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
# Tool implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_chat_with_model(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Send a message to a specific model and return its response.
|
||||
|
||||
Content format:
|
||||
@@ -160,7 +163,7 @@ async def do_chat_with_model(content: str, session_id: Optional[str] = None) ->
|
||||
return {"error": "No message provided (line 2+ is the message)"}
|
||||
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -190,7 +193,7 @@ _TEACHER_SYSTEM_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Ask a more capable model for help.
|
||||
|
||||
Content format:
|
||||
@@ -213,7 +216,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
|
||||
return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."}
|
||||
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -235,7 +238,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
|
||||
return {"error": f"Teacher call failed ({model_spec}): {e}"}
|
||||
|
||||
|
||||
async def do_second_opinion(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_second_opinion(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Get a second opinion from another model, then have the original model
|
||||
evaluate the feedback and produce a unified version.
|
||||
|
||||
@@ -259,7 +262,7 @@ async def do_second_opinion(content: str, session_id: Optional[str] = None) -> D
|
||||
focus = lines[1].strip() if len(lines) > 1 else ""
|
||||
|
||||
try:
|
||||
reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec)
|
||||
reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -400,7 +403,7 @@ async def do_create_session(content: str, session_id: Optional[str] = None, owne
|
||||
return {"error": "Session name cannot be empty"}
|
||||
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -584,7 +587,7 @@ async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = No
|
||||
yield {"_final": True, "desc": desc, "result": result}
|
||||
|
||||
|
||||
async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Execute a multi-step pipeline where each model's output feeds the next.
|
||||
|
||||
Content format (JSON):
|
||||
@@ -638,7 +641,7 @@ async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
if not model_spec or not instruction:
|
||||
return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"}
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
resolved.append((url, model, headers, instruction))
|
||||
except ValueError as e:
|
||||
return {"error": f"Step {i + 1}: {e}"}
|
||||
@@ -1091,7 +1094,7 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner
|
||||
# List models tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""List all available models across configured endpoints.
|
||||
|
||||
Content = optional filter keyword.
|
||||
@@ -1099,12 +1102,16 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict
|
||||
import httpx
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||
from src.auth_helpers import owner_filter
|
||||
|
||||
keyword = content.strip().lower() if content.strip() else None
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
query = owner_filter(query, ModelEndpoint, owner)
|
||||
endpoints = query.all()
|
||||
if not endpoints:
|
||||
return {"results": "No enabled model endpoints configured."}
|
||||
|
||||
@@ -1250,7 +1257,7 @@ async def do_manage_rag(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
# UI control tool (returns events for frontend to apply)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
async def do_ui_control(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Control frontend UI: toggle settings, switch model, change theme.
|
||||
|
||||
Content format:
|
||||
@@ -1325,7 +1332,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict:
|
||||
|
||||
# Resolve the model to validate it exists
|
||||
try:
|
||||
url, model_id, headers = _resolve_model(model_spec)
|
||||
url, model_id, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -1580,7 +1587,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
||||
if not model_spec:
|
||||
for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"):
|
||||
try:
|
||||
_resolve_model(candidate)
|
||||
_resolve_model(candidate, owner=owner)
|
||||
model_spec = candidate
|
||||
break
|
||||
except ValueError:
|
||||
@@ -1589,13 +1596,17 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
||||
if not model_spec:
|
||||
try:
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
import httpx as _req
|
||||
_idb = SessionLocal()
|
||||
try:
|
||||
_img_eps = _idb.query(ModelEndpoint).filter(
|
||||
_img_q = _idb.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.model_type == "image",
|
||||
).all()
|
||||
)
|
||||
if owner:
|
||||
_img_q = owner_filter(_img_q, ModelEndpoint, owner)
|
||||
_img_eps = _img_q.all()
|
||||
for _iep in _img_eps:
|
||||
_ibase = _iep.base_url.rstrip("/")
|
||||
if not _ibase.endswith("/v1"):
|
||||
@@ -1618,7 +1629,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
||||
|
||||
# Resolve the model to find the right endpoint
|
||||
try:
|
||||
url, model_id, headers = _resolve_model(model_spec)
|
||||
url, model_id, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError:
|
||||
return {"error": f"No endpoint found with image model '{model_spec}'. "
|
||||
"Configure an OpenAI-compatible endpoint with image generation support."}
|
||||
@@ -1760,7 +1771,7 @@ async def dispatch_ai_tool(
|
||||
if tool == "chat_with_model":
|
||||
model_spec = content.split("\n")[0].strip()[:60]
|
||||
desc = f"chat_with_model: {model_spec}"
|
||||
result = await do_chat_with_model(content, session_id)
|
||||
result = await do_chat_with_model(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "create_session":
|
||||
name = content.split("\n")[0].strip()[:60]
|
||||
@@ -1779,7 +1790,7 @@ async def dispatch_ai_tool(
|
||||
|
||||
elif tool == "pipeline":
|
||||
desc = "pipeline: running steps"
|
||||
result = await do_pipeline(content, session_id)
|
||||
result = await do_pipeline(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "manage_session":
|
||||
action = content.split("\n")[0].strip()[:40]
|
||||
@@ -1794,17 +1805,17 @@ async def dispatch_ai_tool(
|
||||
elif tool == "list_models":
|
||||
keyword = content.strip()[:40]
|
||||
desc = f"list_models{': ' + keyword if keyword else ''}"
|
||||
result = await do_list_models(content, session_id)
|
||||
result = await do_list_models(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "ui_control":
|
||||
action = content.split("\n")[0].strip()[:60]
|
||||
desc = f"ui_control: {action}"
|
||||
result = await do_ui_control(content, session_id)
|
||||
result = await do_ui_control(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "ask_teacher":
|
||||
problem = content.split("\n", 1)[-1].strip()[:60]
|
||||
desc = f"ask_teacher: {problem}"
|
||||
result = await do_ask_teacher(content, session_id)
|
||||
result = await do_ask_teacher(content, session_id, owner=owner)
|
||||
|
||||
else:
|
||||
desc = f"unknown ai tool: {tool}"
|
||||
|
||||
75
tests/test_ai_interaction_owner_scope.py
Normal file
75
tests/test_ai_interaction_owner_scope.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from src import ai_interaction
|
||||
|
||||
|
||||
def _source(fn) -> str:
|
||||
return inspect.getsource(fn)
|
||||
|
||||
|
||||
def test_model_resolver_applies_owner_filter():
|
||||
body = _source(ai_interaction._resolve_model)
|
||||
|
||||
assert "owner: Optional[str] = None" in body
|
||||
assert "from src.auth_helpers import owner_filter" in body
|
||||
assert "owner_filter(query, ModelEndpoint, owner)" in body
|
||||
|
||||
|
||||
def test_model_listing_and_image_fallback_are_owner_scoped():
|
||||
list_body = _source(ai_interaction.do_list_models)
|
||||
image_body = _source(ai_interaction.do_generate_image)
|
||||
|
||||
assert "owner: Optional[str] = None" in list_body
|
||||
assert "owner_filter(query, ModelEndpoint, owner)" in list_body
|
||||
assert "_resolve_model(candidate, owner=owner)" in image_body
|
||||
assert "owner_filter(_img_q, ModelEndpoint, owner)" in image_body
|
||||
assert "_resolve_model(model_spec, owner=owner)" in image_body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool,content", [
|
||||
("chat_with_model", "gpt-test\nhello"),
|
||||
("pipeline", "gpt-test | summarize this"),
|
||||
("list_models", ""),
|
||||
("ui_control", "switch_model gpt-test"),
|
||||
("ask_teacher", "gpt-test\nhelp me"),
|
||||
])
|
||||
async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content):
|
||||
seen = {}
|
||||
|
||||
async def capture(name, content, session_id=None, owner=None):
|
||||
seen[name] = {"content": content, "session_id": session_id, "owner": owner}
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
ai_interaction,
|
||||
"do_chat_with_model",
|
||||
lambda content, session_id=None, owner=None: capture("chat_with_model", content, session_id, owner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ai_interaction,
|
||||
"do_pipeline",
|
||||
lambda content, session_id=None, owner=None: capture("pipeline", content, session_id, owner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ai_interaction,
|
||||
"do_list_models",
|
||||
lambda content, session_id=None, owner=None: capture("list_models", content, session_id, owner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ai_interaction,
|
||||
"do_ui_control",
|
||||
lambda content, session_id=None, owner=None: capture("ui_control", content, session_id, owner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ai_interaction,
|
||||
"do_ask_teacher",
|
||||
lambda content, session_id=None, owner=None: capture("ask_teacher", content, session_id, owner),
|
||||
)
|
||||
|
||||
_desc, result = await ai_interaction.dispatch_ai_tool(tool, content, session_id="sid1", owner="alice")
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert seen[tool]["owner"] == "alice"
|
||||
assert seen[tool]["session_id"] == "sid1"
|
||||
Reference in New Issue
Block a user