fix(models): stabilize proxy endpoint refresh behavior

* fix: support large proxy model endpoint refresh

Large OpenAI-compatible proxy endpoints can expose hundreds of models and make /v1/models slow. Treating those endpoints like local model servers caused model picker opens and background probes to repeatedly hit /models, producing timeouts and making otherwise usable endpoints appear offline.

Make model endpoint discovery cached-first for normal UI usage, add explicit proxy/API classification and refresh policy fields, exclude proxy/API endpoints from aggressive local probing, and preserve cached models when refresh fails.

Manual Test/Add/Refresh actions still fetch the full model list with longer timeouts so users can intentionally import large proxy model lists without blocking normal model picker usage.

* fix: preserve endpoint ping status semantics
This commit is contained in:
Yuri
2026-06-04 00:56:11 -03:00
committed by GitHub
parent eee2167502
commit a2e691da2b
10 changed files with 1323 additions and 231 deletions

View File

@@ -342,6 +342,14 @@ class ModelEndpoint(TimestampMixin, Base):
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models) pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
model_type = Column(String, nullable=True, default="llm") # "llm" or "image" model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
# auto = classify by URL; local = self-hosted server; api/proxy = external
# OpenAI-compatible API even when reachable through a private/tailnet IP.
endpoint_kind = Column(String, nullable=True, default="auto")
# auto = background refresh with TTL/backoff; manual/disabled = cached-first
# only unless an explicit endpoint probe is requested.
model_refresh_mode = Column(String, nullable=True, default="auto")
model_refresh_interval = Column(Integer, nullable=True, default=None)
model_refresh_timeout = Column(Integer, nullable=True, default=None)
# Whether models on this endpoint accept OpenAI-style function # Whether models on this endpoint accept OpenAI-style function
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto- # schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
# register time from `--enable-auto-tool-choice` in the serve cmd; # register time from `--enable-auto-tool-choice` in the serve cmd;
@@ -809,6 +817,29 @@ def _migrate_add_model_type_column():
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(f"model_type migration failed: {e}") logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
def _migrate_add_model_endpoint_refresh_columns():
"""Add endpoint classification / refresh policy columns if missing."""
import sqlite3
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
columns = [row[1] for row in cursor.fetchall()]
if columns and "endpoint_kind" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN endpoint_kind TEXT DEFAULT 'auto'")
if columns and "model_refresh_mode" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_mode TEXT DEFAULT 'auto'")
if columns and "model_refresh_interval" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_interval INTEGER")
if columns and "model_refresh_timeout" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
conn.commit()
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
def _migrate_add_task_run_model_column(): def _migrate_add_task_run_model_column():
"""Add model column to task_runs if it doesn't exist (records which model ran).""" """Add model column to task_runs if it doesn't exist (records which model ran)."""
import sqlite3 import sqlite3
@@ -1539,6 +1570,7 @@ def init_db():
_migrate_add_pinned_models_column() _migrate_add_pinned_models_column()
_migrate_add_notes_sort_order() _migrate_add_notes_sort_order()
_migrate_add_model_type_column() _migrate_add_model_type_column()
_migrate_add_model_endpoint_refresh_columns()
_migrate_add_model_endpoint_owner_column() _migrate_add_model_endpoint_owner_column()
_migrate_add_supports_tools_column() _migrate_add_supports_tools_column()
_migrate_add_task_run_model_column() _migrate_add_task_run_model_column()

View File

