diff --git a/routes/compare_routes.py b/routes/compare_routes.py index 2d06e95..7886762 100644 --- a/routes/compare_routes.py +++ b/routes/compare_routes.py @@ -18,6 +18,26 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/compare", tags=["compare"]) +def _owned_endpoint_by_url(db, base_url, owner): + """ModelEndpoint whose base_url == `base_url` and is VISIBLE to `owner` + (their own rows + legacy null-owner "shared" rows); None otherwise. + + Owner-scoped on purpose. ModelEndpoint is per-user (core/database.py: non-null + owner = private, "the model picker only shows the endpoint to that user") and + holds a decrypted `api_key`. start_comparison copies the matched row's api_key + into the caller-owned [CMP] session's headers, which then drives that session's + /api/chat_stream calls — so an UNSCOPED base_url match would let a user mint a + comparison bound to ANOTHER user's private endpoint and spend that owner's + api_key / reach whatever base_url they configured. Mirrors + session_routes._owned_endpoint. A null/empty owner is a no-op (single-user / + legacy mode). + """ + from core.database import ModelEndpoint + from src.auth_helpers import owner_filter + q = db.query(ModelEndpoint).filter(ModelEndpoint.base_url == base_url) + return owner_filter(q, ModelEndpoint, owner).first() + + class RecordVoteRequest(BaseModel): prompt: str models: List[str] @@ -61,13 +81,11 @@ def setup_compare_routes(session_manager: SessionManager): # Copy API key from endpoint config db = SessionLocal() try: - from core.database import ModelEndpoint from src.endpoint_resolver import build_headers, normalize_base - # Find matching endpoint by URL + # Find matching endpoint by URL, scoped to the caller so a + # comparison can't borrow another user's private endpoint key. base = normalize_base(endpoint) - ep = db.query(ModelEndpoint).filter( - ModelEndpoint.base_url == base - ).first() + ep = _owned_endpoint_by_url(db, base, user) if ep and ep.api_key: s = session_manager.sessions.get(sid) if s: diff --git a/tests/test_compare_endpoint_owner_scope.py b/tests/test_compare_endpoint_owner_scope.py new file mode 100644 index 0000000..42a016c --- /dev/null +++ b/tests/test_compare_endpoint_owner_scope.py @@ -0,0 +1,121 @@ +"""Owner-scope regression for /api/compare/start endpoint-key resolution. + +start_comparison() takes caller-supplied endpoint URLs (endpoint_a/endpoint_b), +matches a ModelEndpoint by base_url, and copies that row's *decrypted* api_key +into the caller-owned [CMP] session's headers — which then drive that session's +/api/chat_stream calls. The match must be owner-scoped (the caller's own rows + +legacy null-owner shared rows) so a user can't mint a comparison bound to +ANOTHER user's private endpoint and spend their api_key / reach their base_url. +Mirrors the session `_owned_endpoint` and research `_owned_enabled_endpoint` +fixes. +""" + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +# Stub core.database so importing routes.compare_routes (which drags in +# core.session_manager) is cheap under the sqlalchemy MagicMock stubs. The +# helper resolves ModelEndpoint at call time; we swap in a fake declarative +# class below. owner_filter stays REAL. +if "core.database" not in sys.modules: + sys.modules["core.database"] = types.ModuleType("core.database") +_cd = sys.modules["core.database"] +_cd.Base = MagicMock() +for _name in ( + "Session", "ChatMessage", "Document", "DocumentVersion", "GalleryImage", + "GalleryAlbum", "SessionLocal", "Comparison", "ModelEndpoint", +): + if not hasattr(_cd, _name): + setattr(_cd, _name, MagicMock()) + +from routes.compare_routes import _owned_endpoint_by_url # noqa: E402 + + +class _Predicate: + def __init__(self, check): + self._check = check + + def __call__(self, row): + return self._check(row) + + def __or__(self, other): + return _Predicate(lambda row: self(row) or other(row)) + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, value): + return _Predicate(lambda row: getattr(row, self.name) == value) + + +class _ModelEndpoint: + base_url = _Column("base_url") + owner = _Column("owner") + + +class _Query: + def __init__(self, rows): + self._rows = list(rows) + + def filter(self, *predicates): + self._rows = [r for r in self._rows if all(p(r) for p in predicates)] + return self + + def first(self): + return self._rows[0] if self._rows else None + + +class _DB: + def __init__(self, rows): + self._rows = rows + + def query(self, model): + assert model is _ModelEndpoint + return _Query(self._rows) + + +def _ep(base_url, owner): + return SimpleNamespace(base_url=base_url, owner=owner, api_key="sk-secret") + + +def _resolve(rows, base_url, owner): + sys.modules["core.database"].ModelEndpoint = _ModelEndpoint + return _owned_endpoint_by_url(_DB(rows), base_url, owner) + + +URL = "https://api.example.com/v1" + + +def test_rejects_another_owners_private_endpoint(): + # bob owns the only endpoint at URL; alice supplying that URL gets None + # → no headers, no key copied into her comparison session. + rows = [_ep(URL, "bob")] + assert _resolve(rows, URL, "alice") is None + + +def test_returns_callers_own_endpoint(): + rows = [_ep(URL, "bob"), _ep(URL, "alice")] + ep = _resolve(rows, URL, "alice") + assert ep is not None and ep.owner == "alice" + + +def test_allows_legacy_null_owner_shared_row(): + rows = [_ep(URL, None)] + ep = _resolve(rows, URL, "alice") + assert ep is not None and ep.owner is None + + +def test_no_match_returns_none(): + rows = [_ep("https://other.example/v1", "alice")] + assert _resolve(rows, URL, "alice") is None + + +def test_null_owner_is_legacy_single_user_noop(): + # Single-user / unresolved owner: owner_filter no-op, exact URL match wins. + rows = [_ep(URL, "bob")] + ep = _resolve(rows, URL, None) + assert ep is not None and ep.owner == "bob"