diff --git a/tests/test_companion_readonly.py b/tests/test_companion_readonly.py index 2ea1edb..357cbab 100644 --- a/tests/test_companion_readonly.py +++ b/tests/test_companion_readonly.py @@ -9,6 +9,7 @@ rows, and legacy null-owner rows must not widen a token's access. import os import sys import types +import json from types import SimpleNamespace from unittest.mock import MagicMock @@ -24,13 +25,129 @@ if "core.database" not in sys.modules: _db.ModelEndpoint = MagicMock() sys.modules["core.database"] = _db -from companion.routes import token_owner, owner_can_see +import companion.routes as companion_routes +from companion.routes import setup_companion_routes, token_owner, owner_can_see def _request(**state): return SimpleNamespace(state=SimpleNamespace(**state)) +class _Predicate: + def __init__(self, check): + self._check = check + + def __call__(self, row): + return self._check(row) + + def __or__(self, other): + return _Predicate(lambda row: self(row) or other(row)) + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, value): # noqa: D401 + return _Predicate(lambda row: getattr(row, self.name) == value) + + +class _ModelEndpoint: + is_enabled = _Column("is_enabled") + model_type = _Column("model_type") + owner = _Column("owner") + + +class _Query: + def __init__(self, rows): + self._rows = list(rows) + + def filter(self, *predicates): + self._rows = [ + row for row in self._rows + if all(predicate(row) for predicate in predicates) + ] + return self + + def all(self): + return list(self._rows) + + +class _DB: + def __init__(self, rows): + self._rows = rows + self.closed = False + + def query(self, model): + assert model is _ModelEndpoint + return _Query(self._rows) + + def close(self): + self.closed = True + + +def _ep( + id, + name, + owner, + *, + is_enabled=True, + model_type="llm", + base_url=None, + cached_models=None, + hidden_models=None, + supports_tools=False, + api_key="secret-key", +): + return SimpleNamespace( + id=id, + name=name, + owner=owner, + is_enabled=is_enabled, + model_type=model_type, + base_url=base_url or f"https://{name}.example/v1", + cached_models=json.dumps(cached_models or [f"{name}-model"]), + hidden_models=json.dumps(hidden_models or []), + supports_tools=supports_tools, + api_key=api_key, + headers={"Authorization": "Bearer secret-header"}, + ) + + +def _models_route(): + for route in setup_companion_routes().routes: + if getattr(route, "path", "") == "/api/companion/models": + assert "GET" in getattr(route, "methods", set()) + return route.endpoint + raise AssertionError("GET /api/companion/models route not found") + + +def _call_models_route(monkeypatch, rows, request): + db = _DB(rows) + db_mod = sys.modules["core.database"] + monkeypatch.setattr(db_mod, "SessionLocal", lambda: db) + monkeypatch.setattr(db_mod, "ModelEndpoint", _ModelEndpoint) + + endpoint_mod = sys.modules.get("src.endpoint_resolver") + if endpoint_mod is None: + endpoint_mod = types.ModuleType("src.endpoint_resolver") + sys.modules["src.endpoint_resolver"] = endpoint_mod + monkeypatch.setattr( + endpoint_mod, + "build_chat_url", + lambda base_url: f"{base_url.rstrip('/')}/chat/completions", + raising=False, + ) + + response = _models_route()(request) + assert db.closed is True + return response["endpoints"] + + +def _endpoint_names(endpoints): + return [endpoint["name"] for endpoint in endpoints] + + # --- token_owner: who a request is attributed to --------------------------- def test_token_owner_bearer_resolves_to_token_owner(): @@ -76,3 +193,133 @@ def test_unauthenticated_owner_sees_only_shared_rows(): # never any owned row. assert owner_can_see(None, None) is True assert owner_can_see("alice", None) is False + + +# --- GET /api/companion/models: route-level scoping ----------------------- + +def test_models_route_scopes_cookie_user_to_owned_and_shared_rows(monkeypatch): + rows = [ + _ep(1, "alice-endpoint", "alice"), + _ep(2, "shared-endpoint", None), + _ep(3, "bob-endpoint", "bob"), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice") + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=False, current_user="ignored"), + ) + + assert _endpoint_names(endpoints) == ["alice-endpoint", "shared-endpoint"] + + +def test_models_route_scopes_api_token_to_token_owner(monkeypatch): + rows = [ + _ep(1, "alice-endpoint", "alice"), + _ep(2, "shared-endpoint", None), + _ep(3, "bob-endpoint", "bob"), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "api") + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=True, api_token_owner="alice", current_user="api"), + ) + + assert _endpoint_names(endpoints) == ["alice-endpoint", "shared-endpoint"] + + +def test_models_route_unresolved_owner_returns_only_shared_rows(monkeypatch): + rows = [ + _ep(1, "alice-endpoint", "alice"), + _ep(2, "shared-endpoint", None), + _ep(3, "bob-endpoint", "bob"), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: None) + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=True, api_token_owner=None, current_user="api"), + ) + + assert _endpoint_names(endpoints) == ["shared-endpoint"] + + +def test_models_route_filters_hidden_models_and_secret_fields(monkeypatch): + rows = [ + _ep( + 1, + "alice-endpoint", + "alice", + base_url="https://alice.example/v1", + cached_models=["visible-model", "hidden-model"], + hidden_models=["hidden-model"], + supports_tools=True, + api_key="super-secret", + ), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice") + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=False, current_user="alice"), + ) + + assert endpoints == [{ + "endpoint_id": 1, + "name": "alice-endpoint", + "endpoint_url": "https://alice.example/v1/chat/completions", + "models": ["visible-model"], + "supports_tools": True, + }] + returned = endpoints[0] + assert "hidden-model" not in returned["models"] + assert set(returned) == { + "endpoint_id", + "name", + "endpoint_url", + "models", + "supports_tools", + } + assert "api_key" not in returned + assert "headers" not in returned + assert "base_url" not in returned + assert "super-secret" not in repr(returned) + + +def test_models_route_filters_disabled_and_non_llm_endpoints(monkeypatch): + rows = [ + _ep(1, "enabled-llm", "alice", is_enabled=True, model_type="llm"), + _ep(2, "legacy-null-type", "alice", is_enabled=True, model_type=None), + _ep(3, "disabled-llm", "alice", is_enabled=False, model_type="llm"), + _ep(4, "image-endpoint", "alice", is_enabled=True, model_type="image"), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice") + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=False, current_user="alice"), + ) + + assert _endpoint_names(endpoints) == ["enabled-llm", "legacy-null-type"] + + +def test_models_route_returns_built_chat_url(monkeypatch): + rows = [ + _ep(1, "alice-endpoint", "alice", base_url="https://raw.example/v1"), + ] + monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice") + + endpoints = _call_models_route( + monkeypatch, + rows, + _request(api_token=False, current_user="alice"), + ) + + assert endpoints[0]["endpoint_url"] == "https://raw.example/v1/chat/completions" + assert endpoints[0]["endpoint_url"] != "https://raw.example/v1"