@@ -11,7 +11,7 @@ import httpx
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request from fastapi import APIRouter, HTTPException, Form, Query, Body, Request, Response
from pydantic import BaseModel from pydantic import BaseModel
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from core.database import SessionLocal, ModelEndpoint, Session as DbSession from core.database import SessionLocal, ModelEndpoint, Session as DbSession
@@ -335,6 +335,141 @@ def _truthy(value: str | None) -> bool:
return (value or "").strip().lower() in ("true", "1", "yes", "on") return (value or "").strip().lower() in ("true", "1", "yes", "on")
_ENDPOINT_KINDS = {"auto", "local", "api", "proxy"}
_REFRESH_MODES = {"auto", "manual", "disabled"}
def _normalize_endpoint_kind(value: Any) -> str:
kind = str(value or "auto").strip().lower()
return kind if kind in _ENDPOINT_KINDS else "auto"
def _normalize_refresh_mode(value: Any, endpoint_kind: str = "auto") -> str:
mode = str(value or "").strip().lower()
kind = _normalize_endpoint_kind(endpoint_kind)
if mode in ("manual", "disabled"):
return mode
if mode == "auto" and kind != "proxy":
return "auto"
# Proxies default to manual cached-first behavior. Normal local/API
# endpoints keep automatic bounded refreshes.
return "manual" if kind == "proxy" else "auto"
def _endpoint_kind(ep: Any) -> str:
return _normalize_endpoint_kind(getattr(ep, "endpoint_kind", None))
def _endpoint_refresh_mode(ep: Any, endpoint_kind: str | None = None) -> str:
return _normalize_refresh_mode(getattr(ep, "model_refresh_mode", None), endpoint_kind or _endpoint_kind(ep))
def _endpoint_refresh_interval(ep: Any, category: str) -> float:
raw = getattr(ep, "model_refresh_interval", None)
try:
val = int(raw) if raw is not None else 0
except Exception:
val = 0
if val > 0:
return float(max(30, val))
return 60.0 if category == "local" else 3600.0
def _endpoint_refresh_timeout(ep: Any, category: str) -> float:
raw = getattr(ep, "model_refresh_timeout", None)
try:
val = int(raw) if raw is not None else 0
except Exception:
val = 0
if val > 0:
return float(max(1, min(30, val)))
return 2.5 if category == "local" else 2.0
def _manual_refresh_timeout(ep: Any, category: str, requested: Any = None) -> float:
"""Timeout for explicit user-triggered model-list refreshes.
Background refreshes stay short. A manual refresh is the one path where a
large proxy may legitimately need 15-30s to aggregate its catalog.
"""
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
if requested_val is not None:
return float(requested_val)
stored = _parse_positive_int(getattr(ep, "model_refresh_timeout", None), minimum=1, maximum=60)
if category == "local":
return float(stored) if stored is not None else _endpoint_refresh_timeout(ep, category)
return float(max(stored or 30, 30))
def _parse_model_list(raw: Any) -> List[str]:
"""Return a sanitized list of model ids from JSON/list/comma text."""
if raw is None:
return []
value = raw
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
if isinstance(parsed, list):
value = parsed
else:
value = re.split(r"[\n,]+", text)
except Exception:
value = re.split(r"[\n,]+", text)
if not isinstance(value, list):
return []
out = []
seen = set()
for item in value:
mid = str(item or "").strip()
if not mid or mid in seen:
continue
seen.add(mid)
out.append(mid)
return out
def _parse_positive_int(raw: Any, *, minimum: int = 1, maximum: int = 86400) -> Optional[int]:
try:
val = int(str(raw).strip())
except Exception:
return None
if val < minimum:
return None
return min(val, maximum)
def _explicit_model_list_timeout(base_url: str, endpoint_kind: str = "auto", requested: Any = None) -> float:
"""Timeout for explicit user-triggered model-list fetches during setup."""
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
if requested_val is not None:
return float(requested_val)
kind = _normalize_endpoint_kind(endpoint_kind)
category = _classify_endpoint(base_url, kind)
if kind in ("api", "proxy") or category == "api":
return 30.0
return 3.0 if _is_ollama_base(base_url) else 2.0
def _cached_model_ids(ep: Any) -> List[str]:
return _parse_model_list(getattr(ep, "cached_models", None))
def _hidden_model_ids(ep: Any) -> set:
return set(_parse_model_list(getattr(ep, "hidden_models", None)))
def _is_ollama_base(base_url: str) -> bool:
try:
parsed = urlparse(base_url)
host = (parsed.hostname or "").lower()
return parsed.port == 11434 or "ollama" in host
except Exception:
return "ollama" in (base_url or "").lower()
# Prefixes/substrings for models that are NOT chat-completions-capable # Prefixes/substrings for models that are NOT chat-completions-capable
_NON_CHAT_PREFIXES = ( _NON_CHAT_PREFIXES = (
"dall-e", "tts-", "whisper", "text-embedding", "embedding", "dall-e", "tts-", "whisper", "text-embedding", "embedding",
@@ -441,10 +576,15 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
_TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.") _TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.")
def _classify_endpoint(base_url: str) -> str: def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
"""Return 'local' if the endpoint URL points to a private/local address, else 'api'. """Return 'local' if the endpoint URL points to a private/local address, else 'api'.
Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted
servers (e.g. Cookbook serve endpoints) get reachability-probed too.""" servers (e.g. Cookbook serve endpoints) get reachability-probed too."""
kind = _normalize_endpoint_kind(endpoint_kind)
if kind == "local":
return "local"
if kind in ("api", "proxy"):
return "api"
try: try:
host = urlparse(base_url).hostname or "" host = urlparse(base_url).hostname or ""
if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES): if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES):
@@ -456,6 +596,21 @@ def _classify_endpoint(base_url: str) -> str:
return "api" return "api"
def _effective_endpoint_kind(ep: Any, base_url: str) -> str:
"""Return explicit kind, with a legacy proxy heuristic for keyed /v1 URLs."""
kind = _endpoint_kind(ep)
if kind != "auto":
return kind
if getattr(ep, "api_key", None) and not _is_ollama_base(base_url):
try:
path = (urlparse(base_url).path or "").rstrip("/")
if path.endswith("/v1") or "/openai" in path:
return "proxy"
except Exception:
pass
return "auto"
def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]: def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]:
"""Probe a base URL's /models endpoint and return list of model IDs. """Probe a base URL's /models endpoint and return list of model IDs.
@@ -546,30 +701,18 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
"""Reachability probe that does not require installed/listed models.""" """Reachability probe that does not require installed/listed models."""
from src.endpoint_resolver import resolve_url from src.endpoint_resolver import resolve_url
base = resolve_url(_normalize_base(base_url)) base = resolve_url(_normalize_base(base_url))
headers = {} headers = build_headers(api_key, base)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version, # Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the # /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
# base is the host root or the native /api root (e.g. http://localhost:11434, # /models as a generic health check because large proxy catalogs can be slow.
# http://localhost:11434/api) because /models lives under /v1 there. Treat
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
# than as a definitive offline verdict — Ollama is reachable, it just
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
# marks an alive Ollama as offline whenever cached_models is empty (issue
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
# _ping_endpoint() was returning before that fallback could run.
parsed_base = urlparse(base) parsed_base = urlparse(base)
looks_like_ollama = ( looks_like_ollama = (
parsed_base.port == 11434 parsed_base.port == 11434
or "ollama" in (parsed_base.hostname or "").lower() or "ollama" in (parsed_base.hostname or "").lower()
) )
url = base + "/models" def _result_from_response(r) -> Dict[str, Any]:
last_error: Optional[str] = None
try:
r = httpx.get(url, headers=headers, timeout=timeout)
if 300 <= r.status_code < 400: if 300 <= r.status_code < 400:
loc = r.headers.get("location", "") loc = r.headers.get("location", "")
if loc.startswith("/login") or "/login" in loc: if loc.startswith("/login") or "/login" in loc:
@@ -579,13 +722,15 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.", "error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
} }
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"} return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
if r.status_code < 400: if 200 <= r.status_code < 300:
return {"reachable": True, "status_code": r.status_code, "error": None} return {
if r.status_code < 500 and not looks_like_ollama: "reachable": True,
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"} "status_code": r.status_code,
last_error = f"HTTP {r.status_code}" "error": None,
except Exception as e: }
last_error = str(e)[:120] return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
last_error: Optional[str] = None
try: try:
if looks_like_ollama: if looks_like_ollama:
@@ -597,14 +742,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
for path in ("/api/version", "/api/tags"): for path in ("/api/version", "/api/tags"):
try: try:
r = httpx.get(root + path, timeout=timeout) r = httpx.get(root + path, timeout=timeout)
if r.status_code < 400: result = _result_from_response(r)
return {"reachable": True, "status_code": r.status_code, "error": None} if result["reachable"]:
last_error = f"HTTP {r.status_code}" return result
last_error = result.get("error")
except Exception as e: except Exception as e:
last_error = str(e)[:120] last_error = str(e)[:120]
except Exception: except Exception:
pass pass
try:
r = httpx.get(base, headers=headers, timeout=timeout)
return _result_from_response(r)
except Exception as e:
last_error = str(e)[:120]
return {"reachable": False, "status_code": None, "error": last_error} return {"reachable": False, "status_code": None, "error": last_error}
@@ -715,17 +867,71 @@ def setup_model_routes(model_discovery):
flip).""" flip)."""
_models_cache.clear() _models_cache.clear()
# Track endpoints that have failed recently so we back off probing dead ones. # Track model-list refreshes by URL+key. This prevents repeated picker/API
_probe_failures = {} # ep_id → (last_fail_ts, consecutive_fails) # opens from starting duplicate /models probes, and gives slow/offline
# providers a cooldown after failures.
_refresh_state: Dict[str, Dict[str, Any]] = {}
_refresh_inflight = {"v": False} # coarse single-flight guard _refresh_inflight = {"v": False} # coarse single-flight guard
_REFRESH_FAILURE_BASE = 300.0
_REFRESH_FAILURE_MAX = 3600.0
def _refresh_caches_bg(): def _refresh_key(base: str, api_key: Optional[str]) -> str:
"""Background thread: re-probe all endpoints in PARALLEL with a tight return f"{base.rstrip('/')}\x00{api_key or ''}"
timeout, skipping endpoints that have been failing repeatedly.
Was the cause of gradual server degradation: sequential 3s-timeout def _ts(value: Any) -> float:
probes against many endpoints (some offline) tied up the threadpool try:
for 15-30s every cache cycle, eventually exhausting it.""" return float(value.timestamp()) if value else 0.0
except Exception:
return 0.0
def _failure_delay(fails: int) -> float:
if fails <= 0:
return 0.0
return min(_REFRESH_FAILURE_BASE * (2 ** max(0, fails - 1)), _REFRESH_FAILURE_MAX)
def _should_refresh_endpoint(ep: Any, now: float, force: bool = False) -> tuple[bool, Dict[str, Any]]:
base = _normalize_base(getattr(ep, "base_url", "") or "")
kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
mode = _endpoint_refresh_mode(ep, kind)
cached = _cached_model_ids(ep)
key = _refresh_key(base, getattr(ep, "api_key", None))
state = _refresh_state.get(key, {})
info = {
"id": getattr(ep, "id", ""),
"base": base,
"api_key": getattr(ep, "api_key", None),
"kind": kind,
"category": category,
"mode": mode,
"key": key,
"timeout": _endpoint_refresh_timeout(ep, category),
}
if not base:
return False, info
if state.get("inflight"):
return False, info
if mode in ("manual", "disabled") and not force:
return False, info
fails = int(state.get("fail_count") or 0)
if fails and not force:
last_failure = float(state.get("last_failure") or 0.0)
if now - last_failure < _failure_delay(fails):
return False, info
if cached and not force:
interval = _endpoint_refresh_interval(ep, category)
last_good = float(state.get("last_success") or 0.0) or _ts(getattr(ep, "updated_at", None)) or _ts(getattr(ep, "created_at", None))
if last_good and now - last_good < interval:
return False, info
return True, info
def _refresh_caches_bg(force: bool = False):
"""Background thread: safely refresh model caches with per-base single-flight.
The public /api/models path stays cached-first. This refresh never clears
a non-empty cached model list on timeout/failure, and proxy/manual
endpoints are skipped unless explicitly forced."""
import threading import threading
if _refresh_inflight["v"]: if _refresh_inflight["v"]:
return # already running return # already running
@@ -735,44 +941,63 @@ def setup_model_routes(model_discovery):
try: try:
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
db = SessionLocal() db = SessionLocal()
changed = False
try: try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
# Skip endpoints that have failed 3+ times in a row in the last 5 min
now = _time.time() now = _time.time()
to_probe = [] groups: Dict[str, Dict[str, Any]] = {}
for ep in endpoints: for ep in endpoints:
ts, fails = _probe_failures.get(ep.id, (0, 0)) ok, info = _should_refresh_endpoint(ep, now, force=force)
if fails >= 3 and (now - ts) < 300: if not ok:
continue continue
to_probe.append(ep) groups.setdefault(info["key"], {
"base": info["base"],
"api_key": info["api_key"],
"timeout": info["timeout"],
"endpoint_ids": [],
})["endpoint_ids"].append(info["id"])
def _probe_one(ep): for key in groups:
base = _normalize_base(ep.base_url) st = _refresh_state.setdefault(key, {})
st["inflight"] = True
st["last_attempt"] = now
def _probe_one(key: str, data: Dict[str, Any]):
try: try:
ids = _probe_endpoint(base, ep.api_key, timeout=2) ids = _probe_endpoint(data["base"], data.get("api_key"), timeout=data.get("timeout") or 2)
return ep, ids, None return key, data["endpoint_ids"], ids, None
except Exception as e: except Exception as e:
return ep, None, e return key, data["endpoint_ids"], None, e
if to_probe: if groups:
# Bounded parallelism — 8 concurrent probes is plenty with ThreadPoolExecutor(max_workers=min(4, len(groups))) as pool:
with ThreadPoolExecutor(max_workers=min(8, len(to_probe))) as pool: futures = [pool.submit(_probe_one, key, data) for key, data in groups.items()]
futures = [pool.submit(_probe_one, ep) for ep in to_probe]
for fut in as_completed(futures): for fut in as_completed(futures):
ep, ids, err = fut.result() key, endpoint_ids, ids, err = fut.result()
st = _refresh_state.setdefault(key, {})
if ids: if ids:
ep.cached_models = json.dumps(ids) for ep_id in endpoint_ids:
_probe_failures.pop(ep.id, None) ep_obj = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep_obj:
ep_obj.cached_models = json.dumps(ids)
changed = True
st["last_success"] = _time.time()
st["fail_count"] = 0
st.pop("last_failure", None)
else: else:
prev = _probe_failures.get(ep.id, (0, 0)) st["last_failure"] = _time.time()
_probe_failures[ep.id] = (_time.time(), prev[1] + 1) st["fail_count"] = int(st.get("fail_count") or 0) + 1
st["inflight"] = False
db.commit() db.commit()
finally: finally:
db.close() db.close()
_invalidate_models_cache() if changed:
_invalidate_models_cache()
except Exception: except Exception:
pass pass
finally: finally:
for st in _refresh_state.values():
st["inflight"] = False
_refresh_inflight["v"] = False _refresh_inflight["v"] = False
threading.Thread(target=_do, daemon=True).start() threading.Thread(target=_do, daemon=True).start()
@@ -804,24 +1029,15 @@ def setup_model_routes(model_discovery):
base = _normalize_base(ep.base_url) base = _normalize_base(ep.base_url)
provider = _detect_provider(base) provider = _detect_provider(base)
# Use cached models — background refresh keeps them updated # Use cached models — background refresh keeps them updated
model_ids = [] model_ids = _cached_model_ids(ep)
if ep.cached_models:
try:
model_ids = json.loads(ep.cached_models)
except Exception:
pass
ep_model_type = getattr(ep, "model_type", None) or "llm" ep_model_type = getattr(ep, "model_type", None) or "llm"
# Filter out hidden (probe-failed) models # Filter out hidden (probe-failed) models
hidden = set() hidden = _hidden_model_ids(ep)
if ep.hidden_models:
try:
hidden = set(json.loads(ep.hidden_models))
except Exception:
pass
model_ids = [m for m in model_ids if m not in hidden] model_ids = [m for m in model_ids if m not in hidden]
# Build correct URL based on provider # Build correct URL based on provider
chat_url = build_chat_url(base) chat_url = build_chat_url(base)
category = _classify_endpoint(base) kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
if model_ids: if model_ids:
curated_key = _match_provider_curated(base, None) curated_key = _match_provider_curated(base, None)
@@ -837,6 +1053,7 @@ def setup_model_routes(model_discovery):
"endpoint_id": ep.id, "endpoint_id": ep.id,
"endpoint_name": ep.name, "endpoint_name": ep.name,
"category": category, "category": category,
"endpoint_kind": kind,
"model_type": ep_model_type, "model_type": ep_model_type,
}) })
else: else:
@@ -852,6 +1069,7 @@ def setup_model_routes(model_discovery):
"endpoint_id": ep.id, "endpoint_id": ep.id,
"endpoint_name": ep.name, "endpoint_name": ep.name,
"category": category, "category": category,
"endpoint_kind": kind,
"model_type": ep_model_type, "model_type": ep_model_type,
"offline": True, "offline": True,
}) })
@@ -898,11 +1116,11 @@ def setup_model_routes(model_discovery):
result = _fetch_models(owner=owner, is_admin=_is_admin) result = _fetch_models(owner=owner, is_admin=_is_admin)
_models_cache[_cache_key] = {"data": result, "time": now} _models_cache[_cache_key] = {"data": result, "time": now}
# Kick off background refresh to update caches from live endpoints # Kick off background refresh to update caches from live endpoints
_refresh_caches_bg() _refresh_caches_bg(force=refresh)
return result return result
# Brief cache for local-probe results so picker-open doesn't hammer # Brief cache for local-probe results so picker-open doesn't hammer
# /v1/models every time. 8s TTL — long enough to amortize cost, # endpoint health checks every time. 8s TTL — long enough to amortize cost,
# short enough that a freshly-killed local server shows as offline # short enough that a freshly-killed local server shows as offline
# within ~8s of the user noticing. # within ~8s of the user noticing.
_LOCAL_PROBE_TTL = 8.0 _LOCAL_PROBE_TTL = 8.0
@@ -912,7 +1130,7 @@ def setup_model_routes(model_discovery):
async def probe_local_endpoints(request: Request): async def probe_local_endpoints(request: Request):
"""Fast parallel reachability check for LOCAL endpoints only. """Fast parallel reachability check for LOCAL endpoints only.
Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are
assumed up. Local endpoints get a 1.5s /models probe so the UI assumed up. Local endpoints get a 1.5s cheap reachability probe so the UI
can dim stale entries pointing at dead vLLM servers. Returns can dim stale entries pointing at dead vLLM servers. Returns
{ep_id: {alive, latency_ms, error}}.""" {ep_id: {alive, latency_ms, error}}."""
require_admin(request) require_admin(request)
@@ -924,36 +1142,44 @@ def setup_model_routes(model_discovery):
db = SessionLocal() db = SessionLocal()
try: try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
local_eps = [ local_eps = []
(ep.id, _normalize_base(ep.base_url), ep.api_key) for ep in endpoints:
for ep in endpoints base = _normalize_base(ep.base_url)
if _classify_endpoint(_normalize_base(ep.base_url)) == "local" kind = _effective_endpoint_kind(ep, base)
] if _classify_endpoint(base, kind) == "local":
local_eps.append((ep.id, base, ep.api_key))
finally: finally:
db.close() db.close()
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]: grouped: Dict[str, Dict[str, Any]] = {}
for ep_id, base, api_key in local_eps:
key = _refresh_key(base, api_key)
grouped.setdefault(key, {"base": base, "api_key": api_key, "endpoint_ids": []})["endpoint_ids"].append(ep_id)
async def _probe_one(data: Dict[str, Any]) -> Dict[str, Any]:
t0 = _time.time() t0 = _time.time()
try: try:
models = _probe_endpoint(base, api_key, timeout=2.5) import asyncio as _asyncio
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5)
lat = round((_time.time() - t0) * 1000) lat = round((_time.time() - t0) * 1000)
return { return {
"alive": bool(models), "alive": bool(ping.get("reachable")),
"latency_ms": lat, "latency_ms": lat,
"status_code": 200 if models else None, "status_code": ping.get("status_code"),
"error": None if models else "No models found", "error": ping.get("error"),
} }
except Exception as e: except Exception as e:
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]} return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
import asyncio as _asyncio import asyncio as _asyncio
results_list = await _asyncio.gather( results_list = await _asyncio.gather(
*[_probe_one(eid, base, key) for eid, base, key in local_eps], *[_probe_one(data) for data in grouped.values()],
return_exceptions=False, return_exceptions=False,
) )
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
for (eid, _, _), r in zip(local_eps, results_list): for data, r in zip(grouped.values(), results_list):
results[eid] = r for eid in data["endpoint_ids"]:
results[eid] = r
_local_probe_cache["data"] = results _local_probe_cache["data"] = results
_local_probe_cache["time"] = now _local_probe_cache["time"] = now
@@ -973,50 +1199,28 @@ def setup_model_routes(model_discovery):
for ep in endpoints: for ep in endpoints:
base = _normalize_base(ep.base_url) base = _normalize_base(ep.base_url)
provider = _detect_provider(base) provider = _detect_provider(base)
kind = _effective_endpoint_kind(ep, base)
cached_count = len(_cached_model_ids(ep))
entry = { entry = {
"id": ep.id, "id": ep.id,
"name": ep.name, "name": ep.name,
"base_url": base, "base_url": base,
"provider": provider, "provider": provider,
"category": _classify_endpoint(base), "category": _classify_endpoint(base, kind),
"endpoint_kind": kind,
} }
if provider == "anthropic": try:
# Anthropic has no /models endpoint; just check connectivity t0 = _time.time()
try: ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
t0 = _time.time() entry["latency_ms"] = round((_time.time() - t0) * 1000)
r = httpx.get(base.rstrip("/"), timeout=5) entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["latency_ms"] = round((_time.time() - t0) * 1000) entry["error"] = ping.get("error")
entry["status"] = "online" entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
entry["model_count"] = len(ANTHROPIC_MODELS) except Exception as e:
except Exception as e: entry["latency_ms"] = None
entry["latency_ms"] = None entry["status"] = "online" if cached_count else "offline"
entry["status"] = "offline" entry["error"] = str(e)
entry["error"] = str(e) entry["model_count"] = cached_count
entry["model_count"] = 0
else:
url = build_models_url(base)
headers = build_headers(ep.api_key, base)
try:
t0 = _time.time()
r = httpx.get(url, headers=headers, timeout=5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
entry["status"] = "online"
entry["model_count"] = len(models)
except Exception as e:
if "latency_ms" not in entry:
entry["latency_ms"] = None
entry["status"] = "offline"
entry["error"] = str(e)
entry["model_count"] = 0
results.append(entry) results.append(entry)
return {"endpoints": results} return {"endpoints": results}
@@ -1165,19 +1369,8 @@ def setup_model_routes(model_discovery):
rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all() rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all()
results = [] results = []
for r in rows: for r in rows:
# Use cached model list to avoid slow probe on every load all_models = _cached_model_ids(r)
all_models = [] hidden = _hidden_model_ids(r)
if r.cached_models:
try:
all_models = json.loads(r.cached_models)
except Exception:
pass
hidden = set()
if r.hidden_models:
try:
hidden = set(json.loads(r.hidden_models))
except Exception:
pass
pinned = _normalize_model_ids(getattr(r, "pinned_models", None)) pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
visible = _visible_models(all_models, r.hidden_models, pinned) visible = _visible_models(all_models, r.hidden_models, pinned)
# Endpoint counts as reachable if it has any model — including # Endpoint counts as reachable if it has any model — including
@@ -1188,6 +1381,8 @@ def setup_model_routes(model_discovery):
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"): if ping.get("reachable"):
status = "empty" status = "empty"
base = _normalize_base(r.base_url)
kind = _effective_endpoint_kind(r, base)
results.append({ results.append({
"id": r.id, "id": r.id,
"name": r.name, "name": r.name,
@@ -1202,6 +1397,11 @@ def setup_model_routes(model_discovery):
"ping_error": (ping or {}).get("error") if ping else None, "ping_error": (ping or {}).get("error") if ping else None,
"model_type": getattr(r, "model_type", None) or "llm", "model_type": getattr(r, "model_type", None) or "llm",
"supports_tools": getattr(r, "supports_tools", None), "supports_tools": getattr(r, "supports_tools", None),
"endpoint_kind": kind,
"category": _classify_endpoint(base, kind),
"model_refresh_mode": _endpoint_refresh_mode(r, kind),
"model_refresh_interval": getattr(r, "model_refresh_interval", None),
"model_refresh_timeout": getattr(r, "model_refresh_timeout", None),
}) })
return results return results
finally: finally:
@@ -1216,6 +1416,10 @@ def setup_model_routes(model_discovery):
skip_probe: str = Form("false"), skip_probe: str = Form("false"),
require_models: str = Form("false"), require_models: str = Form("false"),
model_type: str = Form("llm"), model_type: str = Form("llm"),
endpoint_kind: str = Form("auto"),
model_refresh_mode: str = Form(""),
model_refresh_interval: str = Form(""),
model_refresh_timeout: str = Form(""),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown) supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
container_local: str = Form("false"), container_local: str = Form("false"),
@@ -1240,8 +1444,15 @@ def setup_model_routes(model_discovery):
if not name.strip(): if not name.strip():
name = base_url.replace("http://", "").replace("https://", "").split("/")[0] name = base_url.replace("http://", "").replace("https://", "").split("/")[0]
requested_kind = _normalize_endpoint_kind(endpoint_kind)
refresh_mode = _normalize_refresh_mode(model_refresh_mode, requested_kind)
refresh_interval = _parse_positive_int(model_refresh_interval, minimum=30, maximum=86400)
refresh_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
require_model_list = _truthy(require_models) require_model_list = _truthy(require_models)
should_probe = require_model_list or not _truthy(skip_probe) should_probe = (
require_model_list or requested_kind in ("api", "proxy") or not _truthy(skip_probe)
)
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
# Dedupe: if an endpoint with the same base_url already exists and # Dedupe: if an endpoint with the same base_url already exists and
# is reachable by the caller (shared or owned by them), return it # is reachable by the caller (shared or owned by them), return it
@@ -1259,6 +1470,7 @@ def setup_model_routes(model_discovery):
.first() .first()
) )
if existing: if existing:
changed = False
# Persist any incoming pinned IDs onto the existing row. An # Persist any incoming pinned IDs onto the existing row. An
# empty/omitted form field must not wipe previously pinned IDs. # empty/omitted form field must not wipe previously pinned IDs.
_incoming_pinned = _normalize_model_ids(pinned_models) _incoming_pinned = _normalize_model_ids(pinned_models)
@@ -1268,15 +1480,45 @@ def setup_model_routes(model_discovery):
_incoming_pinned, _incoming_pinned,
) )
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
changed = True
existing_kind_for_probe = requested_kind if requested_kind != "auto" else _effective_endpoint_kind(existing, base_url)
if requested_kind != "auto" and _endpoint_kind(existing) == "auto":
existing.endpoint_kind = requested_kind
changed = True
if model_refresh_mode or (requested_kind == "proxy" and _endpoint_refresh_mode(existing, requested_kind) != refresh_mode):
existing.model_refresh_mode = refresh_mode
changed = True
if refresh_interval is not None:
existing.model_refresh_interval = refresh_interval
changed = True
if refresh_timeout is not None:
existing.model_refresh_timeout = refresh_timeout
changed = True
if api_key.strip() and not existing.api_key:
existing.api_key = api_key.strip()
changed = True
if should_probe:
probed_models = _probe_endpoint(
base_url,
(api_key.strip() or existing.api_key or None),
timeout=_explicit_model_list_timeout(base_url, existing_kind_for_probe, refresh_timeout),
)
if probed_models:
existing.cached_models = json.dumps(probed_models)
changed = True
if changed:
_db_dedup.commit() _db_dedup.commit()
_invalidate_models_cache() _invalidate_models_cache()
_local_probe_cache["data"] = None
existing_models = _cached_model_ids(existing)
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None)) _existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
existing_kind = _effective_endpoint_kind(existing, existing.base_url)
return { return {
"id": existing.id, "id": existing.id,
"name": existing.name, "name": existing.name,
"base_url": existing.base_url, "base_url": existing.base_url,
"models": _visible_models( "models": _visible_models(
getattr(existing, "cached_models", None), existing_models,
getattr(existing, "hidden_models", None), getattr(existing, "hidden_models", None),
existing.pinned_models, existing.pinned_models,
), ),
@@ -1284,16 +1526,16 @@ def setup_model_routes(model_discovery):
"online": True, "online": True,
"status": "online", "status": "online",
"existing": True, "existing": True,
"endpoint_kind": existing_kind,
"category": _classify_endpoint(existing.base_url, existing_kind),
} }
finally: finally:
_db_dedup.close() _db_dedup.close()
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh) model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=explicit_timeout) if should_probe else []
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
ping = {"reachable": False, "error": None} ping = {"reachable": False, "error": None}
if should_probe and not model_ids: if (should_probe or requested_kind in ("api", "proxy")) and not model_ids:
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=min(explicit_timeout, 2.0))
if require_model_list and not model_ids: if require_model_list and not model_ids:
raise HTTPException(400, _model_endpoint_error_message(base_url, ping)) raise HTTPException(400, _model_endpoint_error_message(base_url, ping))
@@ -1317,6 +1559,10 @@ def setup_model_routes(model_discovery):
api_key=api_key.strip() or None, api_key=api_key.strip() or None,
is_enabled=True, is_enabled=True,
model_type=model_type.strip() if model_type else "llm", model_type=model_type.strip() if model_type else "llm",
endpoint_kind=requested_kind,
model_refresh_mode=refresh_mode,
model_refresh_interval=refresh_interval,
model_refresh_timeout=refresh_timeout,
cached_models=json.dumps(model_ids) if model_ids else None, cached_models=json.dumps(model_ids) if model_ids else None,
pinned_models=json.dumps(_pinned) if _pinned else None, pinned_models=json.dumps(_pinned) if _pinned else None,
supports_tools=_st, supports_tools=_st,
@@ -1349,6 +1595,8 @@ def setup_model_routes(model_discovery):
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")), "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"), "status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None, "ping_error": ping.get("error") if ping else None,
"endpoint_kind": requested_kind,
"category": _classify_endpoint(base_url, requested_kind),
} }
@router.post("/model-endpoints/test") @router.post("/model-endpoints/test")
@@ -1356,6 +1604,8 @@ def setup_model_routes(model_discovery):
request: Request, request: Request,
base_url: str = Form(...), base_url: str = Form(...),
api_key: str = Form(""), api_key: str = Form(""),
endpoint_kind: str = Form("auto"),
model_refresh_timeout: str = Form(""),
): ):
require_admin(request) require_admin(request)
base_url = _normalize_base(base_url) base_url = _normalize_base(base_url)
@@ -1364,9 +1614,11 @@ def setup_model_routes(model_discovery):
from src.endpoint_resolver import resolve_url from src.endpoint_resolver import resolve_url
base_url = resolve_url(base_url) base_url = resolve_url(base_url)
base_url = _rewrite_loopback_for_docker(base_url) base_url = _rewrite_loopback_for_docker(base_url)
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2 requested_kind = _normalize_endpoint_kind(endpoint_kind)
configured_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
probe_timeout = _explicit_model_list_timeout(base_url, requested_kind, configured_timeout)
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout) models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout) ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=min(probe_timeout, 2.0))
return { return {
"base_url": base_url, "base_url": base_url,
"online": bool(models) or bool(ping.get("reachable")), "online": bool(models) or bool(ping.get("reachable")),
@@ -1374,6 +1626,8 @@ def setup_model_routes(model_discovery):
"ping_error": ping.get("error") if ping else None, "ping_error": ping.get("error") if ping else None,
"models": models, "models": models,
"count": len(models), "count": len(models),
"endpoint_kind": requested_kind,
"category": _classify_endpoint(base_url, requested_kind),
} }
@router.get("/model-endpoints/{ep_id}/probe") @router.get("/model-endpoints/{ep_id}/probe")
@@ -1415,7 +1669,8 @@ def setup_model_routes(model_discovery):
ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep_obj: if ep_obj:
ep_obj.hidden_models = json.dumps(failed) if failed else None ep_obj.hidden_models = json.dumps(failed) if failed else None
ep_obj.cached_models = json.dumps(all_models) if all_models else None if all_models:
ep_obj.cached_models = json.dumps(all_models)
db2.commit() db2.commit()
finally: finally:
db2.close() db2.close()
@@ -1426,7 +1681,13 @@ def setup_model_routes(model_discovery):
return StreamingResponse(_stream(), media_type="text/event-stream") return StreamingResponse(_stream(), media_type="text/event-stream")
@router.get("/model-endpoints/{ep_id}/models") @router.get("/model-endpoints/{ep_id}/models")
def list_endpoint_models(ep_id: str, request: Request): def list_endpoint_models(
ep_id: str,
request: Request,
response: Response,
refresh: bool = False,
refresh_timeout: Optional[int] = Query(None, ge=1, le=60),
):
"""List all discovered models for an endpoint with hidden/visible state.""" """List all discovered models for an endpoint with hidden/visible state."""
require_admin(request) require_admin(request)
db = SessionLocal() db = SessionLocal()
@@ -1434,23 +1695,28 @@ def setup_model_routes(model_discovery):
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if not ep: if not ep:
raise HTTPException(404, "Endpoint not found") raise HTTPException(404, "Endpoint not found")
hidden = set() hidden = _hidden_model_ids(ep)
if ep.hidden_models: all_models = _cached_model_ids(ep)
if refresh:
base = _normalize_base(ep.base_url)
kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
try: try:
hidden = set(json.loads(ep.hidden_models)) probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
except Exception: except Exception as exc:
pass logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
# Try live probe, fall back to cached. Pinned IDs are admin-entered probed = []
# and persist regardless of probe results — never overwritten here. if probed:
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) all_models = probed
if all_models: ep.cached_models = json.dumps(all_models)
ep.cached_models = json.dumps(all_models) db.commit()
db.commit() _invalidate_models_cache()
elif ep.cached_models: response.headers["X-Model-Refresh-Status"] = "refreshed"
try: response.headers["X-Model-Refresh-Count"] = str(len(probed))
all_models = json.loads(ep.cached_models) else:
except Exception: response.headers["X-Model-Refresh-Status"] = "failed"
pass response.headers["X-Model-Refresh-Warning"] = "Model refresh failed or returned no models; kept cached models."
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None)) pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
pinned_set = set(pinned) pinned_set = set(pinned)
return [ return [
@@ -1502,7 +1768,6 @@ def setup_model_routes(model_discovery):
@router.get("/default-chat") @router.get("/default-chat")
def get_default_chat(request: Request): def get_default_chat(request: Request):
import json as _json
# SECURITY: resolve the default endpoint + model from the CALLER's # SECURITY: resolve the default endpoint + model from the CALLER's
# per-user prefs ONLY. We deliberately do NOT fall back to the # per-user prefs ONLY. We deliberately do NOT fall back to the
# global `default_model` / `default_endpoint_id` in settings.json # global `default_model` / `default_endpoint_id` in settings.json
@@ -1635,6 +1900,16 @@ def setup_model_routes(model_discovery):
if "pinned_models" in body: if "pinned_models" in body:
_pinned = _normalize_model_ids(body["pinned_models"]) _pinned = _normalize_model_ids(body["pinned_models"])
ep.pinned_models = json.dumps(_pinned) if _pinned else None ep.pinned_models = json.dumps(_pinned) if _pinned else None
if "endpoint_kind" in body:
ep.endpoint_kind = _normalize_endpoint_kind(body.get("endpoint_kind"))
if "model_refresh_mode" in body:
ep.model_refresh_mode = _normalize_refresh_mode(body.get("model_refresh_mode"), _endpoint_kind(ep))
if "model_refresh_interval" in body:
interval = _parse_positive_int(body.get("model_refresh_interval"), minimum=30, maximum=86400)
ep.model_refresh_interval = interval
if "model_refresh_timeout" in body:
timeout = _parse_positive_int(body.get("model_refresh_timeout"), minimum=1, maximum=60)
ep.model_refresh_timeout = timeout
# Rotating an API key used to require DELETE+POST, which wiped # Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base # endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key # URL. Allow in-place updates so the admin can change the key
@@ -1664,6 +1939,10 @@ def setup_model_routes(model_discovery):
"model_type": ep.model_type, "model_type": ep.model_type,
"base_url": ep.base_url, "base_url": ep.base_url,
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)), "pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
"model_refresh_interval": getattr(ep, "model_refresh_interval", None),
"model_refresh_timeout": getattr(ep, "model_refresh_timeout", None),
} }
finally: finally:
db.close() db.close()

