diff --git a/routes/research_routes.py b/routes/research_routes.py index c075002..267ab50 100644 --- a/routes/research_routes.py +++ b/routes/research_routes.py @@ -48,6 +48,30 @@ def _resolve_research_endpoint(sess) -> tuple: return url, model, headers +def _owned_enabled_endpoint(db, owner, endpoint_id=None): + """An enabled ModelEndpoint VISIBLE to `owner` (their own rows + legacy + null-owner "shared" rows), optionally narrowed to a specific endpoint_id; + None if nothing visible matches. + + Owner-scoped on purpose. ModelEndpoint is per-user (core/database.py: non-null + owner = private, "the model picker only shows the endpoint to that user") and + holds a decrypted `api_key`. /api/research/start feeds the resolved row's + api_key + base_url into research_handler.start_research(llm_endpoint=, + llm_headers=), so an UNSCOPED lookup — by the caller-supplied endpoint_id, or + via the bare first-enabled fallback — would let a research-privileged user + spend ANOTHER user's API key/quota and reach whatever internal base_url they + configured. Mirrors webhook_routes._first_enabled_endpoint and + session_routes._owned_endpoint. A null/empty owner is a no-op (single-user / + legacy mode). + """ + from src.database import ModelEndpoint + from src.auth_helpers import owner_filter + q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712 + if endpoint_id: + q = q.filter(ModelEndpoint.id == endpoint_id) + return owner_filter(q, ModelEndpoint, owner).first() + + def setup_research_routes(research_handler, session_manager=None) -> APIRouter: router = APIRouter(tags=["research"]) @@ -344,14 +368,13 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: if body.endpoint_id: from src.database import SessionLocal - from src.database import ModelEndpoint from src.endpoint_resolver import normalize_base, build_chat_url, build_headers db = SessionLocal() try: - ep = db.query(ModelEndpoint).filter( - ModelEndpoint.id == body.endpoint_id, - ModelEndpoint.is_enabled == True, - ).first() + # Owner-scoped: never resolve another user's private endpoint + # (and its decrypted api_key / internal base_url). A scoped miss + # reads as 404 so the endpoint's existence isn't revealed. + ep = _owned_enabled_endpoint(db, user, body.endpoint_id) if not ep: raise HTTPException(404, "Endpoint not found or disabled") base = normalize_base(ep.base_url) @@ -382,13 +405,14 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: ep_url, ep_model, ep_headers = resolve_endpoint("chat") if not ep_url: from src.database import SessionLocal - from src.database import ModelEndpoint from src.endpoint_resolver import normalize_base, build_chat_url, build_headers db = SessionLocal() try: - ep = db.query(ModelEndpoint).filter( - ModelEndpoint.is_enabled == True, - ).first() + # Owner-scoped first-enabled fallback: the caller's own rows + # + legacy null-owner shared rows only — never borrow another + # user's private endpoint/api_key. Same fix as the + # /api/v1/chat fallback (webhook_routes._first_enabled_endpoint). + ep = _owned_enabled_endpoint(db, user) if ep: base = normalize_base(ep.base_url) ep_url = build_chat_url(base) diff --git a/tests/test_research_endpoint_owner_scope.py b/tests/test_research_endpoint_owner_scope.py new file mode 100644 index 0000000..baa71d3 --- /dev/null +++ b/tests/test_research_endpoint_owner_scope.py @@ -0,0 +1,131 @@ +"""Owner-scope regression for /api/research/start endpoint resolution. + +`research_start()` resolves a CALLER-SUPPLIED `endpoint_id` (and, with nothing +configured, a bare first-enabled fallback) to a `ModelEndpoint` whose *decrypted* +api_key + base_url then drive the research LLM calls +(`start_research(llm_endpoint=, llm_headers=)`). Both lookups must be +owner-scoped — the caller's own rows plus legacy null-owner ("shared") rows — +so a research-privileged user (or a chat-scoped token) can't bind a research run +to ANOTHER user's PRIVATE endpoint and silently spend that owner's API key / +reach whatever internal base_url they configured. Mirrors the +webhook `_first_enabled_endpoint` (#1045) and session `_owned_endpoint` fixes. +""" + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +# The helper resolves `from src.database import ModelEndpoint` at call time. +# Stub the module so we can hand it a fake declarative class whose column +# comparisons return inspectable predicates (the real one is a SQLAlchemy +# class, MagicMock'd to oblivion by conftest). owner_filter stays REAL. +_sd = types.ModuleType("src.database") +_sd.ModelEndpoint = MagicMock() +sys.modules.setdefault("src.database", _sd) + +from routes.research_routes import _owned_enabled_endpoint # noqa: E402 + + +class _Predicate: + def __init__(self, check): + self._check = check + + def __call__(self, row): + return self._check(row) + + def __or__(self, other): + return _Predicate(lambda row: self(row) or other(row)) + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, value): + return _Predicate(lambda row: getattr(row, self.name) == value) + + +class _ModelEndpoint: + id = _Column("id") + is_enabled = _Column("is_enabled") + owner = _Column("owner") + + +class _Query: + def __init__(self, rows): + self._rows = list(rows) + + def filter(self, *predicates): + self._rows = [r for r in self._rows if all(p(r) for p in predicates)] + return self + + def first(self): + return self._rows[0] if self._rows else None + + +class _DB: + def __init__(self, rows): + self._rows = rows + + def query(self, model): + assert model is _ModelEndpoint + return _Query(self._rows) + + +def _ep(eid, owner, *, is_enabled=True): + return SimpleNamespace(id=eid, owner=owner, is_enabled=is_enabled, api_key="sk-secret") + + +def _resolve(rows, owner, endpoint_id=None): + sys.modules["src.database"].ModelEndpoint = _ModelEndpoint + return _owned_enabled_endpoint(_DB(rows), owner, endpoint_id) + + +# --- explicit endpoint_id (POST /api/research/start, body.endpoint_id) -------- + +def test_endpoint_id_rejects_another_owners_private_endpoint(): + # bob's private endpoint exists, but alice asking for it by id resolves None + # → the route raises 404 ("Endpoint not found or disabled"), never builds + # headers from bob's key. + rows = [_ep("ep-bob", "bob"), _ep("ep-alice", "alice")] + assert _resolve(rows, "alice", "ep-bob") is None + + +def test_endpoint_id_returns_callers_own_endpoint(): + rows = [_ep("ep-bob", "bob"), _ep("ep-alice", "alice")] + ep = _resolve(rows, "alice", "ep-alice") + assert ep is not None and ep.id == "ep-alice" + + +def test_endpoint_id_allows_legacy_null_owner_shared_row(): + rows = [_ep("ep-shared", None)] + ep = _resolve(rows, "alice", "ep-shared") + assert ep is not None and ep.id == "ep-shared" + + +def test_endpoint_id_skips_disabled_even_when_owned(): + rows = [_ep("ep-alice", "alice", is_enabled=False)] + assert _resolve(rows, "alice", "ep-alice") is None + + +# --- bare first-enabled fallback (no endpoint_id, nothing configured) --------- + +def test_fallback_never_picks_another_owners_endpoint(): + # bob's private endpoint is first in the table, alice must never borrow it. + rows = [_ep("ep-bob", "bob"), _ep("ep-shared", None)] + ep = _resolve(rows, "alice") + assert ep is not None and ep.id == "ep-shared" + + +def test_fallback_returns_none_when_only_others_endpoints(): + rows = [_ep("ep-bob", "bob"), _ep("ep-carol", "carol")] + assert _resolve(rows, "alice") is None + + +# --- legacy single-user / unresolved owner: owner_filter no-op --------------- + +def test_null_owner_is_legacy_single_user_noop(): + rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")] + ep = _resolve(rows, None, "ep-x") + assert ep is not None and ep.id == "ep-x"