feat(models): support pinned endpoint model IDs
This commit is contained in:
committed by
GitHub
parent
1284b14a13
commit
145f4fd2b4
@@ -1,6 +1,9 @@
|
||||
"""Tests for model route helper functions — pure logic, no server needed."""
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
_visible_models,
|
||||
_normalize_model_ids,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_probe_endpoint,
|
||||
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
|
||||
|
||||
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
|
||||
assert model_routes._docker_host_gateway_reachable() is False
|
||||
|
||||
|
||||
# ── pinned model IDs: normalization helper ──
|
||||
|
||||
|
||||
class TestNormalizeModelIds:
|
||||
def test_list_passthrough_trims_and_dedupes(self):
|
||||
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
|
||||
|
||||
def test_json_string_list(self):
|
||||
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
|
||||
|
||||
def test_comma_and_newline_string(self):
|
||||
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
|
||||
|
||||
def test_none_and_empty(self):
|
||||
assert _normalize_model_ids(None) == []
|
||||
assert _normalize_model_ids("") == []
|
||||
assert _normalize_model_ids(" ") == []
|
||||
|
||||
def test_non_string_values_ignored(self):
|
||||
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
|
||||
|
||||
|
||||
# ── pinned model IDs: _visible_models merge ──
|
||||
|
||||
|
||||
class TestVisibleModelsPinned:
|
||||
def test_includes_pinned_not_in_cached(self):
|
||||
visible = _visible_models(["a"], None, ["deploy-1"])
|
||||
assert visible == ["a", "deploy-1"]
|
||||
|
||||
def test_cached_plus_pinned_dedup_preserves_order(self):
|
||||
visible = _visible_models(["a", "b"], None, ["b", "c"])
|
||||
assert visible == ["a", "b", "c"]
|
||||
|
||||
def test_hidden_can_hide_a_pinned_model(self):
|
||||
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
|
||||
assert visible == ["a"]
|
||||
|
||||
def test_accepts_json_string_inputs(self):
|
||||
visible = _visible_models('["a"]', '["a"]', '["b"]')
|
||||
assert visible == ["b"]
|
||||
|
||||
|
||||
# ── pinned model IDs: route behaviour ──
|
||||
|
||||
# Building the router exercises FastAPI's Form() routes, which require
|
||||
# python-multipart. The test env ships without it, so register a minimal stub
|
||||
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
|
||||
if "python_multipart" not in sys.modules:
|
||||
try:
|
||||
import python_multipart # noqa: F401
|
||||
except ImportError:
|
||||
_mp_stub = types.ModuleType("python_multipart")
|
||||
_mp_stub.__version__ = "0.0.13"
|
||||
sys.modules["python_multipart"] = _mp_stub
|
||||
|
||||
|
||||
class _PinnedFakeQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
return self
|
||||
|
||||
def order_by(self, *args):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _PinnedFakeDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.added = []
|
||||
self.committed = 0
|
||||
|
||||
def query(self, model):
|
||||
return _PinnedFakeQuery(self.rows)
|
||||
|
||||
def add(self, row):
|
||||
self.added.append(row)
|
||||
|
||||
def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeCol:
|
||||
"""Column stand-in: every comparison/operator just returns itself so the
|
||||
dedupe query expressions evaluate without a real SQLAlchemy column."""
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self
|
||||
|
||||
def is_(self, other):
|
||||
return self
|
||||
|
||||
def __or__(self, other):
|
||||
return self
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _RecordingEndpoint:
|
||||
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
|
||||
|
||||
Class-level fake columns let it double as the query class in the dedupe
|
||||
lookup; instance attributes (set in __init__) shadow them per-row.
|
||||
"""
|
||||
|
||||
id = _FakeCol()
|
||||
base_url = _FakeCol()
|
||||
owner = _FakeCol()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _PinnedFakeRequest:
|
||||
def __init__(self, body=None, headers=None):
|
||||
self._body = body if body is not None else {}
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self):
|
||||
return self._body
|
||||
|
||||
|
||||
def _get_route(path, method):
|
||||
from routes.model_routes import setup_model_routes
|
||||
router = setup_model_routes(model_discovery=None)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{method} {path} not found")
|
||||
|
||||
|
||||
def _make_endpoint(**kwargs):
|
||||
base = dict(
|
||||
id="ep1",
|
||||
name="EP",
|
||||
base_url="http://localhost:9999/v1",
|
||||
api_key=None,
|
||||
is_enabled=True,
|
||||
hidden_models=None,
|
||||
cached_models=None,
|
||||
pinned_models=None,
|
||||
model_type="llm",
|
||||
supports_tools=None,
|
||||
)
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_patch_models_saves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint()
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
|
||||
result = asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
|
||||
assert result["pinned_count"] == 2
|
||||
|
||||
|
||||
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
|
||||
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
|
||||
asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.hidden_models) == ["hide-me"]
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
|
||||
|
||||
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
|
||||
|
||||
result = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
ids = [row["id"] for row in result]
|
||||
assert ids == ["deploy-1"]
|
||||
assert result[0]["is_pinned"] is True
|
||||
|
||||
|
||||
def test_reprobe_preserves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
|
||||
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||
monkeypatch.setattr(
|
||||
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
|
||||
)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||
|
||||
response = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
async def _drain():
|
||||
async for _ in response.body_iterator:
|
||||
pass
|
||||
|
||||
asyncio.run(_drain())
|
||||
|
||||
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
assert json.loads(ep.cached_models) == ["m1"]
|
||||
|
||||
|
||||
def test_visible_models_handles_malformed_strings():
|
||||
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||
# never raise; a malformed hidden string is normalized too.
|
||||
result = _visible_models("a,b", "b", "{bad json")
|
||||
assert isinstance(result, list)
|
||||
assert result == ["a", "{bad json"]
|
||||
assert _visible_models("", None, "") == []
|
||||
assert _visible_models("only-cached", None, None) == ["only-cached"]
|
||||
|
||||
|
||||
def _create_form_kwargs(**overrides):
|
||||
"""Defaults for every Form() param create_model_endpoint reads directly.
|
||||
|
||||
Calling the route as a plain function bypasses FastAPI form parsing, so the
|
||||
Form() sentinels must be replaced with real strings.
|
||||
"""
|
||||
kwargs = dict(
|
||||
name="",
|
||||
api_key="",
|
||||
skip_probe="true", # avoid any network probe in unit tests
|
||||
require_models="false",
|
||||
model_type="llm",
|
||||
supports_tools="",
|
||||
pinned_models="",
|
||||
container_local="false",
|
||||
shared="true",
|
||||
)
|
||||
kwargs.update(overrides)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _patch_create_deps(monkeypatch, db):
|
||||
import src.auth_helpers as auth_helpers
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
||||
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
||||
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
||||
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
||||
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
||||
|
||||
|
||||
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
|
||||
db = _PinnedFakeDb([]) # no existing row → fresh create path
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
|
||||
)
|
||||
|
||||
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["online"] is True
|
||||
# Persisted onto the created row.
|
||||
assert len(db.added) == 1
|
||||
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
|
||||
|
||||
|
||||
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
hidden_models=None,
|
||||
pinned_models=json.dumps(["old-pin"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="new-pin"),
|
||||
)
|
||||
|
||||
assert result["existing"] is True
|
||||
# Incoming pin merged onto the existing pins (no clobber, order preserved).
|
||||
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
|
||||
assert result["pinned_models"] == ["old-pin", "new-pin"]
|
||||
# models = cached + pinned - hidden, visible merged list.
|
||||
assert result["models"] == ["m1", "old-pin", "new-pin"]
|
||||
# No new row created on the dedupe path.
|
||||
assert db.added == []
|
||||
|
||||
|
||||
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
pinned_models=json.dumps(["keep-me"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(), # pinned_models defaults to ""
|
||||
)
|
||||
|
||||
assert json.loads(existing.pinned_models) == ["keep-me"]
|
||||
assert result["pinned_models"] == ["keep-me"]
|
||||
assert db.committed == 0 # nothing to persist
|
||||
|
||||
Reference in New Issue
Block a user