View File

@@ -743,8 +743,74 @@ def _normalize_anthropic_url(url: str) -> str:
return url + "/messages" return url + "/messages"
return url + "/v1/messages" return url + "/v1/messages"
def _model_list_base(url: str) -> str:
"""Normalize model/chat URLs to the configured endpoint base."""
base = (url or "").strip().rstrip("/")
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
if base.endswith(suffix):
base = base[: -len(suffix)].rstrip("/")
for suffix in ("/chat", "/tags", "/generate"):
if base.endswith("/api" + suffix):
base = base[: -len(suffix)].rstrip("/")
return base
def _parse_model_cache(raw) -> List[str]:
if not raw:
return []
try:
models = json.loads(raw) if isinstance(raw, str) else raw
except Exception:
return []
if not isinstance(models, list):
return []
out = []
seen = set()
for item in models:
mid = str(item or "").strip()
if not mid or mid in seen:
continue
out.append(mid)
seen.add(mid)
return out
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
"""Return cached models for a configured endpoint matching endpoint_url."""
target = _model_list_base(endpoint_url)
if not target:
return []
try:
from src.database import SessionLocal, ModelEndpoint
except Exception:
return []
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in rows:
if _model_list_base(getattr(ep, "base_url", "")) != target:
continue
models = _parse_model_cache(getattr(ep, "cached_models", None) or getattr(ep, "models", None))
if not models:
continue
hidden = set(_parse_model_cache(getattr(ep, "hidden_models", None)))
return [m for m in models if m not in hidden]
except Exception:
return []
finally:
try:
db.close()
except Exception:
pass
return []
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]: def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
"""List available model IDs from an endpoint.""" """List available model IDs from an endpoint."""
cached = _configured_cached_model_ids(base_chat_url)
if cached:
return cached
provider = _detect_provider(base_chat_url) provider = _detect_provider(base_chat_url)
if provider == "anthropic": if provider == "anthropic":
return list(ANTHROPIC_MODELS) return list(ANTHROPIC_MODELS)

