fix(ai): scope tool model resolution by owner
* Stabilize full test collection * Scope AI tool model resolution by owner
This commit is contained in:
@@ -58,7 +58,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
|||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||||
"""Resolve a model specifier to (endpoint_url, model_id, headers).
|
"""Resolve a model specifier to (endpoint_url, model_id, headers).
|
||||||
|
|
||||||
Accepts:
|
Accepts:
|
||||||
@@ -70,6 +70,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
|||||||
import httpx
|
import httpx
|
||||||
from src.database import SessionLocal, ModelEndpoint
|
from src.database import SessionLocal, ModelEndpoint
|
||||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
|
||||||
spec = spec.strip()
|
spec = spec.strip()
|
||||||
target_endpoint_name = None
|
target_endpoint_name = None
|
||||||
@@ -86,6 +87,8 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
|||||||
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
if target_endpoint_name:
|
if target_endpoint_name:
|
||||||
query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%"))
|
query = query.filter(ModelEndpoint.name.ilike(f"%{target_endpoint_name}%"))
|
||||||
|
if owner:
|
||||||
|
query = owner_filter(query, ModelEndpoint, owner)
|
||||||
endpoints = query.all()
|
endpoints = query.all()
|
||||||
|
|
||||||
if not endpoints:
|
if not endpoints:
|
||||||
@@ -141,7 +144,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
|||||||
# Tool implementations
|
# Tool implementations
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def do_chat_with_model(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Send a message to a specific model and return its response.
|
"""Send a message to a specific model and return its response.
|
||||||
|
|
||||||
Content format:
|
Content format:
|
||||||
@@ -160,7 +163,7 @@ async def do_chat_with_model(content: str, session_id: Optional[str] = None) ->
|
|||||||
return {"error": "No message provided (line 2+ is the message)"}
|
return {"error": "No message provided (line 2+ is the message)"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url, model, headers = _resolve_model(model_spec)
|
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
@@ -190,7 +193,7 @@ _TEACHER_SYSTEM_PROMPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Ask a more capable model for help.
|
"""Ask a more capable model for help.
|
||||||
|
|
||||||
Content format:
|
Content format:
|
||||||
@@ -213,7 +216,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
|
|||||||
return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."}
|
return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url, model, headers = _resolve_model(model_spec)
|
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
@@ -235,7 +238,7 @@ async def do_ask_teacher(content: str, session_id: Optional[str] = None) -> Dict
|
|||||||
return {"error": f"Teacher call failed ({model_spec}): {e}"}
|
return {"error": f"Teacher call failed ({model_spec}): {e}"}
|
||||||
|
|
||||||
|
|
||||||
async def do_second_opinion(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_second_opinion(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Get a second opinion from another model, then have the original model
|
"""Get a second opinion from another model, then have the original model
|
||||||
evaluate the feedback and produce a unified version.
|
evaluate the feedback and produce a unified version.
|
||||||
|
|
||||||
@@ -259,7 +262,7 @@ async def do_second_opinion(content: str, session_id: Optional[str] = None) -> D
|
|||||||
focus = lines[1].strip() if len(lines) > 1 else ""
|
focus = lines[1].strip() if len(lines) > 1 else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec)
|
reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
@@ -400,7 +403,7 @@ async def do_create_session(content: str, session_id: Optional[str] = None, owne
|
|||||||
return {"error": "Session name cannot be empty"}
|
return {"error": "Session name cannot be empty"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url, model, headers = _resolve_model(model_spec)
|
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
@@ -584,7 +587,7 @@ async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = No
|
|||||||
yield {"_final": True, "desc": desc, "result": result}
|
yield {"_final": True, "desc": desc, "result": result}
|
||||||
|
|
||||||
|
|
||||||
async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Execute a multi-step pipeline where each model's output feeds the next.
|
"""Execute a multi-step pipeline where each model's output feeds the next.
|
||||||
|
|
||||||
Content format (JSON):
|
Content format (JSON):
|
||||||
@@ -638,7 +641,7 @@ async def do_pipeline(content: str, session_id: Optional[str] = None) -> Dict:
|
|||||||
if not model_spec or not instruction:
|
if not model_spec or not instruction:
|
||||||
return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"}
|
return {"error": f"Step {i + 1}: both 'model' and 'instruction' are required"}
|
||||||
try:
|
try:
|
||||||
url, model, headers = _resolve_model(model_spec)
|
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||||
resolved.append((url, model, headers, instruction))
|
resolved.append((url, model, headers, instruction))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": f"Step {i + 1}: {e}"}
|
return {"error": f"Step {i + 1}: {e}"}
|
||||||
@@ -1091,7 +1094,7 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner
|
|||||||
# List models tool
|
# List models tool
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""List all available models across configured endpoints.
|
"""List all available models across configured endpoints.
|
||||||
|
|
||||||
Content = optional filter keyword.
|
Content = optional filter keyword.
|
||||||
@@ -1099,12 +1102,16 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict
|
|||||||
import httpx
|
import httpx
|
||||||
from src.database import SessionLocal, ModelEndpoint
|
from src.database import SessionLocal, ModelEndpoint
|
||||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
|
||||||
keyword = content.strip().lower() if content.strip() else None
|
keyword = content.strip().lower() if content.strip() else None
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
if owner:
|
||||||
|
query = owner_filter(query, ModelEndpoint, owner)
|
||||||
|
endpoints = query.all()
|
||||||
if not endpoints:
|
if not endpoints:
|
||||||
return {"results": "No enabled model endpoints configured."}
|
return {"results": "No enabled model endpoints configured."}
|
||||||
|
|
||||||
@@ -1250,7 +1257,7 @@ async def do_manage_rag(content: str, session_id: Optional[str] = None) -> Dict:
|
|||||||
# UI control tool (returns events for frontend to apply)
|
# UI control tool (returns events for frontend to apply)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict:
|
async def do_ui_control(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||||
"""Control frontend UI: toggle settings, switch model, change theme.
|
"""Control frontend UI: toggle settings, switch model, change theme.
|
||||||
|
|
||||||
Content format:
|
Content format:
|
||||||
@@ -1325,7 +1332,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None) -> Dict:
|
|||||||
|
|
||||||
# Resolve the model to validate it exists
|
# Resolve the model to validate it exists
|
||||||
try:
|
try:
|
||||||
url, model_id, headers = _resolve_model(model_spec)
|
url, model_id, headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
@@ -1580,7 +1587,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
if not model_spec:
|
if not model_spec:
|
||||||
for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"):
|
for candidate in ("gpt-image-1.5", "gpt-image-1", "dall-e-3"):
|
||||||
try:
|
try:
|
||||||
_resolve_model(candidate)
|
_resolve_model(candidate, owner=owner)
|
||||||
model_spec = candidate
|
model_spec = candidate
|
||||||
break
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -1589,13 +1596,17 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
if not model_spec:
|
if not model_spec:
|
||||||
try:
|
try:
|
||||||
from src.database import SessionLocal, ModelEndpoint
|
from src.database import SessionLocal, ModelEndpoint
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
import httpx as _req
|
import httpx as _req
|
||||||
_idb = SessionLocal()
|
_idb = SessionLocal()
|
||||||
try:
|
try:
|
||||||
_img_eps = _idb.query(ModelEndpoint).filter(
|
_img_q = _idb.query(ModelEndpoint).filter(
|
||||||
ModelEndpoint.is_enabled == True,
|
ModelEndpoint.is_enabled == True,
|
||||||
ModelEndpoint.model_type == "image",
|
ModelEndpoint.model_type == "image",
|
||||||
).all()
|
)
|
||||||
|
if owner:
|
||||||
|
_img_q = owner_filter(_img_q, ModelEndpoint, owner)
|
||||||
|
_img_eps = _img_q.all()
|
||||||
for _iep in _img_eps:
|
for _iep in _img_eps:
|
||||||
_ibase = _iep.base_url.rstrip("/")
|
_ibase = _iep.base_url.rstrip("/")
|
||||||
if not _ibase.endswith("/v1"):
|
if not _ibase.endswith("/v1"):
|
||||||
@@ -1618,7 +1629,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
|
|
||||||
# Resolve the model to find the right endpoint
|
# Resolve the model to find the right endpoint
|
||||||
try:
|
try:
|
||||||
url, model_id, headers = _resolve_model(model_spec)
|
url, model_id, headers = _resolve_model(model_spec, owner=owner)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {"error": f"No endpoint found with image model '{model_spec}'. "
|
return {"error": f"No endpoint found with image model '{model_spec}'. "
|
||||||
"Configure an OpenAI-compatible endpoint with image generation support."}
|
"Configure an OpenAI-compatible endpoint with image generation support."}
|
||||||
@@ -1760,7 +1771,7 @@ async def dispatch_ai_tool(
|
|||||||
if tool == "chat_with_model":
|
if tool == "chat_with_model":
|
||||||
model_spec = content.split("\n")[0].strip()[:60]
|
model_spec = content.split("\n")[0].strip()[:60]
|
||||||
desc = f"chat_with_model: {model_spec}"
|
desc = f"chat_with_model: {model_spec}"
|
||||||
result = await do_chat_with_model(content, session_id)
|
result = await do_chat_with_model(content, session_id, owner=owner)
|
||||||
|
|
||||||
elif tool == "create_session":
|
elif tool == "create_session":
|
||||||
name = content.split("\n")[0].strip()[:60]
|
name = content.split("\n")[0].strip()[:60]
|
||||||
@@ -1779,7 +1790,7 @@ async def dispatch_ai_tool(
|
|||||||
|
|
||||||
elif tool == "pipeline":
|
elif tool == "pipeline":
|
||||||
desc = "pipeline: running steps"
|
desc = "pipeline: running steps"
|
||||||
result = await do_pipeline(content, session_id)
|
result = await do_pipeline(content, session_id, owner=owner)
|
||||||
|
|
||||||
elif tool == "manage_session":
|
elif tool == "manage_session":
|
||||||
action = content.split("\n")[0].strip()[:40]
|
action = content.split("\n")[0].strip()[:40]
|
||||||
@@ -1794,17 +1805,17 @@ async def dispatch_ai_tool(
|
|||||||
elif tool == "list_models":
|
elif tool == "list_models":
|
||||||
keyword = content.strip()[:40]
|
keyword = content.strip()[:40]
|
||||||
desc = f"list_models{': ' + keyword if keyword else ''}"
|
desc = f"list_models{': ' + keyword if keyword else ''}"
|
||||||
result = await do_list_models(content, session_id)
|
result = await do_list_models(content, session_id, owner=owner)
|
||||||
|
|
||||||
elif tool == "ui_control":
|
elif tool == "ui_control":
|
||||||
action = content.split("\n")[0].strip()[:60]
|
action = content.split("\n")[0].strip()[:60]
|
||||||
desc = f"ui_control: {action}"
|
desc = f"ui_control: {action}"
|
||||||
result = await do_ui_control(content, session_id)
|
result = await do_ui_control(content, session_id, owner=owner)
|
||||||
|
|
||||||
elif tool == "ask_teacher":
|
elif tool == "ask_teacher":
|
||||||
problem = content.split("\n", 1)[-1].strip()[:60]
|
problem = content.split("\n", 1)[-1].strip()[:60]
|
||||||
desc = f"ask_teacher: {problem}"
|
desc = f"ask_teacher: {problem}"
|
||||||
result = await do_ask_teacher(content, session_id)
|
result = await do_ask_teacher(content, session_id, owner=owner)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
desc = f"unknown ai tool: {tool}"
|
desc = f"unknown ai tool: {tool}"
|
||||||
|
|||||||
75
tests/test_ai_interaction_owner_scope.py
Normal file
75
tests/test_ai_interaction_owner_scope.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src import ai_interaction
|
||||||
|
|
||||||
|
|
||||||
|
def _source(fn) -> str:
|
||||||
|
return inspect.getsource(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_resolver_applies_owner_filter():
|
||||||
|
body = _source(ai_interaction._resolve_model)
|
||||||
|
|
||||||
|
assert "owner: Optional[str] = None" in body
|
||||||
|
assert "from src.auth_helpers import owner_filter" in body
|
||||||
|
assert "owner_filter(query, ModelEndpoint, owner)" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_listing_and_image_fallback_are_owner_scoped():
|
||||||
|
list_body = _source(ai_interaction.do_list_models)
|
||||||
|
image_body = _source(ai_interaction.do_generate_image)
|
||||||
|
|
||||||
|
assert "owner: Optional[str] = None" in list_body
|
||||||
|
assert "owner_filter(query, ModelEndpoint, owner)" in list_body
|
||||||
|
assert "_resolve_model(candidate, owner=owner)" in image_body
|
||||||
|
assert "owner_filter(_img_q, ModelEndpoint, owner)" in image_body
|
||||||
|
assert "_resolve_model(model_spec, owner=owner)" in image_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tool,content", [
|
||||||
|
("chat_with_model", "gpt-test\nhello"),
|
||||||
|
("pipeline", "gpt-test | summarize this"),
|
||||||
|
("list_models", ""),
|
||||||
|
("ui_control", "switch_model gpt-test"),
|
||||||
|
("ask_teacher", "gpt-test\nhelp me"),
|
||||||
|
])
|
||||||
|
async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content):
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
async def capture(name, content, session_id=None, owner=None):
|
||||||
|
seen[name] = {"content": content, "session_id": session_id, "owner": owner}
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ai_interaction,
|
||||||
|
"do_chat_with_model",
|
||||||
|
lambda content, session_id=None, owner=None: capture("chat_with_model", content, session_id, owner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ai_interaction,
|
||||||
|
"do_pipeline",
|
||||||
|
lambda content, session_id=None, owner=None: capture("pipeline", content, session_id, owner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ai_interaction,
|
||||||
|
"do_list_models",
|
||||||
|
lambda content, session_id=None, owner=None: capture("list_models", content, session_id, owner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ai_interaction,
|
||||||
|
"do_ui_control",
|
||||||
|
lambda content, session_id=None, owner=None: capture("ui_control", content, session_id, owner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ai_interaction,
|
||||||
|
"do_ask_teacher",
|
||||||
|
lambda content, session_id=None, owner=None: capture("ask_teacher", content, session_id, owner),
|
||||||
|
)
|
||||||
|
|
||||||
|
_desc, result = await ai_interaction.dispatch_ai_tool(tool, content, session_id="sid1", owner="alice")
|
||||||
|
|
||||||
|
assert result == {"ok": True}
|
||||||
|
assert seen[tool]["owner"] == "alice"
|
||||||
|
assert seen[tool]["session_id"] == "sid1"
|
||||||
Reference in New Issue
Block a user