tests: cover companion models route filtering
This commit is contained in:
committed by
GitHub
parent
97528be0f4
commit
5607db85d4
@@ -9,6 +9,7 @@ rows, and legacy null-owner rows must not widen a token's access.
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -24,13 +25,129 @@ if "core.database" not in sys.modules:
|
||||
_db.ModelEndpoint = MagicMock()
|
||||
sys.modules["core.database"] = _db
|
||||
|
||||
from companion.routes import token_owner, owner_can_see
|
||||
import companion.routes as companion_routes
|
||||
from companion.routes import setup_companion_routes, token_owner, owner_can_see
|
||||
|
||||
|
||||
def _request(**state):
|
||||
return SimpleNamespace(state=SimpleNamespace(**state))
|
||||
|
||||
|
||||
class _Predicate:
|
||||
def __init__(self, check):
|
||||
self._check = check
|
||||
|
||||
def __call__(self, row):
|
||||
return self._check(row)
|
||||
|
||||
def __or__(self, other):
|
||||
return _Predicate(lambda row: self(row) or other(row))
|
||||
|
||||
|
||||
class _Column:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, value): # noqa: D401
|
||||
return _Predicate(lambda row: getattr(row, self.name) == value)
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
model_type = _Column("model_type")
|
||||
owner = _Column("owner")
|
||||
|
||||
|
||||
class _Query:
|
||||
def __init__(self, rows):
|
||||
self._rows = list(rows)
|
||||
|
||||
def filter(self, *predicates):
|
||||
self._rows = [
|
||||
row for row in self._rows
|
||||
if all(predicate(row) for predicate in predicates)
|
||||
]
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _DB:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is _ModelEndpoint
|
||||
return _Query(self._rows)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _ep(
|
||||
id,
|
||||
name,
|
||||
owner,
|
||||
*,
|
||||
is_enabled=True,
|
||||
model_type="llm",
|
||||
base_url=None,
|
||||
cached_models=None,
|
||||
hidden_models=None,
|
||||
supports_tools=False,
|
||||
api_key="secret-key",
|
||||
):
|
||||
return SimpleNamespace(
|
||||
id=id,
|
||||
name=name,
|
||||
owner=owner,
|
||||
is_enabled=is_enabled,
|
||||
model_type=model_type,
|
||||
base_url=base_url or f"https://{name}.example/v1",
|
||||
cached_models=json.dumps(cached_models or [f"{name}-model"]),
|
||||
hidden_models=json.dumps(hidden_models or []),
|
||||
supports_tools=supports_tools,
|
||||
api_key=api_key,
|
||||
headers={"Authorization": "Bearer secret-header"},
|
||||
)
|
||||
|
||||
|
||||
def _models_route():
|
||||
for route in setup_companion_routes().routes:
|
||||
if getattr(route, "path", "") == "/api/companion/models":
|
||||
assert "GET" in getattr(route, "methods", set())
|
||||
return route.endpoint
|
||||
raise AssertionError("GET /api/companion/models route not found")
|
||||
|
||||
|
||||
def _call_models_route(monkeypatch, rows, request):
|
||||
db = _DB(rows)
|
||||
db_mod = sys.modules["core.database"]
|
||||
monkeypatch.setattr(db_mod, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(db_mod, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
endpoint_mod = sys.modules.get("src.endpoint_resolver")
|
||||
if endpoint_mod is None:
|
||||
endpoint_mod = types.ModuleType("src.endpoint_resolver")
|
||||
sys.modules["src.endpoint_resolver"] = endpoint_mod
|
||||
monkeypatch.setattr(
|
||||
endpoint_mod,
|
||||
"build_chat_url",
|
||||
lambda base_url: f"{base_url.rstrip('/')}/chat/completions",
|
||||
raising=False,
|
||||
)
|
||||
|
||||
response = _models_route()(request)
|
||||
assert db.closed is True
|
||||
return response["endpoints"]
|
||||
|
||||
|
||||
def _endpoint_names(endpoints):
|
||||
return [endpoint["name"] for endpoint in endpoints]
|
||||
|
||||
|
||||
# --- token_owner: who a request is attributed to ---------------------------
|
||||
|
||||
def test_token_owner_bearer_resolves_to_token_owner():
|
||||
@@ -76,3 +193,133 @@ def test_unauthenticated_owner_sees_only_shared_rows():
|
||||
# never any owned row.
|
||||
assert owner_can_see(None, None) is True
|
||||
assert owner_can_see("alice", None) is False
|
||||
|
||||
|
||||
# --- GET /api/companion/models: route-level scoping -----------------------
|
||||
|
||||
def test_models_route_scopes_cookie_user_to_owned_and_shared_rows(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "alice-endpoint", "alice"),
|
||||
_ep(2, "shared-endpoint", None),
|
||||
_ep(3, "bob-endpoint", "bob"),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice")
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=False, current_user="ignored"),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["alice-endpoint", "shared-endpoint"]
|
||||
|
||||
|
||||
def test_models_route_scopes_api_token_to_token_owner(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "alice-endpoint", "alice"),
|
||||
_ep(2, "shared-endpoint", None),
|
||||
_ep(3, "bob-endpoint", "bob"),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "api")
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=True, api_token_owner="alice", current_user="api"),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["alice-endpoint", "shared-endpoint"]
|
||||
|
||||
|
||||
def test_models_route_unresolved_owner_returns_only_shared_rows(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "alice-endpoint", "alice"),
|
||||
_ep(2, "shared-endpoint", None),
|
||||
_ep(3, "bob-endpoint", "bob"),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: None)
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=True, api_token_owner=None, current_user="api"),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["shared-endpoint"]
|
||||
|
||||
|
||||
def test_models_route_filters_hidden_models_and_secret_fields(monkeypatch):
|
||||
rows = [
|
||||
_ep(
|
||||
1,
|
||||
"alice-endpoint",
|
||||
"alice",
|
||||
base_url="https://alice.example/v1",
|
||||
cached_models=["visible-model", "hidden-model"],
|
||||
hidden_models=["hidden-model"],
|
||||
supports_tools=True,
|
||||
api_key="super-secret",
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice")
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=False, current_user="alice"),
|
||||
)
|
||||
|
||||
assert endpoints == [{
|
||||
"endpoint_id": 1,
|
||||
"name": "alice-endpoint",
|
||||
"endpoint_url": "https://alice.example/v1/chat/completions",
|
||||
"models": ["visible-model"],
|
||||
"supports_tools": True,
|
||||
}]
|
||||
returned = endpoints[0]
|
||||
assert "hidden-model" not in returned["models"]
|
||||
assert set(returned) == {
|
||||
"endpoint_id",
|
||||
"name",
|
||||
"endpoint_url",
|
||||
"models",
|
||||
"supports_tools",
|
||||
}
|
||||
assert "api_key" not in returned
|
||||
assert "headers" not in returned
|
||||
assert "base_url" not in returned
|
||||
assert "super-secret" not in repr(returned)
|
||||
|
||||
|
||||
def test_models_route_filters_disabled_and_non_llm_endpoints(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "enabled-llm", "alice", is_enabled=True, model_type="llm"),
|
||||
_ep(2, "legacy-null-type", "alice", is_enabled=True, model_type=None),
|
||||
_ep(3, "disabled-llm", "alice", is_enabled=False, model_type="llm"),
|
||||
_ep(4, "image-endpoint", "alice", is_enabled=True, model_type="image"),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice")
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=False, current_user="alice"),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["enabled-llm", "legacy-null-type"]
|
||||
|
||||
|
||||
def test_models_route_returns_built_chat_url(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "alice-endpoint", "alice", base_url="https://raw.example/v1"),
|
||||
]
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "alice")
|
||||
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=False, current_user="alice"),
|
||||
)
|
||||
|
||||
assert endpoints[0]["endpoint_url"] == "https://raw.example/v1/chat/completions"
|
||||
assert endpoints[0]["endpoint_url"] != "https://raw.example/v1"
|
||||
|
||||
Reference in New Issue
Block a user