diff --git a/core/database.py b/core/database.py index 293a303..d530171 100644 --- a/core/database.py +++ b/core/database.py @@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base): is_enabled = Column(Boolean, default=True) hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) + pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models) model_type = Column(String, nullable=True, default="llm") # "llm" or "image" # Whether models on this endpoint accept OpenAI-style function # schemas + emit `tool_calls`. Auto-detected at Cookbook auto- @@ -856,6 +857,24 @@ def _migrate_add_cached_models_column(): except Exception as e: logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") +def _migrate_add_pinned_models_column(): + """Add pinned_models column to model_endpoints if it doesn't exist.""" + import sqlite3 + db_path = DATABASE_URL.replace("sqlite:///", "") + if not os.path.exists(db_path): + return + try: + conn = sqlite3.connect(db_path) + cursor = conn.execute("PRAGMA table_info(model_endpoints)") + columns = [row[1] for row in cursor.fetchall()] + if columns and "pinned_models" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT") + conn.commit() + logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints") + conn.close() + except Exception as e: + logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}") + def _migrate_add_notes_sort_order(): """Add sort_order, image_url, repeat columns to notes if they don't exist.""" import sqlite3 @@ -1511,6 +1530,7 @@ def init_db(): Base.metadata.create_all(bind=engine) _migrate_add_hidden_models_column() _migrate_add_cached_models_column() + _migrate_add_pinned_models_column() _migrate_add_notes_sort_order() _migrate_add_model_type_column() _migrate_add_model_endpoint_owner_column() diff --git a/routes/model_routes.py b/routes/model_routes.py index 0135d1c..f4153b0 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> return "No models found for that provider/key." -def _visible_models(cached_models, hidden_models): - """Filter cached model IDs by hidden_models. Returns list of visible IDs.""" - all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or []) +def _normalize_model_ids(value): + """Coerce a model-ID input into a clean, ordered list of strings. + + Accepts a list, a JSON-encoded list string, or a comma/newline separated + string (handy for form or backend API input). Trims whitespace, drops + empty and non-string values, and de-duplicates preserving first-seen order. + """ + if value is None: + return [] + items = value + if isinstance(value, str): + text = value.strip() + if not text: + return [] + try: + parsed = json.loads(text) + except Exception: + parsed = None + items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text) + if not isinstance(items, list): + return [] + out, seen = [], set() + for item in items: + if not isinstance(item, str): + continue + s = item.strip() + if not s or s in seen: + continue + seen.add(s) + out.append(s) + return out + + +def _merge_model_ids(*lists): + """Concatenate model-ID lists, de-duplicating and preserving order.""" + out, seen = [], set() + for ids in lists: + for m in (ids or []): + if not isinstance(m, str) or m in seen: + continue + seen.add(m) + out.append(m) + return out + + +def _visible_models(cached_models, hidden_models, pinned_models=None): + """Merge cached + pinned model IDs, then filter out hidden ones. + + Pinned IDs are admin-entered and may not appear in cached_models (e.g. + cloud deployment IDs the provider does not list in /v1/models). Returns an + ordered, de-duplicated list of visible IDs. + """ + # Normalize each input so JSON strings, lists, comma/newline strings, and + # malformed strings are all handled without raising. + merged = _merge_model_ids( + _normalize_model_ids(cached_models), + _normalize_model_ids(pinned_models), + ) if not hidden_models: - return all_models - hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or [])) - return [m for m in all_models if m not in hidden] + return merged + hidden = set(_normalize_model_ids(hidden_models)) + return [m for m in merged if m not in hidden] def setup_model_routes(model_discovery): @@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery): hidden = set(json.loads(r.hidden_models)) except Exception: pass - visible = [m for m in all_models if m not in hidden] - status = "online" if all_models else "offline" + pinned = _normalize_model_ids(getattr(r, "pinned_models", None)) + visible = _visible_models(all_models, r.hidden_models, pinned) + # Endpoint counts as reachable if it has any model — including + # admin-pinned IDs that a probe would never surface. + status = "online" if (all_models or pinned) else "offline" ping = None - if not all_models and r.is_enabled: + if not all_models and not pinned and r.is_enabled: ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) if ping.get("reachable"): status = "empty" @@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery): "has_key": bool(r.api_key), "is_enabled": r.is_enabled, "models": visible, + "pinned_models": pinned, "hidden_count": len(hidden), "online": status != "offline", "status": status, @@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery): require_models: str = Form("false"), model_type: str = Form("llm"), supports_tools: str = Form(""), # "true"/"false"/"" (unknown) + pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline container_local: str = Form("false"), # Default `shared=true` → endpoints are visible to all users (the # app's historical behaviour). Admins can pass `shared=false` to @@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery): .first() ) if existing: + # Persist any incoming pinned IDs onto the existing row. An + # empty/omitted form field must not wipe previously pinned IDs. + _incoming_pinned = _normalize_model_ids(pinned_models) + if _incoming_pinned: + _merged_pinned = _merge_model_ids( + _normalize_model_ids(getattr(existing, "pinned_models", None)), + _incoming_pinned, + ) + existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None + _db_dedup.commit() + _invalidate_models_cache() + _existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None)) return { "id": existing.id, "name": existing.name, "base_url": existing.base_url, - "models": json.loads(existing.cached_models) if existing.cached_models else [], + "models": _visible_models( + getattr(existing, "cached_models", None), + getattr(existing, "hidden_models", None), + existing.pinned_models, + ), + "pinned_models": _existing_pinned, "online": True, "status": "online", "existing": True, @@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery): try: _st_raw = (supports_tools or "").strip().lower() _st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None) + _pinned = _normalize_model_ids(pinned_models) # Stamp owner so the picker only shows this endpoint to the admin # who added it. Pass `shared=true` to mark it null-owner (visible # to all users), preserving the pre-fix "everyone sees everything" @@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery): is_enabled=True, model_type=model_type.strip() if model_type else "llm", cached_models=json.dumps(model_ids) if model_ids else None, + pinned_models=json.dumps(_pinned) if _pinned else None, supports_tools=_st, owner=_owner_val, ) @@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery): "id": ep_id, "name": name.strip(), "base_url": base_url, - "models": model_ids, - "online": bool(model_ids) or bool(ping.get("reachable")), - "status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"), + "models": _merge_model_ids(model_ids, _pinned), + "pinned_models": _pinned, + "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")), + "status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"), "ping_error": ping.get("error") if ping else None, } @@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery): hidden = set(json.loads(ep.hidden_models)) except Exception: pass - # Try live probe, fall back to cached + # Try live probe, fall back to cached. Pinned IDs are admin-entered + # and persist regardless of probe results — never overwritten here. all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) if all_models: ep.cached_models = json.dumps(all_models) @@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery): all_models = json.loads(ep.cached_models) except Exception: pass + pinned = _normalize_model_ids(getattr(ep, "pinned_models", None)) + pinned_set = set(pinned) return [ - {"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden} - for m in all_models + { + "id": m, + "display": m.split("/")[-1], + "is_hidden": m in hidden, + "is_pinned": m in pinned_set, + } + for m in _merge_model_ids(all_models, pinned) ] finally: db.close() @router.patch("/model-endpoints/{ep_id}/models") async def update_hidden_models(ep_id: str, request: Request): - """Bulk update hidden models list for an endpoint. + """Bulk update hidden and/or pinned model lists for an endpoint. - Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]} + Expects JSON body with optional keys: + {"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]} + Each key is updated only when present, so callers can patch one list + without clobbering the other. """ require_admin(request) db = SessionLocal() @@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery): if not ep: raise HTTPException(404, "Endpoint not found") body = await request.json() - hidden = body.get("hidden", []) - if not isinstance(hidden, list): - raise HTTPException(400, "hidden must be a list of model IDs") - ep.hidden_models = json.dumps(hidden) if hidden else None + if not isinstance(body, dict): + raise HTTPException(400, "Body must be a JSON object") + if "hidden" in body: + hidden = body.get("hidden") + if not isinstance(hidden, list): + raise HTTPException(400, "hidden must be a list of model IDs") + ep.hidden_models = json.dumps(hidden) if hidden else None + # Accept either "pinned" or "pinned_models" for the manual IDs list. + if "pinned_models" in body or "pinned" in body: + pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned"))) + ep.pinned_models = json.dumps(pinned) if pinned else None db.commit() _invalidate_models_cache() - return {"id": ep_id, "hidden_count": len(hidden)} + hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0 + pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0 + return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count} finally: db.close() @@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery): return {"endpoint_id": "", "endpoint_url": "", "model": ""} base = _normalize_base(ep.base_url) chat_url = build_chat_url(base) - if not model and getattr(ep, "cached_models", None): + if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)): try: - visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None)) + visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None)) if visible: model = visible[0] except Exception: @@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery): ep.name = body["name"].strip() or ep.name if "model_type" in body and isinstance(body["model_type"], str): ep.model_type = body["model_type"].strip() or ep.model_type + if "pinned_models" in body: + _pinned = _normalize_model_ids(body["pinned_models"]) + ep.pinned_models = json.dumps(_pinned) if _pinned else None # Rotating an API key used to require DELETE+POST, which wiped # endpoint_url/model from every session referencing the old base # URL. Allow in-place updates so the admin can change the key @@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery): "name": ep.name, "model_type": ep.model_type, "base_url": ep.base_url, + "pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)), } finally: db.close() diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index be767e4..48d6293 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -1,6 +1,9 @@ """Tests for model route helper functions — pure logic, no server needed.""" +import asyncio +import json import sys import types +from types import SimpleNamespace from unittest.mock import MagicMock import httpx @@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver from routes.model_routes import ( _match_provider_curated, _curate_models, + _visible_models, + _normalize_model_ids, _is_chat_model, _classify_endpoint, _probe_endpoint, @@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable: monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail) assert model_routes._docker_host_gateway_reachable() is False + + +# ── pinned model IDs: normalization helper ── + + +class TestNormalizeModelIds: + def test_list_passthrough_trims_and_dedupes(self): + assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"] + + def test_json_string_list(self): + assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"] + + def test_comma_and_newline_string(self): + assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"] + + def test_none_and_empty(self): + assert _normalize_model_ids(None) == [] + assert _normalize_model_ids("") == [] + assert _normalize_model_ids(" ") == [] + + def test_non_string_values_ignored(self): + assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"] + + +# ── pinned model IDs: _visible_models merge ── + + +class TestVisibleModelsPinned: + def test_includes_pinned_not_in_cached(self): + visible = _visible_models(["a"], None, ["deploy-1"]) + assert visible == ["a", "deploy-1"] + + def test_cached_plus_pinned_dedup_preserves_order(self): + visible = _visible_models(["a", "b"], None, ["b", "c"]) + assert visible == ["a", "b", "c"] + + def test_hidden_can_hide_a_pinned_model(self): + visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"]) + assert visible == ["a"] + + def test_accepts_json_string_inputs(self): + visible = _visible_models('["a"]', '["a"]', '["b"]') + assert visible == ["b"] + + +# ── pinned model IDs: route behaviour ── + +# Building the router exercises FastAPI's Form() routes, which require +# python-multipart. The test env ships without it, so register a minimal stub +# (mirrors tests/test_review_regressions.py) only when it's genuinely missing. +if "python_multipart" not in sys.modules: + try: + import python_multipart # noqa: F401 + except ImportError: + _mp_stub = types.ModuleType("python_multipart") + _mp_stub.__version__ = "0.0.13" + sys.modules["python_multipart"] = _mp_stub + + +class _PinnedFakeQuery: + def __init__(self, rows): + self.rows = list(rows) + + def filter(self, *conditions): + return self + + def order_by(self, *args): + return self + + def first(self): + return self.rows[0] if self.rows else None + + def all(self): + return list(self.rows) + + +class _PinnedFakeDb: + def __init__(self, rows): + self.rows = rows + self.added = [] + self.committed = 0 + + def query(self, model): + return _PinnedFakeQuery(self.rows) + + def add(self, row): + self.added.append(row) + + def commit(self): + self.committed += 1 + + def close(self): + pass + + +class _FakeCol: + """Column stand-in: every comparison/operator just returns itself so the + dedupe query expressions evaluate without a real SQLAlchemy column.""" + + __hash__ = None + + def __eq__(self, other): + return self + + def is_(self, other): + return self + + def __or__(self, other): + return self + + def desc(self): + return self + + +class _RecordingEndpoint: + """ModelEndpoint stand-in that stores constructor kwargs as attributes. + + Class-level fake columns let it double as the query class in the dedupe + lookup; instance attributes (set in __init__) shadow them per-row. + """ + + id = _FakeCol() + base_url = _FakeCol() + owner = _FakeCol() + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +class _PinnedFakeRequest: + def __init__(self, body=None, headers=None): + self._body = body if body is not None else {} + self.headers = headers or {} + + async def json(self): + return self._body + + +def _get_route(path, method): + from routes.model_routes import setup_model_routes + router = setup_model_routes(model_discovery=None) + for route in router.routes: + if getattr(route, "path", "") == path and method in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError(f"{method} {path} not found") + + +def _make_endpoint(**kwargs): + base = dict( + id="ep1", + name="EP", + base_url="http://localhost:9999/v1", + api_key=None, + is_enabled=True, + hidden_models=None, + cached_models=None, + pinned_models=None, + model_type="llm", + supports_tools=None, + ) + base.update(kwargs) + return SimpleNamespace(**base) + + +def test_patch_models_saves_pinned_models(monkeypatch): + ep = _make_endpoint() + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH") + + request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]}) + result = asyncio.run(endpoint("ep1", request)) + + assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"] + assert result["pinned_count"] == 2 + + +def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch): + ep = _make_endpoint(hidden_models=json.dumps(["hide-me"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH") + + request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]}) + asyncio.run(endpoint("ep1", request)) + + assert json.loads(ep.hidden_models) == ["hide-me"] + assert json.loads(ep.pinned_models) == ["deploy-1"] + + +def test_get_models_returns_pinned_when_probe_empty(monkeypatch): + ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: []) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET") + + result = endpoint("ep1", _PinnedFakeRequest()) + + ids = [row["id"] for row in result] + assert ids == ["deploy-1"] + assert result[0]["is_pinned"] is True + + +def test_reprobe_preserves_pinned_models(monkeypatch): + ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"]) + monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True) + monkeypatch.setattr( + model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"} + ) + endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET") + + response = endpoint("ep1", _PinnedFakeRequest()) + + async def _drain(): + async for _ in response.body_iterator: + pass + + asyncio.run(_drain()) + + # Probe rewrites cached/hidden but must never touch admin-pinned IDs. + assert json.loads(ep.pinned_models) == ["deploy-1"] + assert json.loads(ep.cached_models) == ["m1"] + + +def test_visible_models_handles_malformed_strings(): + # Non-JSON cached/pinned strings are treated as comma/newline lists and + # never raise; a malformed hidden string is normalized too. + result = _visible_models("a,b", "b", "{bad json") + assert isinstance(result, list) + assert result == ["a", "{bad json"] + assert _visible_models("", None, "") == [] + assert _visible_models("only-cached", None, None) == ["only-cached"] + + +def _create_form_kwargs(**overrides): + """Defaults for every Form() param create_model_endpoint reads directly. + + Calling the route as a plain function bypasses FastAPI form parsing, so the + Form() sentinels must be replaced with real strings. + """ + kwargs = dict( + name="", + api_key="", + skip_probe="true", # avoid any network probe in unit tests + require_models="false", + model_type="llm", + supports_tools="", + pinned_models="", + container_local="false", + shared="true", + ) + kwargs.update(overrides) + return kwargs + + +def _patch_create_deps(monkeypatch, db): + import src.auth_helpers as auth_helpers + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint) + monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b) + monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b) + monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"}) + monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u) + monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None) + + +def test_post_creates_endpoint_with_pinned_models(monkeypatch): + db = _PinnedFakeDb([]) # no existing row → fresh create path + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"), + ) + + assert result["pinned_models"] == ["deploy-1", "deploy-2"] + assert result["models"] == ["deploy-1", "deploy-2"] + assert result["online"] is True + # Persisted onto the created row. + assert len(db.added) == 1 + assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"] + + +def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch): + existing = _make_endpoint( + cached_models=json.dumps(["m1"]), + hidden_models=None, + pinned_models=json.dumps(["old-pin"]), + ) + db = _PinnedFakeDb([existing]) + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(pinned_models="new-pin"), + ) + + assert result["existing"] is True + # Incoming pin merged onto the existing pins (no clobber, order preserved). + assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"] + assert result["pinned_models"] == ["old-pin", "new-pin"] + # models = cached + pinned - hidden, visible merged list. + assert result["models"] == ["m1", "old-pin", "new-pin"] + # No new row created on the dedupe path. + assert db.added == [] + + +def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch): + existing = _make_endpoint( + cached_models=json.dumps(["m1"]), + pinned_models=json.dumps(["keep-me"]), + ) + db = _PinnedFakeDb([existing]) + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(), # pinned_models defaults to "" + ) + + assert json.loads(existing.pinned_models) == ["keep-me"] + assert result["pinned_models"] == ["keep-me"] + assert db.committed == 0 # nothing to persist