View File

@@ -6,6 +6,7 @@ Provides token estimation for context usage tracking.
""" """
import logging import logging
import sys
from typing import Dict, List, Optional from typing import Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -21,8 +22,55 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
"172.30.", "172.31.", "192.168.", "100.") "172.30.", "172.31.", "192.168.", "100.")
def _normalize_base_for_compare(url: str) -> str:
url = (url or "").strip().rstrip("/")
for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"):
if url.endswith(suffix):
url = url[: -len(suffix)].rstrip("/")
return url
def _configured_endpoint_kind(url: str) -> Optional[str]:
"""Return configured endpoint kind for a chat/base URL when available."""
target = _normalize_base_for_compare(url)
if not target:
return None
if "core.database" not in sys.modules:
return None
try:
from core.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in rows:
base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "")
if not base:
continue
if target != base and not target.startswith(base + "/"):
continue
kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower()
if kind in ("local", "api", "proxy"):
return kind
if getattr(ep, "api_key", None):
parsed = urlparse(base)
host = (parsed.hostname or "").lower()
path = (parsed.path or "").rstrip("/")
if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path):
return "proxy"
return "auto"
finally:
db.close()
except Exception:
return None
def _is_local_endpoint(url: str) -> bool: def _is_local_endpoint(url: str) -> bool:
"""Check if URL points to a local/private/tailscale address.""" """Check if URL points to a local/private/tailscale address."""
kind = _configured_endpoint_kind(url)
if kind in ("api", "proxy"):
return False
if kind == "local":
return True
try: try:
host = urlparse(url).hostname or "" host = urlparse(url).hostname or ""
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
@@ -170,6 +218,7 @@ def get_context_length(endpoint_url: str, model: str) -> int:
or context_window fields. Caches result per model ID. or context_window fields. Caches result per model ID.
Falls back to DEFAULT_CONTEXT if unavailable. Falls back to DEFAULT_CONTEXT if unavailable.
""" """
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url) is_local = _is_local_endpoint(endpoint_url)
if not is_local and model in _context_cache: if not is_local and model in _context_cache:
return _context_cache[model] return _context_cache[model]
@@ -178,7 +227,7 @@ def get_context_length(endpoint_url: str, model: str) -> int:
# Only cache non-default values to allow retry on next request. # Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping # Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache. # the same model id, so always re-query them instead of serving stale cache.
if not is_local and ctx != DEFAULT_CONTEXT: if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[model] = ctx _context_cache[model] = ctx
logger.info(f"Context length for {model}: {ctx}") logger.info(f"Context length for {model}: {ctx}")
return ctx return ctx
@@ -207,6 +256,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
"""Query the model API for context length.""" """Query the model API for context length."""
known = _lookup_known(model) known = _lookup_known(model)
api_ctx = None api_ctx = None
configured_kind = _configured_endpoint_kind(endpoint_url)
# Large OpenAI-compatible proxies can make /models expensive. If the
# endpoint is explicitly configured as API/proxy, prefer known context
# metadata (or the default) over downloading the full catalog.
if configured_kind in ("api", "proxy"):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known
return DEFAULT_CONTEXT
# Try llama.cpp /slots endpoint first — reports actual serving context # Try llama.cpp /slots endpoint first — reports actual serving context
if _is_local_endpoint(endpoint_url): if _is_local_endpoint(endpoint_url):

View File

@@ -2079,6 +2079,10 @@
</select> </select>
<div class="admin-model-form-row"> <div class="admin-model-form-row">
<input id="adm-epApiKey" type="password" placeholder="API key"> <input id="adm-epApiKey" type="password" placeholder="API key">
<select id="adm-epKind" style="padding:5px;width:82px;">
<option value="proxy">Proxy</option>
<option value="api">API</option>
</select>
<select id="adm-epType" style="padding:5px;width:80px;"> <select id="adm-epType" style="padding:5px;width:80px;">
<option value="llm">LLM</option> <option value="llm">LLM</option>
<option value="image">Image</option> <option value="image">Image</option>

View File

@@ -371,7 +371,7 @@ async function loadEndpoints() {
const listLegacy = el('adm-epList'); const listLegacy = el('adm-epList');
// Refresh model picker so new endpoints show up in chat // Refresh model picker so new endpoints show up in chat
if (window.modelsModule && window.modelsModule.refreshModels) { if (window.modelsModule && window.modelsModule.refreshModels) {
window.modelsModule.refreshModels(true); window.modelsModule.refreshModels();
setTimeout(() => { setTimeout(() => {
if (window.sessionModule && window.sessionModule.updateModelPicker) { if (window.sessionModule && window.sessionModule.updateModelPicker) {
window.sessionModule.updateModelPicker(); window.sessionModule.updateModelPicker();
@@ -411,12 +411,15 @@ async function loadEndpoints() {
? `<span class="admin-badge">${visibleCount}/${totalCount} models enabled</span>` ? `<span class="admin-badge">${visibleCount}/${totalCount} models enabled</span>`
: '<span class="admin-badge admin-badge-off">offline</span>'; : '<span class="admin-badge admin-badge-off">offline</span>';
const justAddedClass = (_recentlyAddedEpId && String(ep.id) === _recentlyAddedEpId) ? ' adm-ep-just-added' : ''; const justAddedClass = (_recentlyAddedEpId && String(ep.id) === _recentlyAddedEpId) ? ' adm-ep-just-added' : '';
const category = ep.category || (_isLocalEndpoint(ep.base_url) ? 'local' : 'api');
const kindLabel = ep.endpoint_kind && ep.endpoint_kind !== 'auto' ? ep.endpoint_kind.toUpperCase() : '';
return ` return `
<div class="admin-user-row${ep.is_enabled ? '' : ' admin-ep-disabled'}${justAddedClass}" data-adm-ep-id="${ep.id}"> <div class="admin-user-row${ep.is_enabled ? '' : ' admin-ep-disabled'}${justAddedClass}" data-adm-ep-id="${ep.id}">
<div style="display:flex;align-items:center;justify-content:space-between;${hasModels ? 'cursor:pointer;' : ''}padding:4px 0;" data-adm-ep-header="${ep.id}"> <div style="display:flex;align-items:center;justify-content:space-between;${hasModels ? 'cursor:pointer;' : ''}padding:4px 0;" data-adm-ep-header="${ep.id}">
<div class="admin-user-info" style="flex:1;flex-wrap:wrap;gap:0.3rem;"> <div class="admin-user-info" style="flex:1;flex-wrap:wrap;gap:0.3rem;">
<span class="admin-user-name">${esc(ep.name)}</span> <span class="admin-user-name">${esc(ep.name)}</span>
${ep.model_type === 'image' ? '<span class="admin-badge" style="background:color-mix(in srgb, var(--accent) 20%, transparent);color:var(--accent);">Image</span>' : ''} ${ep.model_type === 'image' ? '<span class="admin-badge" style="background:color-mix(in srgb, var(--accent) 20%, transparent);color:var(--accent);">Image</span>' : ''}
${kindLabel ? `<span class="admin-badge">${esc(kindLabel)}</span>` : ''}
${statusBadge} ${statusBadge}
${ep.is_enabled ? '' : '<span class="admin-badge admin-badge-off">disabled</span>'} ${ep.is_enabled ? '' : '<span class="admin-badge admin-badge-off">disabled</span>'}
${hasModels ? '<span style="font-size:10px;opacity:0.4;">Click to manage models</span>' : ''} ${hasModels ? '<span style="font-size:10px;opacity:0.4;">Click to manage models</span>' : ''}
@@ -427,7 +430,7 @@ async function loadEndpoints() {
${hasModels ? '<svg class="admin-user-chevron" width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" style="opacity:0.3;transition:transform 0.2s,opacity 0.2s;"><polyline points="6 9 12 15 18 9"/></svg>' : ''} ${hasModels ? '<svg class="admin-user-chevron" width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" style="opacity:0.3;transition:transform 0.2s,opacity 0.2s;"><polyline points="6 9 12 15 18 9"/></svg>' : ''}
</div> </div>
</div> </div>
<div class="admin-ep-detail">${esc(ep.base_url)}${_isLocalEndpoint(ep.base_url) ? `<button type="button" class="admin-ep-copy-btn" data-adm-copy-url="${esc(ep.base_url)}" title="Copy URL" aria-label="Copy URL" style="background:none;border:none;padding:0 2px;margin-left:6px;cursor:pointer;color:inherit;opacity:0.45;vertical-align:-2px;line-height:1;"><svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg></button>` : ''}${ep.has_key ? ' (key set)' : ''}</div> <div class="admin-ep-detail">${esc(ep.base_url)}${category === 'local' ? `<button type="button" class="admin-ep-copy-btn" data-adm-copy-url="${esc(ep.base_url)}" title="Copy URL" aria-label="Copy URL" style="background:none;border:none;padding:0 2px;margin-left:6px;cursor:pointer;color:inherit;opacity:0.45;vertical-align:-2px;line-height:1;"><svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg></button>` : ''}${ep.has_key ? ' (key set)' : ''}</div>
${hasModels ? `<div class="mcp-tools-panel hidden" data-adm-ep-models-panel="${ep.id}"></div>` : ''} ${hasModels ? `<div class="mcp-tools-panel hidden" data-adm-ep-models-panel="${ep.id}"></div>` : ''}
</div>`; </div>`;
}); });
@@ -446,7 +449,7 @@ async function loadEndpoints() {
container.innerHTML = indices.map(i => rowHtml[i]).join(''); container.innerHTML = indices.map(i => rowHtml[i]).join('');
}; };
const localIdx = [], apiIdx = []; const localIdx = [], apiIdx = [];
data.forEach((ep, i) => (_isLocalEndpoint(ep.base_url) ? localIdx : apiIdx).push(i)); data.forEach((ep, i) => ((ep.category || (_isLocalEndpoint(ep.base_url) ? 'local' : 'api')) === 'local' ? localIdx : apiIdx).push(i));
// Sort each section: enabled endpoints first, disabled at the bottom. // Sort each section: enabled endpoints first, disabled at the bottom.
// Preserve original order within each group via stable sort. // Preserve original order within each group via stable sort.
const _sortByEnabled = (a, b) => Number(!!data[b].is_enabled) - Number(!!data[a].is_enabled); const _sortByEnabled = (a, b) => Number(!!data[b].is_enabled) - Number(!!data[a].is_enabled);
@@ -552,22 +555,48 @@ async function loadEndpoints() {
} catch (_) {} } catch (_) {}
panel.appendChild(_ld); panel.appendChild(_ld);
const _stopSpin = () => { try { _modelsSpin && _modelsSpin.stop(); } catch (_) {} }; const _stopSpin = () => { try { _modelsSpin && _modelsSpin.stop(); } catch (_) {} };
try { const _loadingHtml = (label) => `<span style="opacity:0.55;font-size:11px;display:inline-flex;align-items:center;gap:8px;">${esc(label)}</span>`;
const res = await fetch(`/api/model-endpoints/${epId}/models`, { credentials: 'same-origin' }); const renderModels = (models, warning = '') => {
const models = await res.json();
_stopSpin();
const sortedModels = sortModelObjects(models); const sortedModels = sortModelObjects(models);
if (!sortedModels.length) { panel.innerHTML = '<span style="opacity:0.5;font-size:11px;">No models</span>'; return; } const warningHtml = warning ? `<div class="admin-error" style="font-size:11px;margin:6px 0;">${esc(warning)}</div>` : '';
const attachRefresh = () => {
panel.querySelector(`[data-ep-refresh-models="${epId}"]`)?.addEventListener('click', async (e) => {
e.preventDefault();
panel.innerHTML = _loadingHtml('Refreshing models...');
try {
const res = await fetch(`/api/model-endpoints/${epId}/models?refresh=true&refresh_timeout=60`, { credentials: 'same-origin' });
const refreshWarning = res.headers.get('X-Model-Refresh-Warning') || '';
if (!res.ok) throw new Error(`HTTP ${res.status}`);
const refreshedModels = await res.json();
renderModels(refreshedModels, refreshWarning);
if (refreshWarning && uiModule?.showToast) uiModule.showToast(refreshWarning, 6000);
} catch (_) {
renderModels(sortedModels, 'Model refresh failed; kept cached models.');
}
});
};
if (!sortedModels.length) {
panel.innerHTML = `<div class="mcp-tools-header">
<span>Models</span>
<span style="display:flex;gap:8px;align-items:center;">
<span class="mcp-tools-count">0/0 enabled</span>
<a href="#" data-ep-refresh-models="${epId}">Refresh</a>
</span>
</div>${warningHtml}<span style="opacity:0.5;font-size:11px;">No models</span>`;
attachRefresh();
return;
}
const hiddenSet = new Set(sortedModels.filter(m => m.is_hidden).map(m => m.id)); const hiddenSet = new Set(sortedModels.filter(m => m.is_hidden).map(m => m.id));
const showSearch = sortedModels.length >= 8; const showSearch = sortedModels.length >= 8;
panel.innerHTML = `<div class="mcp-tools-header"> panel.innerHTML = `<div class="mcp-tools-header">
<span>Models</span> <span>Models</span>
<span style="display:flex;gap:8px;align-items:center;"> <span style="display:flex;gap:8px;align-items:center;">
<span class="mcp-tools-count">${sortedModels.length - hiddenSet.size}/${sortedModels.length} enabled</span> <span class="mcp-tools-count">${sortedModels.length - hiddenSet.size}/${sortedModels.length} enabled</span>
<a href="#" data-ep-refresh-models="${epId}">Refresh</a>
<a href="#" data-ep-select-all="${epId}">All</a> <a href="#" data-ep-select-all="${epId}">All</a>
<a href="#" data-ep-select-none="${epId}">None</a> <a href="#" data-ep-select-none="${epId}">None</a>
</span> </span>
</div>${showSearch ? `<input type="search" class="mcp-tools-search" placeholder="Search ${sortedModels.length} models..." data-ep-search="${epId}">` : ''}<div class="mcp-tools-list">` + sortedModels.map(m => </div>${warningHtml}${showSearch ? `<input type="search" class="mcp-tools-search" placeholder="Search ${sortedModels.length} models..." data-ep-search="${epId}">` : ''}<div class="mcp-tools-list">` + sortedModels.map(m =>
`<label title="${esc(m.id)}" data-ep-model-row data-search="${esc((m.display + ' ' + m.id).toLowerCase())}" class="adm-model-row"> `<label title="${esc(m.id)}" data-ep-model-row data-search="${esc((m.display + ' ' + m.id).toLowerCase())}" class="adm-model-row">
<input type="checkbox" class="adm-cb-hidden" data-ep-model-id="${esc(m.id)}" ${!m.is_hidden ? 'checked' : ''}> <input type="checkbox" class="adm-cb-hidden" data-ep-model-id="${esc(m.id)}" ${!m.is_hidden ? 'checked' : ''}>
<span class="adm-check-dot" aria-hidden="true"></span> <span class="adm-check-dot" aria-hidden="true"></span>
@@ -580,6 +609,7 @@ async function loadEndpoints() {
row.style.display = (!needle || row.dataset.search.includes(needle)) ? '' : 'none'; row.style.display = (!needle || row.dataset.search.includes(needle)) ? '' : 'none';
}); });
}; };
attachRefresh();
panel.querySelector(`[data-ep-search="${epId}"]`)?.addEventListener('input', (e) => filterRows(e.target.value)); panel.querySelector(`[data-ep-search="${epId}"]`)?.addEventListener('input', (e) => filterRows(e.target.value));
panel.querySelector(`[data-ep-select-all="${epId}"]`)?.addEventListener('click', (e) => { panel.querySelector(`[data-ep-select-all="${epId}"]`)?.addEventListener('click', (e) => {
e.preventDefault(); e.preventDefault();
@@ -598,6 +628,13 @@ async function loadEndpoints() {
panel.querySelectorAll('input[type=checkbox]').forEach(cb => { panel.querySelectorAll('input[type=checkbox]').forEach(cb => {
cb.addEventListener('change', () => _saveEpModelState(epId, panel)); cb.addEventListener('change', () => _saveEpModelState(epId, panel));
}); });
};
try {
const res = await fetch(`/api/model-endpoints/${epId}/models`, { credentials: 'same-origin' });
if (!res.ok) throw new Error(`HTTP ${res.status}`);
const models = await res.json();
_stopSpin();
renderModels(models);
} catch (e) { _stopSpin(); panel.innerHTML = '<span class="admin-error" style="font-size:11px;">Failed to load models</span>'; } } catch (e) { _stopSpin(); panel.innerHTML = '<span class="admin-error" style="font-size:11px;">Failed to load models</span>'; }
} }
}); });
@@ -637,6 +674,7 @@ async function _saveEpModelState(epId, panel) {
function initEndpointForm() { function initEndpointForm() {
const provider = el('adm-epProvider'); const provider = el('adm-epProvider');
const urlInput = el('adm-epUrl'); const urlInput = el('adm-epUrl');
const kindSel = el('adm-epKind');
// Custom provider picker — mirrors the (now hidden) <select id="adm-epProvider"> // Custom provider picker — mirrors the (now hidden) <select id="adm-epProvider">
// so the rest of this function (which reads provider.value and dispatches // so the rest of this function (which reads provider.value and dispatches
@@ -688,14 +726,20 @@ function initEndpointForm() {
provider.addEventListener('change', () => { provider.addEventListener('change', () => {
if (provider.value) urlInput.value = provider.value; if (provider.value) urlInput.value = provider.value;
else urlInput.value = ''; else urlInput.value = '';
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
}); });
urlInput.addEventListener('input', () => { urlInput.addEventListener('input', () => {
if (provider.value && urlInput.value.trim() !== provider.value) { if (provider.value && urlInput.value.trim() !== provider.value) {
provider.value = ''; provider.value = '';
if (kindSel) kindSel.value = 'proxy';
_renderPickerMenu(); _renderPickerMenu();
_syncPickerCurrent(); _syncPickerCurrent();
} }
}); });
if (kindSel) kindSel.value = provider.value ? 'api' : (kindSel.value || 'proxy');
function _apiEndpointKind() {
return (kindSel && kindSel.value) ? kindSel.value : (provider.value ? 'api' : 'proxy');
}
function _normalizeBaseUrl(raw) { function _normalizeBaseUrl(raw) {
let u = raw.trim(); let u = raw.trim();
// Fix common protocol typos // Fix common protocol typos
@@ -784,6 +828,8 @@ function initEndpointForm() {
try { try {
const fd = new FormData(); const fd = new FormData();
fd.append('base_url', url); fd.append('base_url', url);
fd.append('endpoint_kind', _apiEndpointKind());
fd.append('model_refresh_timeout', '30');
if (apiKey) fd.append('api_key', apiKey); if (apiKey) fd.append('api_key', apiKey);
const res = await fetch('/api/model-endpoints/test', { const res = await fetch('/api/model-endpoints/test', {
method: 'POST', method: 'POST',
@@ -828,6 +874,10 @@ function initEndpointForm() {
try { try {
const fd = new FormData(); const fd = new FormData();
fd.append('base_url', url); fd.append('base_url', url);
const endpointKind = _apiEndpointKind();
fd.append('endpoint_kind', endpointKind);
fd.append('model_refresh_mode', endpointKind === 'proxy' ? 'manual' : 'auto');
fd.append('model_refresh_timeout', '30');
if (apiKey) fd.append('api_key', apiKey); if (apiKey) fd.append('api_key', apiKey);
if (provider.value && provider.selectedOptions && provider.selectedOptions[0]) { if (provider.value && provider.selectedOptions && provider.selectedOptions[0]) {
fd.append('name', provider.selectedOptions[0].textContent.trim()); fd.append('name', provider.selectedOptions[0].textContent.trim());
@@ -842,6 +892,7 @@ function initEndpointForm() {
const count = d.models ? d.models.length : 0; const count = d.models ? d.models.length : 0;
urlInput.value = ''; urlInput.style.display = ''; urlInput.value = ''; urlInput.style.display = '';
el('adm-epApiKey').value = ''; provider.value = ''; el('adm-epApiKey').value = ''; provider.value = '';
if (kindSel) kindSel.value = 'proxy';
if (epType) epType.value = 'llm'; if (epType) epType.value = 'llm';
if (d.id) _recentlyAddedEpId = String(d.id); if (d.id) _recentlyAddedEpId = String(d.id);
await loadEndpoints(); await loadEndpoints();
@@ -904,6 +955,8 @@ function initEndpointForm() {
const fd = new FormData(); const fd = new FormData();
fd.append('base_url', url); fd.append('base_url', url);
if (apiKey) fd.append('api_key', apiKey); if (apiKey) fd.append('api_key', apiKey);
fd.append('endpoint_kind', 'local');
fd.append('model_refresh_mode', 'auto');
const lt = el('adm-epLocalType'); const lt = el('adm-epLocalType');
if (lt) fd.append('model_type', lt.value); if (lt) fd.append('model_type', lt.value);
fd.append('skip_probe', 'false'); fd.append('skip_probe', 'false');
@@ -986,6 +1039,8 @@ function initEndpointForm() {
const base = item.url.replace('/chat/completions', '').replace(/\/$/, ''); const base = item.url.replace('/chat/completions', '').replace(/\/$/, '');
const fd = new FormData(); const fd = new FormData();
fd.append('base_url', base); fd.append('base_url', base);
fd.append('endpoint_kind', 'local');
fd.append('model_refresh_mode', 'auto');
fd.append('skip_probe', 'false'); fd.append('skip_probe', 'false');
const r = await fetch('/api/model-endpoints', { method: 'POST', body: fd }); const r = await fetch('/api/model-endpoints', { method: 'POST', body: fd });
if (r.ok) { if (r.ok) {

View File

@@ -561,7 +561,7 @@ function _initModelPickerDropdown() {
menu.classList.remove('closing', 'hidden'); menu.classList.remove('closing', 'hidden');
_populate(''); _populate('');
if (window.modelsModule && window.modelsModule.refreshModels) { if (window.modelsModule && window.modelsModule.refreshModels) {
window.modelsModule.refreshModels(true).then(() => { window.modelsModule.refreshModels().then(() => {
if (!menu.classList.contains('hidden')) _populate(search.value || ''); if (!menu.classList.contains('hidden')) _populate(search.value || '');
updateModelPicker(); updateModelPicker();
}).catch(() => {}); }).catch(() => {});

View File

@@ -16,6 +16,7 @@ import { sortModelIds } from './modelSort.js';
let API_BASE = ''; let API_BASE = '';
let _cachedItems = []; // cached /api/models items for model-switch dropdown let _cachedItems = []; // cached /api/models items for model-switch dropdown
let _lastFetchTime = 0; let _lastFetchTime = 0;
let _fetchInflight = null;
const _FETCH_CACHE_TTL = 30000; // 30s client-side cache for /api/models const _FETCH_CACHE_TTL = 30000; // 30s client-side cache for /api/models
const COLLAPSE_KEY = 'odysseus-models-collapsed'; const COLLAPSE_KEY = 'odysseus-models-collapsed';
const FAVORITES_KEY = 'odysseus-model-favorites'; const FAVORITES_KEY = 'odysseus-model-favorites';
@@ -176,8 +177,15 @@ export async function refreshModels(force = false) {
box.appendChild(_loadingSpinner.createElement()); box.appendChild(_loadingSpinner.createElement());
_loadingSpinner.start(); _loadingSpinner.start();
try { try {
const res = await fetch(`${API_BASE}/api/models`); if (!_fetchInflight) {
const data = await res.json(); _fetchInflight = fetch(`${API_BASE}/api/models`, { credentials: 'same-origin' })
.then(async (res) => {
if (!res.ok) throw new Error(`HTTP ${res.status}`);
return res.json();
})
.finally(() => { _fetchInflight = null; });
}
const data = await _fetchInflight;
_lastFetchTime = Date.now(); _lastFetchTime = Date.now();
_cachedItems = data.items || []; _cachedItems = data.items || [];
} catch (e) { } catch (e) {

View File

@@ -1,11 +1,59 @@
"""Tests for model_context.py — local endpoint detection, token estimation, known model lookup.""" """Tests for model_context.py — local endpoint detection, token estimation, known model lookup."""
import sys
import types
import pytest import pytest
import src.model_context as model_context import src.model_context as model_context
from src.model_context import _is_local_endpoint, estimate_tokens, _lookup_known from src.model_context import _is_local_endpoint, estimate_tokens, _lookup_known
class _Column:
def __init__(self, name):
self.name = name
def __eq__(self, value):
return ("eq", self.name, value)
class _ModelEndpoint:
is_enabled = _Column("is_enabled")
class _Query:
def __init__(self, rows):
self.rows = list(rows)
def filter(self, *conditions):
for condition in conditions:
if isinstance(condition, tuple) and condition[0] == "eq":
_, field, value = condition
self.rows = [row for row in self.rows if getattr(row, field) == value]
return self
def all(self):
return list(self.rows)
class _Db:
def __init__(self, rows):
self.rows = rows
def query(self, model):
return _Query(self.rows)
def close(self):
pass
def _install_endpoint_db(monkeypatch, rows):
mod = types.ModuleType("core.database")
mod.ModelEndpoint = _ModelEndpoint
mod.SessionLocal = lambda: _Db(rows)
monkeypatch.setitem(sys.modules, "core.database", mod)
class TestIsLocalEndpoint: class TestIsLocalEndpoint:
def test_localhost(self): def test_localhost(self):
assert _is_local_endpoint("http://localhost:5000/v1/chat/completions") is True assert _is_local_endpoint("http://localhost:5000/v1/chat/completions") is True
@@ -23,6 +71,18 @@ class TestIsLocalEndpoint:
# 100.64.0.0/10 is the CGNAT range Tailscale uses. # 100.64.0.0/10 is the CGNAT range Tailscale uses.
assert _is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True assert _is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True
def test_configured_tailscale_proxy_is_remote(self, monkeypatch):
_install_endpoint_db(monkeypatch, [
types.SimpleNamespace(
base_url="http://100.117.136.97:34521/v1",
endpoint_kind="proxy",
api_key="fake-key",
is_enabled=True,
)
])
assert _is_local_endpoint("http://100.117.136.97:34521/v1/chat/completions") is False
def test_openai_is_remote(self): def test_openai_is_remote(self):
assert _is_local_endpoint("https://api.openai.com/v1/chat/completions") is False assert _is_local_endpoint("https://api.openai.com/v1/chat/completions") is False
@@ -164,3 +224,28 @@ class TestGetContextLength:
assert first == 200000 assert first == 200000
assert second == 200000 assert second == 200000
assert len(calls) == 1 assert len(calls) == 1
def test_configured_proxy_uses_default_without_model_listing(self, monkeypatch):
_install_endpoint_db(monkeypatch, [
types.SimpleNamespace(
base_url="http://100.117.136.97:34521/v1",
endpoint_kind="proxy",
api_key="fake-key",
is_enabled=True,
)
])
calls = []
def fake_get(*args, **kwargs):
calls.append(args)
raise AssertionError("/models should not be queried for configured proxy context")
monkeypatch.setattr(model_context.httpx, "get", fake_get)
endpoint = "http://100.117.136.97:34521/v1/chat/completions"
first = model_context.get_context_length(endpoint, "unknown-proxy-model")
second = model_context.get_context_length(endpoint, "unknown-proxy-model")
assert first == model_context.DEFAULT_CONTEXT
assert second == model_context.DEFAULT_CONTEXT
assert calls == []

View File

@@ -2,9 +2,11 @@
import asyncio import asyncio
import json import json
import sys import sys
import threading
import time
import types import types
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
from types import SimpleNamespace
import httpx import httpx
import pytest import pytest
@@ -28,7 +30,9 @@ if "core.database" not in sys.modules:
sys.modules["core.database"] = _core_db sys.modules["core.database"] = _core_db
import routes.model_routes as model_routes import routes.model_routes as model_routes
import src.database as src_database
import src.endpoint_resolver as endpoint_resolver import src.endpoint_resolver as endpoint_resolver
import src.llm_core as llm_core
from routes.model_routes import ( from routes.model_routes import (
_match_provider_curated, _match_provider_curated,
_curate_models, _curate_models,
@@ -36,7 +40,11 @@ from routes.model_routes import (
_normalize_model_ids, _normalize_model_ids,
_is_chat_model, _is_chat_model,
_classify_endpoint, _classify_endpoint,
_effective_endpoint_kind,
_probe_endpoint, _probe_endpoint,
_ping_endpoint,
_parse_model_list,
_normalize_refresh_mode,
_truthy, _truthy,
_speech_settings_using_endpoint, _speech_settings_using_endpoint,
_clear_speech_settings_for_endpoint, _clear_speech_settings_for_endpoint,
@@ -304,6 +312,54 @@ class TestClassifyEndpoint:
def test_malformed_url(self): def test_malformed_url(self):
assert _classify_endpoint("not-a-url") == "api" assert _classify_endpoint("not-a-url") == "api"
def test_tailscale_auto_is_local(self):
assert _classify_endpoint("http://100.117.136.97:34521/v1") == "local"
def test_tailscale_proxy_override_is_api(self):
assert _classify_endpoint("http://100.117.136.97:34521/v1", "proxy") == "api"
def test_tailscale_api_override_is_api(self):
assert _classify_endpoint("http://100.117.136.97:34521/v1", "api") == "api"
def test_public_local_override_is_local(self):
assert _classify_endpoint("https://api.openai.com/v1", "local") == "local"
def test_keyed_legacy_v1_endpoint_is_effective_proxy(self):
ep = SimpleNamespace(endpoint_kind="auto", api_key="fake-key")
assert _effective_endpoint_kind(ep, "http://100.117.136.97:34521/v1") == "proxy"
def test_proxy_refresh_mode_defaults_manual(self):
assert _normalize_refresh_mode("", "proxy") == "manual"
assert _normalize_refresh_mode("auto", "proxy") == "manual"
assert _normalize_refresh_mode("manual", "proxy") == "manual"
assert _normalize_refresh_mode("auto", "api") == "auto"
def test_parse_model_list_accepts_json_and_text(self):
assert _parse_model_list('["a", "b", "a"]') == ["a", "b"]
assert _parse_model_list("a, b\nc") == ["a", "b", "c"]
def test_ping_endpoint_does_not_request_models_for_openai_style_proxy(self, monkeypatch):
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda url: url, raising=False)
seen = []
def fake_head(*args, **kwargs):
raise AssertionError("generic proxy health check should not use HEAD")
def fake_get(url, headers=None, timeout=None):
seen.append(("GET", url))
request = httpx.Request("GET", url)
return httpx.Response(200, request=request)
monkeypatch.setattr(model_routes.httpx, "head", fake_head)
monkeypatch.setattr(model_routes.httpx, "get", fake_get)
result = _ping_endpoint("http://100.117.136.97:34521/v1", "fake-key", timeout=1)
assert result["reachable"] is True
assert result["status_code"] == 200
assert seen == [("GET", "http://100.117.136.97:34521/v1")]
assert all(not url.endswith("/models") for _, url in seen)
# ── setup probing ── # ── setup probing ──
@@ -534,77 +590,51 @@ if "python_multipart" not in sys.modules:
sys.modules["python_multipart"] = _mp_stub sys.modules["python_multipart"] = _mp_stub
class _PinnedFakeQuery: class _RouteCondition:
def __init__(self, rows): def __init__(self, op, field, value):
self.rows = list(rows) self.op = op
self.field = field
def filter(self, *conditions): self.value = value
return self
def order_by(self, *args):
return self
def first(self):
return self.rows[0] if self.rows else None
def all(self):
return list(self.rows)
class _PinnedFakeDb:
def __init__(self, rows):
self.rows = rows
self.added = []
self.committed = 0
def query(self, model):
return _PinnedFakeQuery(self.rows)
def add(self, row):
self.added.append(row)
def commit(self):
self.committed += 1
def close(self):
pass
class _FakeCol:
"""Column stand-in: every comparison/operator just returns itself so the
dedupe query expressions evaluate without a real SQLAlchemy column."""
__hash__ = None
def __eq__(self, other):
return self
def is_(self, other):
return self
def __or__(self, other): def __or__(self, other):
return self return ("or", self, other)
class _RouteColumn:
def __init__(self, name):
self.name = name
def __eq__(self, value):
return _RouteCondition("eq", self.name, value)
def is_(self, value):
return _RouteCondition("eq", self.name, value)
def desc(self): def desc(self):
return self return self
class _RecordingEndpoint: class _RouteModelEndpoint:
"""ModelEndpoint stand-in that stores constructor kwargs as attributes. """ModelEndpoint stand-in that stores constructor kwargs as attributes.
Class-level fake columns let it double as the query class in the dedupe Class-level fake columns let it double as the query class in the dedupe
lookup; instance attributes (set in __init__) shadow them per-row. lookup; instance attributes (set in __init__) shadow them per-row.
""" """
id = _FakeCol() id = _RouteColumn("id")
base_url = _FakeCol() base_url = _RouteColumn("base_url")
owner = _FakeCol() is_enabled = _RouteColumn("is_enabled")
owner = _RouteColumn("owner")
created_at = _RouteColumn("created_at")
def __init__(self, **kwargs): def __init__(self, **kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
_RecordingEndpoint = _RouteModelEndpoint
class _PinnedFakeRequest: class _PinnedFakeRequest:
def __init__(self, body=None, headers=None): def __init__(self, body=None, headers=None):
self._body = body if body is not None else {} self._body = body if body is not None else {}
@@ -635,6 +665,13 @@ def _make_endpoint(**kwargs):
pinned_models=None, pinned_models=None,
model_type="llm", model_type="llm",
supports_tools=None, supports_tools=None,
endpoint_kind="auto",
model_refresh_mode="auto",
model_refresh_interval=None,
model_refresh_timeout=None,
owner=None,
created_at=None,
updated_at=None,
) )
base.update(kwargs) base.update(kwargs)
return SimpleNamespace(**base) return SimpleNamespace(**base)
@@ -676,7 +713,7 @@ def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: []) monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET") endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
result = endpoint("ep1", _PinnedFakeRequest()) result = endpoint("ep1", _PinnedFakeRequest(), SimpleNamespace(headers={}))
ids = [row["id"] for row in result] ids = [row["id"] for row in result]
assert ids == ["deploy-1"] assert ids == ["deploy-1"]
@@ -730,6 +767,10 @@ def _create_form_kwargs(**overrides):
skip_probe="true", # avoid any network probe in unit tests skip_probe="true", # avoid any network probe in unit tests
require_models="false", require_models="false",
model_type="llm", model_type="llm",
endpoint_kind="auto",
model_refresh_mode="",
model_refresh_interval="",
model_refresh_timeout="",
supports_tools="", supports_tools="",
pinned_models="", pinned_models="",
container_local="false", container_local="false",
@@ -772,6 +813,7 @@ def test_post_creates_endpoint_with_pinned_models(monkeypatch):
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch): def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
existing = _make_endpoint( existing = _make_endpoint(
base_url="http://host:1234/v1",
cached_models=json.dumps(["m1"]), cached_models=json.dumps(["m1"]),
hidden_models=None, hidden_models=None,
pinned_models=json.dumps(["old-pin"]), pinned_models=json.dumps(["old-pin"]),
@@ -798,6 +840,7 @@ def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch): def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
existing = _make_endpoint( existing = _make_endpoint(
base_url="http://host:1234/v1",
cached_models=json.dumps(["m1"]), cached_models=json.dumps(["m1"]),
pinned_models=json.dumps(["keep-me"]), pinned_models=json.dumps(["keep-me"]),
) )
@@ -814,3 +857,464 @@ def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
assert json.loads(existing.pinned_models) == ["keep-me"] assert json.loads(existing.pinned_models) == ["keep-me"]
assert result["pinned_models"] == ["keep-me"] assert result["pinned_models"] == ["keep-me"]
assert db.committed == 0 # nothing to persist assert db.committed == 0 # nothing to persist
class _RouteQuery:
def __init__(self, rows):
self.rows = list(rows)
def filter(self, *conditions):
for condition in conditions:
if isinstance(condition, _RouteCondition) and condition.op == "eq":
self.rows = [row for row in self.rows if getattr(row, condition.field, None) == condition.value]
elif isinstance(condition, tuple) and condition and condition[0] == "or":
keep = []
for row in self.rows:
matched = False
for part in condition[1:]:
if isinstance(part, _RouteCondition) and part.op == "eq":
matched = matched or (getattr(row, part.field, None) == part.value)
if matched:
keep.append(row)
self.rows = keep
return self
def order_by(self, *args, **kwargs):
return self
def all(self):
return list(self.rows)
def first(self):
return self.rows[0] if self.rows else None
class _RouteDb:
def __init__(self, rows):
self.rows = rows
self.added = []
self.committed = 0
self.commits = 0
self.closed = False
def query(self, model):
return _RouteQuery(self.rows)
def commit(self):
self.committed += 1
self.commits += 1
def close(self):
self.closed = True
def add(self, row):
self.rows.append(row)
self.added.append(row)
_PinnedFakeDb = _RouteDb
class _ImmediateThread:
def __init__(self, target, daemon=None):
self.target = target
def start(self):
self.target()
def _wait_for(predicate, timeout=2.0):
deadline = time.time() + timeout
while time.time() < deadline:
if predicate():
return True
time.sleep(0.01)
return bool(predicate())
def _route_endpoint(router, path, method="GET"):
for route in router.routes:
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError(f"{method} {path} route not found")
def _route_ep(
id,
base_url,
*,
cached_models=None,
endpoint_kind="auto",
api_key=None,
name=None,
pinned_models=None,
refresh_mode="auto",
refresh_timeout=None,
):
return SimpleNamespace(
id=id,
name=name or id,
base_url=base_url,
api_key=api_key,
is_enabled=True,
cached_models=json.dumps(cached_models) if cached_models is not None else None,
hidden_models=None,
pinned_models=json.dumps(pinned_models) if pinned_models is not None else None,
model_type="llm",
endpoint_kind=endpoint_kind,
model_refresh_mode=refresh_mode,
model_refresh_interval=None,
model_refresh_timeout=refresh_timeout,
supports_tools=None,
owner=None,
created_at=None,
updated_at=None,
)
def _route_request():
return SimpleNamespace(
state=SimpleNamespace(current_user=None),
app=SimpleNamespace(state=SimpleNamespace(auth_manager=None)),
)
def test_api_models_returns_cached_proxy_models_without_refresh_probe(monkeypatch):
row = _route_ep(
"proxy",
"http://100.117.136.97:34521/v1",
cached_models=["cached-model"],
endpoint_kind="proxy",
api_key="fake-key",
refresh_mode="manual",
)
db = _RouteDb([row])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
def fail_probe(*args, **kwargs):
raise AssertionError("/models probe should not run for cached manual proxy")
monkeypatch.setattr(model_routes, "_probe_endpoint", fail_probe)
monkeypatch.setattr(threading, "Thread", _ImmediateThread)
result = _route_endpoint(router, "/api/models")(_route_request())
assert result["items"][0]["models"] == ["cached-model"]
assert result["items"][0]["category"] == "api"
assert result["items"][0]["endpoint_kind"] == "proxy"
assert "offline" not in result["items"][0]
assert json.loads(row.cached_models) == ["cached-model"]
@pytest.mark.asyncio
async def test_probe_local_skips_tailscale_proxy_endpoint(monkeypatch):
proxy = _route_ep(
"proxy",
"http://100.117.136.97:34521/v1",
cached_models=["cached-model"],
endpoint_kind="proxy",
api_key="fake-key",
)
local = _route_ep("local", "http://127.0.0.1:8000/v1", endpoint_kind="local")
db = _RouteDb([proxy, local])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("full probe should not run")))
pinged = []
def fake_ping(base_url, api_key=None, timeout=1.5):
pinged.append(base_url)
return {"reachable": True, "status_code": 404, "error": "HTTP 404"}
monkeypatch.setattr(model_routes, "_ping_endpoint", fake_ping)
result = await _route_endpoint(router, "/api/model-endpoints/probe-local")(_route_request())
assert set(result) == {"local"}
assert pinged == ["http://127.0.0.1:8000/v1"]
def test_background_refresh_deduplicates_same_base_url(monkeypatch):
ep1 = _route_ep("a", "http://127.0.0.1:8000/v1", endpoint_kind="local")
ep2 = _route_ep("b", "http://127.0.0.1:8000/v1", endpoint_kind="local")
db = _RouteDb([ep1, ep2])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
calls = []
probe_done = threading.Event()
def fake_probe(base_url, api_key=None, timeout=2):
calls.append(base_url)
probe_done.set()
return ["live-model"]
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
_route_endpoint(router, "/api/models")(_route_request(), refresh=True)
assert probe_done.wait(2)
assert _wait_for(lambda: ep1.cached_models and ep2.cached_models)
assert calls == ["http://127.0.0.1:8000/v1"]
assert json.loads(ep1.cached_models) == ["live-model"]
assert json.loads(ep2.cached_models) == ["live-model"]
def test_background_refresh_failure_keeps_existing_cached_models(monkeypatch):
ep = _route_ep(
"local",
"http://127.0.0.1:8000/v1",
cached_models=["cached-model"],
endpoint_kind="local",
)
db = _RouteDb([ep])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
probe_done = threading.Event()
def fake_probe(*args, **kwargs):
probe_done.set()
return []
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
result = _route_endpoint(router, "/api/models")(_route_request(), refresh=True)
assert probe_done.wait(2)
assert _wait_for(lambda: db.commits > 0)
assert result["items"][0]["models"] == ["cached-model"]
assert json.loads(ep.cached_models) == ["cached-model"]
def test_llm_core_list_model_ids_uses_cached_configured_proxy(monkeypatch):
ep = _route_ep(
"proxy",
"http://100.117.136.97:34521/v1",
cached_models=["cached-model", "hidden-model"],
endpoint_kind="proxy",
)
ep.hidden_models = json.dumps(["hidden-model"])
db = _RouteDb([ep])
monkeypatch.setattr(src_database, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(src_database, "SessionLocal", lambda: db)
monkeypatch.setattr(llm_core.httpx, "get", lambda *a, **k: (_ for _ in ()).throw(AssertionError("/models should not be fetched")))
assert llm_core.list_model_ids("http://100.117.136.97:34521/v1/chat/completions", timeout=1) == ["cached-model"]
def test_explicit_proxy_test_fetches_models_with_long_timeout(monkeypatch):
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_ping_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("ping should not run when model listing succeeds")))
calls = []
returned = ["NVIDIA NIM/openai/gpt-oss-120b", "mistral/mistral-small-2603"]
def fake_probe(base_url, api_key=None, timeout=2):
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
return returned
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
result = _route_endpoint(router, "/api/model-endpoints/test", "POST")(
_route_request(),
base_url="http://100.117.136.97:34521/v1",
api_key="fake-key",
endpoint_kind="proxy",
)
assert result["online"] is True
assert result["status"] == "online"
assert result["models"] == returned
assert calls == [{
"base_url": "http://100.117.136.97:34521/v1",
"api_key": "fake-key",
"timeout": 30.0,
}]
def test_explicit_proxy_add_fetches_and_caches_models_with_long_timeout(monkeypatch):
db = _RouteDb([])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_load_settings", lambda: {})
monkeypatch.setattr(model_routes, "_save_settings", lambda settings: None)
monkeypatch.setattr("src.auth_helpers.get_current_user", lambda request: None)
monkeypatch.setattr(model_routes, "_ping_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("ping should not run when model listing succeeds")))
calls = []
returned = ["NVIDIA NIM/openai/gpt-oss-120b", "mistral/mistral-small-2603"]
def fake_probe(base_url, api_key=None, timeout=2):
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
return returned
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
result = _route_endpoint(router, "/api/model-endpoints", "POST")(
_route_request(),
name="Bifrost",
base_url="http://100.117.136.97:34521/v1",
api_key="fake-key",
skip_probe="true",
require_models="false",
model_type="llm",
endpoint_kind="proxy",
model_refresh_mode="manual",
model_refresh_interval="",
model_refresh_timeout="",
supports_tools="",
container_local="false",
shared="true",
)
assert result["online"] is True
assert result["status"] == "online"
assert result["models"] == returned
assert calls == [{
"base_url": "http://100.117.136.97:34521/v1",
"api_key": "fake-key",
"timeout": 30.0,
}]
assert len(db.rows) == 1
assert json.loads(db.rows[0].cached_models) == returned
assert db.rows[0].endpoint_kind == "proxy"
assert db.rows[0].model_refresh_mode == "manual"
def test_manual_refresh_uses_long_timeout_and_saves_full_model_list(monkeypatch):
ep = _route_ep(
"proxy",
"http://100.117.136.97:34521/v1",
cached_models=["cached-model"],
endpoint_kind="proxy",
api_key="fake-key",
refresh_mode="manual",
)
db = _RouteDb([ep])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
calls = []
refreshed = ["cached-model", "mistral/mistral-small-2603", "provider/nested/model/id"]
def fake_probe(base_url, api_key=None, timeout=2):
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
return refreshed
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
response = SimpleNamespace(headers={})
result = _route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
"proxy",
_route_request(),
response,
refresh=True,
refresh_timeout=60,
)
assert [m["id"] for m in result] == refreshed
assert calls == [{
"base_url": "http://100.117.136.97:34521/v1",
"api_key": "fake-key",
"timeout": 60.0,
}]
assert json.loads(ep.cached_models) == refreshed
assert db.commits == 1
assert response.headers["X-Model-Refresh-Status"] == "refreshed"
assert response.headers["X-Model-Refresh-Count"] == "3"
def test_manual_refresh_defaults_to_proxy_long_timeout(monkeypatch):
ep = _route_ep(
"proxy",
"https://proxy.example.test/v1",
cached_models=["cached-model"],
endpoint_kind="proxy",
refresh_mode="manual",
)
db = _RouteDb([ep])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
timeouts = []
def fake_probe(base_url, api_key=None, timeout=2):
timeouts.append(timeout)
return ["cached-model", "new-model"]
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
response = SimpleNamespace(headers={})
_route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
"proxy",
_route_request(),
response,
refresh=True,
)
assert timeouts == [30.0]
assert json.loads(ep.cached_models) == ["cached-model", "new-model"]
def test_manual_refresh_timeout_keeps_cached_models_and_warns(monkeypatch):
ep = _route_ep(
"proxy",
"http://100.117.136.97:34521/v1",
cached_models=["cached-model"],
endpoint_kind="proxy",
api_key="fake-key",
refresh_mode="manual",
)
db = _RouteDb([ep])
router = model_routes.setup_model_routes(model_discovery=None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
def fake_probe(base_url, api_key=None, timeout=2):
raise httpx.TimeoutException("timed out")
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
response = SimpleNamespace(headers={})
result = _route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
"proxy",
_route_request(),
response,
refresh=True,
refresh_timeout=60,
)
assert [m["id"] for m in result] == ["cached-model"]
assert json.loads(ep.cached_models) == ["cached-model"]
assert db.commits == 0
assert response.headers["X-Model-Refresh-Status"] == "failed"
assert "kept cached models" in response.headers["X-Model-Refresh-Warning"]