fix(models): stabilize proxy endpoint refresh behavior
* fix: support large proxy model endpoint refresh Large OpenAI-compatible proxy endpoints can expose hundreds of models and make /v1/models slow. Treating those endpoints like local model servers caused model picker opens and background probes to repeatedly hit /models, producing timeouts and making otherwise usable endpoints appear offline. Make model endpoint discovery cached-first for normal UI usage, add explicit proxy/API classification and refresh policy fields, exclude proxy/API endpoints from aggressive local probing, and preserve cached models when refresh fails. Manual Test/Add/Refresh actions still fetch the full model list with longer timeouts so users can intentionally import large proxy model lists without blocking normal model picker usage. * fix: preserve endpoint ping status semantics
This commit is contained in:
@@ -342,6 +342,14 @@ class ModelEndpoint(TimestampMixin, Base):
|
||||
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
|
||||
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
|
||||
model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
|
||||
# auto = classify by URL; local = self-hosted server; api/proxy = external
|
||||
# OpenAI-compatible API even when reachable through a private/tailnet IP.
|
||||
endpoint_kind = Column(String, nullable=True, default="auto")
|
||||
# auto = background refresh with TTL/backoff; manual/disabled = cached-first
|
||||
# only unless an explicit endpoint probe is requested.
|
||||
model_refresh_mode = Column(String, nullable=True, default="auto")
|
||||
model_refresh_interval = Column(Integer, nullable=True, default=None)
|
||||
model_refresh_timeout = Column(Integer, nullable=True, default=None)
|
||||
# Whether models on this endpoint accept OpenAI-style function
|
||||
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
|
||||
# register time from `--enable-auto-tool-choice` in the serve cmd;
|
||||
@@ -809,6 +817,29 @@ def _migrate_add_model_type_column():
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
|
||||
|
||||
def _migrate_add_model_endpoint_refresh_columns():
|
||||
"""Add endpoint classification / refresh policy columns if missing."""
|
||||
import sqlite3
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
if columns and "endpoint_kind" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN endpoint_kind TEXT DEFAULT 'auto'")
|
||||
if columns and "model_refresh_mode" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_mode TEXT DEFAULT 'auto'")
|
||||
if columns and "model_refresh_interval" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_interval INTEGER")
|
||||
if columns and "model_refresh_timeout" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
|
||||
|
||||
def _migrate_add_task_run_model_column():
|
||||
"""Add model column to task_runs if it doesn't exist (records which model ran)."""
|
||||
import sqlite3
|
||||
@@ -1539,6 +1570,7 @@ def init_db():
|
||||
_migrate_add_pinned_models_column()
|
||||
_migrate_add_notes_sort_order()
|
||||
_migrate_add_model_type_column()
|
||||
_migrate_add_model_endpoint_refresh_columns()
|
||||
_migrate_add_model_endpoint_owner_column()
|
||||
_migrate_add_supports_tools_column()
|
||||
_migrate_add_task_run_model_column()
|
||||
|
||||
@@ -11,7 +11,7 @@ import httpx
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request, Response
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
|
||||
@@ -335,6 +335,141 @@ def _truthy(value: str | None) -> bool:
|
||||
return (value or "").strip().lower() in ("true", "1", "yes", "on")
|
||||
|
||||
|
||||
_ENDPOINT_KINDS = {"auto", "local", "api", "proxy"}
|
||||
_REFRESH_MODES = {"auto", "manual", "disabled"}
|
||||
|
||||
|
||||
def _normalize_endpoint_kind(value: Any) -> str:
|
||||
kind = str(value or "auto").strip().lower()
|
||||
return kind if kind in _ENDPOINT_KINDS else "auto"
|
||||
|
||||
|
||||
def _normalize_refresh_mode(value: Any, endpoint_kind: str = "auto") -> str:
|
||||
mode = str(value or "").strip().lower()
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
if mode in ("manual", "disabled"):
|
||||
return mode
|
||||
if mode == "auto" and kind != "proxy":
|
||||
return "auto"
|
||||
# Proxies default to manual cached-first behavior. Normal local/API
|
||||
# endpoints keep automatic bounded refreshes.
|
||||
return "manual" if kind == "proxy" else "auto"
|
||||
|
||||
|
||||
def _endpoint_kind(ep: Any) -> str:
|
||||
return _normalize_endpoint_kind(getattr(ep, "endpoint_kind", None))
|
||||
|
||||
|
||||
def _endpoint_refresh_mode(ep: Any, endpoint_kind: str | None = None) -> str:
|
||||
return _normalize_refresh_mode(getattr(ep, "model_refresh_mode", None), endpoint_kind or _endpoint_kind(ep))
|
||||
|
||||
|
||||
def _endpoint_refresh_interval(ep: Any, category: str) -> float:
|
||||
raw = getattr(ep, "model_refresh_interval", None)
|
||||
try:
|
||||
val = int(raw) if raw is not None else 0
|
||||
except Exception:
|
||||
val = 0
|
||||
if val > 0:
|
||||
return float(max(30, val))
|
||||
return 60.0 if category == "local" else 3600.0
|
||||
|
||||
|
||||
def _endpoint_refresh_timeout(ep: Any, category: str) -> float:
|
||||
raw = getattr(ep, "model_refresh_timeout", None)
|
||||
try:
|
||||
val = int(raw) if raw is not None else 0
|
||||
except Exception:
|
||||
val = 0
|
||||
if val > 0:
|
||||
return float(max(1, min(30, val)))
|
||||
return 2.5 if category == "local" else 2.0
|
||||
|
||||
|
||||
def _manual_refresh_timeout(ep: Any, category: str, requested: Any = None) -> float:
|
||||
"""Timeout for explicit user-triggered model-list refreshes.
|
||||
|
||||
Background refreshes stay short. A manual refresh is the one path where a
|
||||
large proxy may legitimately need 15-30s to aggregate its catalog.
|
||||
"""
|
||||
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
|
||||
if requested_val is not None:
|
||||
return float(requested_val)
|
||||
stored = _parse_positive_int(getattr(ep, "model_refresh_timeout", None), minimum=1, maximum=60)
|
||||
if category == "local":
|
||||
return float(stored) if stored is not None else _endpoint_refresh_timeout(ep, category)
|
||||
return float(max(stored or 30, 30))
|
||||
|
||||
|
||||
def _parse_model_list(raw: Any) -> List[str]:
|
||||
"""Return a sanitized list of model ids from JSON/list/comma text."""
|
||||
if raw is None:
|
||||
return []
|
||||
value = raw
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, list):
|
||||
value = parsed
|
||||
else:
|
||||
value = re.split(r"[\n,]+", text)
|
||||
except Exception:
|
||||
value = re.split(r"[\n,]+", text)
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
out = []
|
||||
seen = set()
|
||||
for item in value:
|
||||
mid = str(item or "").strip()
|
||||
if not mid or mid in seen:
|
||||
continue
|
||||
seen.add(mid)
|
||||
out.append(mid)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_positive_int(raw: Any, *, minimum: int = 1, maximum: int = 86400) -> Optional[int]:
|
||||
try:
|
||||
val = int(str(raw).strip())
|
||||
except Exception:
|
||||
return None
|
||||
if val < minimum:
|
||||
return None
|
||||
return min(val, maximum)
|
||||
|
||||
|
||||
def _explicit_model_list_timeout(base_url: str, endpoint_kind: str = "auto", requested: Any = None) -> float:
|
||||
"""Timeout for explicit user-triggered model-list fetches during setup."""
|
||||
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
|
||||
if requested_val is not None:
|
||||
return float(requested_val)
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
category = _classify_endpoint(base_url, kind)
|
||||
if kind in ("api", "proxy") or category == "api":
|
||||
return 30.0
|
||||
return 3.0 if _is_ollama_base(base_url) else 2.0
|
||||
|
||||
|
||||
def _cached_model_ids(ep: Any) -> List[str]:
|
||||
return _parse_model_list(getattr(ep, "cached_models", None))
|
||||
|
||||
|
||||
def _hidden_model_ids(ep: Any) -> set:
|
||||
return set(_parse_model_list(getattr(ep, "hidden_models", None)))
|
||||
|
||||
|
||||
def _is_ollama_base(base_url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(base_url)
|
||||
host = (parsed.hostname or "").lower()
|
||||
return parsed.port == 11434 or "ollama" in host
|
||||
except Exception:
|
||||
return "ollama" in (base_url or "").lower()
|
||||
|
||||
|
||||
# Prefixes/substrings for models that are NOT chat-completions-capable
|
||||
_NON_CHAT_PREFIXES = (
|
||||
"dall-e", "tts-", "whisper", "text-embedding", "embedding",
|
||||
@@ -441,10 +576,15 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||
_TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.")
|
||||
|
||||
|
||||
def _classify_endpoint(base_url: str) -> str:
|
||||
def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
|
||||
"""Return 'local' if the endpoint URL points to a private/local address, else 'api'.
|
||||
Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted
|
||||
servers (e.g. Cookbook serve endpoints) get reachability-probed too."""
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
if kind == "local":
|
||||
return "local"
|
||||
if kind in ("api", "proxy"):
|
||||
return "api"
|
||||
try:
|
||||
host = urlparse(base_url).hostname or ""
|
||||
if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES):
|
||||
@@ -456,6 +596,21 @@ def _classify_endpoint(base_url: str) -> str:
|
||||
return "api"
|
||||
|
||||
|
||||
def _effective_endpoint_kind(ep: Any, base_url: str) -> str:
|
||||
"""Return explicit kind, with a legacy proxy heuristic for keyed /v1 URLs."""
|
||||
kind = _endpoint_kind(ep)
|
||||
if kind != "auto":
|
||||
return kind
|
||||
if getattr(ep, "api_key", None) and not _is_ollama_base(base_url):
|
||||
try:
|
||||
path = (urlparse(base_url).path or "").rstrip("/")
|
||||
if path.endswith("/v1") or "/openai" in path:
|
||||
return "proxy"
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
|
||||
|
||||
|
||||
def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]:
|
||||
"""Probe a base URL's /models endpoint and return list of model IDs.
|
||||
@@ -546,30 +701,18 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
||||
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the
|
||||
# base is the host root or the native /api root (e.g. http://localhost:11434,
|
||||
# http://localhost:11434/api) because /models lives under /v1 there. Treat
|
||||
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
|
||||
# than as a definitive offline verdict — Ollama is reachable, it just
|
||||
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
|
||||
# marks an alive Ollama as offline whenever cached_models is empty (issue
|
||||
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
|
||||
# _ping_endpoint() was returning before that fallback could run.
|
||||
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
|
||||
# /models as a generic health check because large proxy catalogs can be slow.
|
||||
parsed_base = urlparse(base)
|
||||
looks_like_ollama = (
|
||||
parsed_base.port == 11434
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
url = base + "/models"
|
||||
last_error: Optional[str] = None
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout)
|
||||
def _result_from_response(r) -> Dict[str, Any]:
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
if loc.startswith("/login") or "/login" in loc:
|
||||
@@ -579,13 +722,15 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
if r.status_code < 500 and not looks_like_ollama:
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
if 200 <= r.status_code < 300:
|
||||
return {
|
||||
"reachable": True,
|
||||
"status_code": r.status_code,
|
||||
"error": None,
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
|
||||
|
||||
last_error: Optional[str] = None
|
||||
|
||||
try:
|
||||
if looks_like_ollama:
|
||||
@@ -597,14 +742,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
for path in ("/api/version", "/api/tags"):
|
||||
try:
|
||||
r = httpx.get(root + path, timeout=timeout)
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
result = _result_from_response(r)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
last_error = result.get("error")
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = httpx.get(base, headers=headers, timeout=timeout)
|
||||
return _result_from_response(r)
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
@@ -715,17 +867,71 @@ def setup_model_routes(model_discovery):
|
||||
flip)."""
|
||||
_models_cache.clear()
|
||||
|
||||
# Track endpoints that have failed recently so we back off probing dead ones.
|
||||
_probe_failures = {} # ep_id → (last_fail_ts, consecutive_fails)
|
||||
# Track model-list refreshes by URL+key. This prevents repeated picker/API
|
||||
# opens from starting duplicate /models probes, and gives slow/offline
|
||||
# providers a cooldown after failures.
|
||||
_refresh_state: Dict[str, Dict[str, Any]] = {}
|
||||
_refresh_inflight = {"v": False} # coarse single-flight guard
|
||||
_REFRESH_FAILURE_BASE = 300.0
|
||||
_REFRESH_FAILURE_MAX = 3600.0
|
||||
|
||||
def _refresh_caches_bg():
|
||||
"""Background thread: re-probe all endpoints in PARALLEL with a tight
|
||||
timeout, skipping endpoints that have been failing repeatedly.
|
||||
def _refresh_key(base: str, api_key: Optional[str]) -> str:
|
||||
return f"{base.rstrip('/')}\x00{api_key or ''}"
|
||||
|
||||
Was the cause of gradual server degradation: sequential 3s-timeout
|
||||
probes against many endpoints (some offline) tied up the threadpool
|
||||
for 15-30s every cache cycle, eventually exhausting it."""
|
||||
def _ts(value: Any) -> float:
|
||||
try:
|
||||
return float(value.timestamp()) if value else 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _failure_delay(fails: int) -> float:
|
||||
if fails <= 0:
|
||||
return 0.0
|
||||
return min(_REFRESH_FAILURE_BASE * (2 ** max(0, fails - 1)), _REFRESH_FAILURE_MAX)
|
||||
|
||||
def _should_refresh_endpoint(ep: Any, now: float, force: bool = False) -> tuple[bool, Dict[str, Any]]:
|
||||
base = _normalize_base(getattr(ep, "base_url", "") or "")
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
mode = _endpoint_refresh_mode(ep, kind)
|
||||
cached = _cached_model_ids(ep)
|
||||
key = _refresh_key(base, getattr(ep, "api_key", None))
|
||||
state = _refresh_state.get(key, {})
|
||||
|
||||
info = {
|
||||
"id": getattr(ep, "id", ""),
|
||||
"base": base,
|
||||
"api_key": getattr(ep, "api_key", None),
|
||||
"kind": kind,
|
||||
"category": category,
|
||||
"mode": mode,
|
||||
"key": key,
|
||||
"timeout": _endpoint_refresh_timeout(ep, category),
|
||||
}
|
||||
if not base:
|
||||
return False, info
|
||||
if state.get("inflight"):
|
||||
return False, info
|
||||
if mode in ("manual", "disabled") and not force:
|
||||
return False, info
|
||||
fails = int(state.get("fail_count") or 0)
|
||||
if fails and not force:
|
||||
last_failure = float(state.get("last_failure") or 0.0)
|
||||
if now - last_failure < _failure_delay(fails):
|
||||
return False, info
|
||||
if cached and not force:
|
||||
interval = _endpoint_refresh_interval(ep, category)
|
||||
last_good = float(state.get("last_success") or 0.0) or _ts(getattr(ep, "updated_at", None)) or _ts(getattr(ep, "created_at", None))
|
||||
if last_good and now - last_good < interval:
|
||||
return False, info
|
||||
return True, info
|
||||
|
||||
def _refresh_caches_bg(force: bool = False):
|
||||
"""Background thread: safely refresh model caches with per-base single-flight.
|
||||
|
||||
The public /api/models path stays cached-first. This refresh never clears
|
||||
a non-empty cached model list on timeout/failure, and proxy/manual
|
||||
endpoints are skipped unless explicitly forced."""
|
||||
import threading
|
||||
if _refresh_inflight["v"]:
|
||||
return # already running
|
||||
@@ -735,44 +941,63 @@ def setup_model_routes(model_discovery):
|
||||
try:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
db = SessionLocal()
|
||||
changed = False
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
# Skip endpoints that have failed 3+ times in a row in the last 5 min
|
||||
now = _time.time()
|
||||
to_probe = []
|
||||
groups: Dict[str, Dict[str, Any]] = {}
|
||||
for ep in endpoints:
|
||||
ts, fails = _probe_failures.get(ep.id, (0, 0))
|
||||
if fails >= 3 and (now - ts) < 300:
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
to_probe.append(ep)
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
"timeout": info["timeout"],
|
||||
"endpoint_ids": [],
|
||||
})["endpoint_ids"].append(info["id"])
|
||||
|
||||
def _probe_one(ep):
|
||||
base = _normalize_base(ep.base_url)
|
||||
for key in groups:
|
||||
st = _refresh_state.setdefault(key, {})
|
||||
st["inflight"] = True
|
||||
st["last_attempt"] = now
|
||||
|
||||
def _probe_one(key: str, data: Dict[str, Any]):
|
||||
try:
|
||||
ids = _probe_endpoint(base, ep.api_key, timeout=2)
|
||||
return ep, ids, None
|
||||
ids = _probe_endpoint(data["base"], data.get("api_key"), timeout=data.get("timeout") or 2)
|
||||
return key, data["endpoint_ids"], ids, None
|
||||
except Exception as e:
|
||||
return ep, None, e
|
||||
return key, data["endpoint_ids"], None, e
|
||||
|
||||
if to_probe:
|
||||
# Bounded parallelism — 8 concurrent probes is plenty
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(to_probe))) as pool:
|
||||
futures = [pool.submit(_probe_one, ep) for ep in to_probe]
|
||||
if groups:
|
||||
with ThreadPoolExecutor(max_workers=min(4, len(groups))) as pool:
|
||||
futures = [pool.submit(_probe_one, key, data) for key, data in groups.items()]
|
||||
for fut in as_completed(futures):
|
||||
ep, ids, err = fut.result()
|
||||
key, endpoint_ids, ids, err = fut.result()
|
||||
st = _refresh_state.setdefault(key, {})
|
||||
if ids:
|
||||
ep.cached_models = json.dumps(ids)
|
||||
_probe_failures.pop(ep.id, None)
|
||||
for ep_id in endpoint_ids:
|
||||
ep_obj = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep_obj:
|
||||
ep_obj.cached_models = json.dumps(ids)
|
||||
changed = True
|
||||
st["last_success"] = _time.time()
|
||||
st["fail_count"] = 0
|
||||
st.pop("last_failure", None)
|
||||
else:
|
||||
prev = _probe_failures.get(ep.id, (0, 0))
|
||||
_probe_failures[ep.id] = (_time.time(), prev[1] + 1)
|
||||
st["last_failure"] = _time.time()
|
||||
st["fail_count"] = int(st.get("fail_count") or 0) + 1
|
||||
st["inflight"] = False
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
_invalidate_models_cache()
|
||||
if changed:
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
for st in _refresh_state.values():
|
||||
st["inflight"] = False
|
||||
_refresh_inflight["v"] = False
|
||||
threading.Thread(target=_do, daemon=True).start()
|
||||
|
||||
@@ -804,24 +1029,15 @@ def setup_model_routes(model_discovery):
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
# Use cached models — background refresh keeps them updated
|
||||
model_ids = []
|
||||
if ep.cached_models:
|
||||
try:
|
||||
model_ids = json.loads(ep.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
model_ids = _cached_model_ids(ep)
|
||||
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
||||
# Filter out hidden (probe-failed) models
|
||||
hidden = set()
|
||||
if ep.hidden_models:
|
||||
try:
|
||||
hidden = set(json.loads(ep.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
hidden = _hidden_model_ids(ep)
|
||||
model_ids = [m for m in model_ids if m not in hidden]
|
||||
# Build correct URL based on provider
|
||||
chat_url = build_chat_url(base)
|
||||
category = _classify_endpoint(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
|
||||
if model_ids:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
@@ -837,6 +1053,7 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_id": ep.id,
|
||||
"endpoint_name": ep.name,
|
||||
"category": category,
|
||||
"endpoint_kind": kind,
|
||||
"model_type": ep_model_type,
|
||||
})
|
||||
else:
|
||||
@@ -852,6 +1069,7 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_id": ep.id,
|
||||
"endpoint_name": ep.name,
|
||||
"category": category,
|
||||
"endpoint_kind": kind,
|
||||
"model_type": ep_model_type,
|
||||
"offline": True,
|
||||
})
|
||||
@@ -898,11 +1116,11 @@ def setup_model_routes(model_discovery):
|
||||
result = _fetch_models(owner=owner, is_admin=_is_admin)
|
||||
_models_cache[_cache_key] = {"data": result, "time": now}
|
||||
# Kick off background refresh to update caches from live endpoints
|
||||
_refresh_caches_bg()
|
||||
_refresh_caches_bg(force=refresh)
|
||||
return result
|
||||
|
||||
# Brief cache for local-probe results so picker-open doesn't hammer
|
||||
# /v1/models every time. 8s TTL — long enough to amortize cost,
|
||||
# endpoint health checks every time. 8s TTL — long enough to amortize cost,
|
||||
# short enough that a freshly-killed local server shows as offline
|
||||
# within ~8s of the user noticing.
|
||||
_LOCAL_PROBE_TTL = 8.0
|
||||
@@ -912,7 +1130,7 @@ def setup_model_routes(model_discovery):
|
||||
async def probe_local_endpoints(request: Request):
|
||||
"""Fast parallel reachability check for LOCAL endpoints only.
|
||||
Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are
|
||||
assumed up. Local endpoints get a 1.5s /models probe so the UI
|
||||
assumed up. Local endpoints get a 1.5s cheap reachability probe so the UI
|
||||
can dim stale entries pointing at dead vLLM servers. Returns
|
||||
{ep_id: {alive, latency_ms, error}}."""
|
||||
require_admin(request)
|
||||
@@ -924,36 +1142,44 @@ def setup_model_routes(model_discovery):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
local_eps = [
|
||||
(ep.id, _normalize_base(ep.base_url), ep.api_key)
|
||||
for ep in endpoints
|
||||
if _classify_endpoint(_normalize_base(ep.base_url)) == "local"
|
||||
]
|
||||
local_eps = []
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
if _classify_endpoint(base, kind) == "local":
|
||||
local_eps.append((ep.id, base, ep.api_key))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]:
|
||||
grouped: Dict[str, Dict[str, Any]] = {}
|
||||
for ep_id, base, api_key in local_eps:
|
||||
key = _refresh_key(base, api_key)
|
||||
grouped.setdefault(key, {"base": base, "api_key": api_key, "endpoint_ids": []})["endpoint_ids"].append(ep_id)
|
||||
|
||||
async def _probe_one(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
t0 = _time.time()
|
||||
try:
|
||||
models = _probe_endpoint(base, api_key, timeout=2.5)
|
||||
import asyncio as _asyncio
|
||||
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5)
|
||||
lat = round((_time.time() - t0) * 1000)
|
||||
return {
|
||||
"alive": bool(models),
|
||||
"alive": bool(ping.get("reachable")),
|
||||
"latency_ms": lat,
|
||||
"status_code": 200 if models else None,
|
||||
"error": None if models else "No models found",
|
||||
"status_code": ping.get("status_code"),
|
||||
"error": ping.get("error"),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
|
||||
|
||||
import asyncio as _asyncio
|
||||
results_list = await _asyncio.gather(
|
||||
*[_probe_one(eid, base, key) for eid, base, key in local_eps],
|
||||
*[_probe_one(data) for data in grouped.values()],
|
||||
return_exceptions=False,
|
||||
)
|
||||
results: Dict[str, Any] = {}
|
||||
for (eid, _, _), r in zip(local_eps, results_list):
|
||||
results[eid] = r
|
||||
for data, r in zip(grouped.values(), results_list):
|
||||
for eid in data["endpoint_ids"]:
|
||||
results[eid] = r
|
||||
|
||||
_local_probe_cache["data"] = results
|
||||
_local_probe_cache["time"] = now
|
||||
@@ -973,50 +1199,28 @@ def setup_model_routes(model_discovery):
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
cached_count = len(_cached_model_ids(ep))
|
||||
entry = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": base,
|
||||
"provider": provider,
|
||||
"category": _classify_endpoint(base),
|
||||
"category": _classify_endpoint(base, kind),
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
if provider == "anthropic":
|
||||
# Anthropic has no /models endpoint; just check connectivity
|
||||
try:
|
||||
t0 = _time.time()
|
||||
r = httpx.get(base.rstrip("/"), timeout=5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online"
|
||||
entry["model_count"] = len(ANTHROPIC_MODELS)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = 0
|
||||
else:
|
||||
url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
t0 = _time.time()
|
||||
r = httpx.get(url, headers=headers, timeout=5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
entry["status"] = "online"
|
||||
entry["model_count"] = len(models)
|
||||
except Exception as e:
|
||||
if "latency_ms" not in entry:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = 0
|
||||
try:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = cached_count
|
||||
results.append(entry)
|
||||
|
||||
return {"endpoints": results}
|
||||
@@ -1165,19 +1369,8 @@ def setup_model_routes(model_discovery):
|
||||
rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all()
|
||||
results = []
|
||||
for r in rows:
|
||||
# Use cached model list to avoid slow probe on every load
|
||||
all_models = []
|
||||
if r.cached_models:
|
||||
try:
|
||||
all_models = json.loads(r.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
hidden = set()
|
||||
if r.hidden_models:
|
||||
try:
|
||||
hidden = set(json.loads(r.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
all_models = _cached_model_ids(r)
|
||||
hidden = _hidden_model_ids(r)
|
||||
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
|
||||
visible = _visible_models(all_models, r.hidden_models, pinned)
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
@@ -1188,6 +1381,8 @@ def setup_model_routes(model_discovery):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
@@ -1202,6 +1397,11 @@ def setup_model_routes(model_discovery):
|
||||
"ping_error": (ping or {}).get("error") if ping else None,
|
||||
"model_type": getattr(r, "model_type", None) or "llm",
|
||||
"supports_tools": getattr(r, "supports_tools", None),
|
||||
"endpoint_kind": kind,
|
||||
"category": _classify_endpoint(base, kind),
|
||||
"model_refresh_mode": _endpoint_refresh_mode(r, kind),
|
||||
"model_refresh_interval": getattr(r, "model_refresh_interval", None),
|
||||
"model_refresh_timeout": getattr(r, "model_refresh_timeout", None),
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
@@ -1216,6 +1416,10 @@ def setup_model_routes(model_discovery):
|
||||
skip_probe: str = Form("false"),
|
||||
require_models: str = Form("false"),
|
||||
model_type: str = Form("llm"),
|
||||
endpoint_kind: str = Form("auto"),
|
||||
model_refresh_mode: str = Form(""),
|
||||
model_refresh_interval: str = Form(""),
|
||||
model_refresh_timeout: str = Form(""),
|
||||
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
||||
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
|
||||
container_local: str = Form("false"),
|
||||
@@ -1240,8 +1444,15 @@ def setup_model_routes(model_discovery):
|
||||
if not name.strip():
|
||||
name = base_url.replace("http://", "").replace("https://", "").split("/")[0]
|
||||
|
||||
requested_kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
refresh_mode = _normalize_refresh_mode(model_refresh_mode, requested_kind)
|
||||
refresh_interval = _parse_positive_int(model_refresh_interval, minimum=30, maximum=86400)
|
||||
refresh_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
|
||||
require_model_list = _truthy(require_models)
|
||||
should_probe = require_model_list or not _truthy(skip_probe)
|
||||
should_probe = (
|
||||
require_model_list or requested_kind in ("api", "proxy") or not _truthy(skip_probe)
|
||||
)
|
||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||
|
||||
# Dedupe: if an endpoint with the same base_url already exists and
|
||||
# is reachable by the caller (shared or owned by them), return it
|
||||
@@ -1259,6 +1470,7 @@ def setup_model_routes(model_discovery):
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
changed = False
|
||||
# Persist any incoming pinned IDs onto the existing row. An
|
||||
# empty/omitted form field must not wipe previously pinned IDs.
|
||||
_incoming_pinned = _normalize_model_ids(pinned_models)
|
||||
@@ -1268,15 +1480,45 @@ def setup_model_routes(model_discovery):
|
||||
_incoming_pinned,
|
||||
)
|
||||
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
|
||||
changed = True
|
||||
existing_kind_for_probe = requested_kind if requested_kind != "auto" else _effective_endpoint_kind(existing, base_url)
|
||||
if requested_kind != "auto" and _endpoint_kind(existing) == "auto":
|
||||
existing.endpoint_kind = requested_kind
|
||||
changed = True
|
||||
if model_refresh_mode or (requested_kind == "proxy" and _endpoint_refresh_mode(existing, requested_kind) != refresh_mode):
|
||||
existing.model_refresh_mode = refresh_mode
|
||||
changed = True
|
||||
if refresh_interval is not None:
|
||||
existing.model_refresh_interval = refresh_interval
|
||||
changed = True
|
||||
if refresh_timeout is not None:
|
||||
existing.model_refresh_timeout = refresh_timeout
|
||||
changed = True
|
||||
if api_key.strip() and not existing.api_key:
|
||||
existing.api_key = api_key.strip()
|
||||
changed = True
|
||||
if should_probe:
|
||||
probed_models = _probe_endpoint(
|
||||
base_url,
|
||||
(api_key.strip() or existing.api_key or None),
|
||||
timeout=_explicit_model_list_timeout(base_url, existing_kind_for_probe, refresh_timeout),
|
||||
)
|
||||
if probed_models:
|
||||
existing.cached_models = json.dumps(probed_models)
|
||||
changed = True
|
||||
if changed:
|
||||
_db_dedup.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
existing_models = _cached_model_ids(existing)
|
||||
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
|
||||
existing_kind = _effective_endpoint_kind(existing, existing.base_url)
|
||||
return {
|
||||
"id": existing.id,
|
||||
"name": existing.name,
|
||||
"base_url": existing.base_url,
|
||||
"models": _visible_models(
|
||||
getattr(existing, "cached_models", None),
|
||||
existing_models,
|
||||
getattr(existing, "hidden_models", None),
|
||||
existing.pinned_models,
|
||||
),
|
||||
@@ -1284,16 +1526,16 @@ def setup_model_routes(model_discovery):
|
||||
"online": True,
|
||||
"status": "online",
|
||||
"existing": True,
|
||||
"endpoint_kind": existing_kind,
|
||||
"category": _classify_endpoint(existing.base_url, existing_kind),
|
||||
}
|
||||
finally:
|
||||
_db_dedup.close()
|
||||
|
||||
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh)
|
||||
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=explicit_timeout) if should_probe else []
|
||||
ping = {"reachable": False, "error": None}
|
||||
if should_probe and not model_ids:
|
||||
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout)
|
||||
if (should_probe or requested_kind in ("api", "proxy")) and not model_ids:
|
||||
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=min(explicit_timeout, 2.0))
|
||||
if require_model_list and not model_ids:
|
||||
raise HTTPException(400, _model_endpoint_error_message(base_url, ping))
|
||||
|
||||
@@ -1317,6 +1559,10 @@ def setup_model_routes(model_discovery):
|
||||
api_key=api_key.strip() or None,
|
||||
is_enabled=True,
|
||||
model_type=model_type.strip() if model_type else "llm",
|
||||
endpoint_kind=requested_kind,
|
||||
model_refresh_mode=refresh_mode,
|
||||
model_refresh_interval=refresh_interval,
|
||||
model_refresh_timeout=refresh_timeout,
|
||||
cached_models=json.dumps(model_ids) if model_ids else None,
|
||||
pinned_models=json.dumps(_pinned) if _pinned else None,
|
||||
supports_tools=_st,
|
||||
@@ -1349,6 +1595,8 @@ def setup_model_routes(model_discovery):
|
||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
"endpoint_kind": requested_kind,
|
||||
"category": _classify_endpoint(base_url, requested_kind),
|
||||
}
|
||||
|
||||
@router.post("/model-endpoints/test")
|
||||
@@ -1356,6 +1604,8 @@ def setup_model_routes(model_discovery):
|
||||
request: Request,
|
||||
base_url: str = Form(...),
|
||||
api_key: str = Form(""),
|
||||
endpoint_kind: str = Form("auto"),
|
||||
model_refresh_timeout: str = Form(""),
|
||||
):
|
||||
require_admin(request)
|
||||
base_url = _normalize_base(base_url)
|
||||
@@ -1364,9 +1614,11 @@ def setup_model_routes(model_discovery):
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base_url = resolve_url(base_url)
|
||||
base_url = _rewrite_loopback_for_docker(base_url)
|
||||
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
|
||||
requested_kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
configured_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
|
||||
probe_timeout = _explicit_model_list_timeout(base_url, requested_kind, configured_timeout)
|
||||
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=min(probe_timeout, 2.0))
|
||||
return {
|
||||
"base_url": base_url,
|
||||
"online": bool(models) or bool(ping.get("reachable")),
|
||||
@@ -1374,6 +1626,8 @@ def setup_model_routes(model_discovery):
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
"endpoint_kind": requested_kind,
|
||||
"category": _classify_endpoint(base_url, requested_kind),
|
||||
}
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/probe")
|
||||
@@ -1415,7 +1669,8 @@ def setup_model_routes(model_discovery):
|
||||
ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep_obj:
|
||||
ep_obj.hidden_models = json.dumps(failed) if failed else None
|
||||
ep_obj.cached_models = json.dumps(all_models) if all_models else None
|
||||
if all_models:
|
||||
ep_obj.cached_models = json.dumps(all_models)
|
||||
db2.commit()
|
||||
finally:
|
||||
db2.close()
|
||||
@@ -1426,7 +1681,13 @@ def setup_model_routes(model_discovery):
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/models")
|
||||
def list_endpoint_models(ep_id: str, request: Request):
|
||||
def list_endpoint_models(
|
||||
ep_id: str,
|
||||
request: Request,
|
||||
response: Response,
|
||||
refresh: bool = False,
|
||||
refresh_timeout: Optional[int] = Query(None, ge=1, le=60),
|
||||
):
|
||||
"""List all discovered models for an endpoint with hidden/visible state."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
@@ -1434,23 +1695,28 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
hidden = set()
|
||||
if ep.hidden_models:
|
||||
hidden = _hidden_model_ids(ep)
|
||||
all_models = _cached_model_ids(ep)
|
||||
if refresh:
|
||||
base = _normalize_base(ep.base_url)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
hidden = set(json.loads(ep.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
# Try live probe, fall back to cached. Pinned IDs are admin-entered
|
||||
# and persist regardless of probe results — never overwritten here.
|
||||
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
|
||||
if all_models:
|
||||
ep.cached_models = json.dumps(all_models)
|
||||
db.commit()
|
||||
elif ep.cached_models:
|
||||
try:
|
||||
all_models = json.loads(ep.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
if probed:
|
||||
all_models = probed
|
||||
ep.cached_models = json.dumps(all_models)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
response.headers["X-Model-Refresh-Status"] = "refreshed"
|
||||
response.headers["X-Model-Refresh-Count"] = str(len(probed))
|
||||
else:
|
||||
response.headers["X-Model-Refresh-Status"] = "failed"
|
||||
response.headers["X-Model-Refresh-Warning"] = "Model refresh failed or returned no models; kept cached models."
|
||||
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
|
||||
pinned_set = set(pinned)
|
||||
return [
|
||||
@@ -1502,7 +1768,6 @@ def setup_model_routes(model_discovery):
|
||||
|
||||
@router.get("/default-chat")
|
||||
def get_default_chat(request: Request):
|
||||
import json as _json
|
||||
# SECURITY: resolve the default endpoint + model from the CALLER's
|
||||
# per-user prefs ONLY. We deliberately do NOT fall back to the
|
||||
# global `default_model` / `default_endpoint_id` in settings.json
|
||||
@@ -1635,6 +1900,16 @@ def setup_model_routes(model_discovery):
|
||||
if "pinned_models" in body:
|
||||
_pinned = _normalize_model_ids(body["pinned_models"])
|
||||
ep.pinned_models = json.dumps(_pinned) if _pinned else None
|
||||
if "endpoint_kind" in body:
|
||||
ep.endpoint_kind = _normalize_endpoint_kind(body.get("endpoint_kind"))
|
||||
if "model_refresh_mode" in body:
|
||||
ep.model_refresh_mode = _normalize_refresh_mode(body.get("model_refresh_mode"), _endpoint_kind(ep))
|
||||
if "model_refresh_interval" in body:
|
||||
interval = _parse_positive_int(body.get("model_refresh_interval"), minimum=30, maximum=86400)
|
||||
ep.model_refresh_interval = interval
|
||||
if "model_refresh_timeout" in body:
|
||||
timeout = _parse_positive_int(body.get("model_refresh_timeout"), minimum=1, maximum=60)
|
||||
ep.model_refresh_timeout = timeout
|
||||
# Rotating an API key used to require DELETE+POST, which wiped
|
||||
# endpoint_url/model from every session referencing the old base
|
||||
# URL. Allow in-place updates so the admin can change the key
|
||||
@@ -1664,6 +1939,10 @@ def setup_model_routes(model_discovery):
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||
"model_refresh_interval": getattr(ep, "model_refresh_interval", None),
|
||||
"model_refresh_timeout": getattr(ep, "model_refresh_timeout", None),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -743,8 +743,74 @@ def _normalize_anthropic_url(url: str) -> str:
|
||||
return url + "/messages"
|
||||
return url + "/v1/messages"
|
||||
|
||||
|
||||
def _model_list_base(url: str) -> str:
|
||||
"""Normalize model/chat URLs to the configured endpoint base."""
|
||||
base = (url or "").strip().rstrip("/")
|
||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
||||
if base.endswith(suffix):
|
||||
base = base[: -len(suffix)].rstrip("/")
|
||||
for suffix in ("/chat", "/tags", "/generate"):
|
||||
if base.endswith("/api" + suffix):
|
||||
base = base[: -len(suffix)].rstrip("/")
|
||||
return base
|
||||
|
||||
|
||||
def _parse_model_cache(raw) -> List[str]:
|
||||
if not raw:
|
||||
return []
|
||||
try:
|
||||
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(models, list):
|
||||
return []
|
||||
out = []
|
||||
seen = set()
|
||||
for item in models:
|
||||
mid = str(item or "").strip()
|
||||
if not mid or mid in seen:
|
||||
continue
|
||||
out.append(mid)
|
||||
seen.add(mid)
|
||||
return out
|
||||
|
||||
|
||||
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
"""Return cached models for a configured endpoint matching endpoint_url."""
|
||||
target = _model_list_base(endpoint_url)
|
||||
if not target:
|
||||
return []
|
||||
try:
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
except Exception:
|
||||
return []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in rows:
|
||||
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
||||
continue
|
||||
models = _parse_model_cache(getattr(ep, "cached_models", None) or getattr(ep, "models", None))
|
||||
if not models:
|
||||
continue
|
||||
hidden = set(_parse_model_cache(getattr(ep, "hidden_models", None)))
|
||||
return [m for m in models if m not in hidden]
|
||||
except Exception:
|
||||
return []
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
||||
"""List available model IDs from an endpoint."""
|
||||
cached = _configured_cached_model_ids(base_chat_url)
|
||||
if cached:
|
||||
return cached
|
||||
provider = _detect_provider(base_chat_url)
|
||||
if provider == "anthropic":
|
||||
return list(ANTHROPIC_MODELS)
|
||||
|
||||
@@ -6,6 +6,7 @@ Provides token estimation for context usage tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from urllib.parse import urlparse
|
||||
@@ -21,8 +22,55 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||
"172.30.", "172.31.", "192.168.", "100.")
|
||||
|
||||
|
||||
def _normalize_base_for_compare(url: str) -> str:
|
||||
url = (url or "").strip().rstrip("/")
|
||||
for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"):
|
||||
if url.endswith(suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
return url
|
||||
|
||||
|
||||
def _configured_endpoint_kind(url: str) -> Optional[str]:
|
||||
"""Return configured endpoint kind for a chat/base URL when available."""
|
||||
target = _normalize_base_for_compare(url)
|
||||
if not target:
|
||||
return None
|
||||
if "core.database" not in sys.modules:
|
||||
return None
|
||||
try:
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in rows:
|
||||
base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "")
|
||||
if not base:
|
||||
continue
|
||||
if target != base and not target.startswith(base + "/"):
|
||||
continue
|
||||
kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower()
|
||||
if kind in ("local", "api", "proxy"):
|
||||
return kind
|
||||
if getattr(ep, "api_key", None):
|
||||
parsed = urlparse(base)
|
||||
host = (parsed.hostname or "").lower()
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path):
|
||||
return "proxy"
|
||||
return "auto"
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _is_local_endpoint(url: str) -> bool:
|
||||
"""Check if URL points to a local/private/tailscale address."""
|
||||
kind = _configured_endpoint_kind(url)
|
||||
if kind in ("api", "proxy"):
|
||||
return False
|
||||
if kind == "local":
|
||||
return True
|
||||
try:
|
||||
host = urlparse(url).hostname or ""
|
||||
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
|
||||
@@ -170,6 +218,7 @@ def get_context_length(endpoint_url: str, model: str) -> int:
|
||||
or context_window fields. Caches result per model ID.
|
||||
Falls back to DEFAULT_CONTEXT if unavailable.
|
||||
"""
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
is_local = _is_local_endpoint(endpoint_url)
|
||||
if not is_local and model in _context_cache:
|
||||
return _context_cache[model]
|
||||
@@ -178,7 +227,7 @@ def get_context_length(endpoint_url: str, model: str) -> int:
|
||||
# Only cache non-default values to allow retry on next request.
|
||||
# Local endpoints can restart with a different --max-model-len while keeping
|
||||
# the same model id, so always re-query them instead of serving stale cache.
|
||||
if not is_local and ctx != DEFAULT_CONTEXT:
|
||||
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
||||
_context_cache[model] = ctx
|
||||
logger.info(f"Context length for {model}: {ctx}")
|
||||
return ctx
|
||||
@@ -207,6 +256,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
"""Query the model API for context length."""
|
||||
known = _lookup_known(model)
|
||||
api_ctx = None
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
|
||||
# Large OpenAI-compatible proxies can make /models expensive. If the
|
||||
# endpoint is explicitly configured as API/proxy, prefer known context
|
||||
# metadata (or the default) over downloading the full catalog.
|
||||
if configured_kind in ("api", "proxy"):
|
||||
if known:
|
||||
logger.info(f"Using known context window for {model}: {known}")
|
||||
return known
|
||||
return DEFAULT_CONTEXT
|
||||
|
||||
# Try llama.cpp /slots endpoint first — reports actual serving context
|
||||
if _is_local_endpoint(endpoint_url):
|
||||
|
||||
@@ -2079,6 +2079,10 @@
|
||||
</select>
|
||||
<div class="admin-model-form-row">
|
||||
<input id="adm-epApiKey" type="password" placeholder="API key">
|
||||
<select id="adm-epKind" style="padding:5px;width:82px;">
|
||||
<option value="proxy">Proxy</option>
|
||||
<option value="api">API</option>
|
||||
</select>
|
||||
<select id="adm-epType" style="padding:5px;width:80px;">
|
||||
<option value="llm">LLM</option>
|
||||
<option value="image">Image</option>
|
||||
|
||||
@@ -371,7 +371,7 @@ async function loadEndpoints() {
|
||||
const listLegacy = el('adm-epList');
|
||||
// Refresh model picker so new endpoints show up in chat
|
||||
if (window.modelsModule && window.modelsModule.refreshModels) {
|
||||
window.modelsModule.refreshModels(true);
|
||||
window.modelsModule.refreshModels();
|
||||
setTimeout(() => {
|
||||
if (window.sessionModule && window.sessionModule.updateModelPicker) {
|
||||
window.sessionModule.updateModelPicker();
|
||||
@@ -411,12 +411,15 @@ async function loadEndpoints() {
|
||||
? `<span class="admin-badge">${visibleCount}/${totalCount} models enabled</span>`
|
||||
: '<span class="admin-badge admin-badge-off">offline</span>';
|
||||
const justAddedClass = (_recentlyAddedEpId && String(ep.id) === _recentlyAddedEpId) ? ' adm-ep-just-added' : '';
|
||||
const category = ep.category || (_isLocalEndpoint(ep.base_url) ? 'local' : 'api');
|
||||
const kindLabel = ep.endpoint_kind && ep.endpoint_kind !== 'auto' ? ep.endpoint_kind.toUpperCase() : '';
|
||||
return `
|
||||
<div class="admin-user-row${ep.is_enabled ? '' : ' admin-ep-disabled'}${justAddedClass}" data-adm-ep-id="${ep.id}">
|
||||
<div style="display:flex;align-items:center;justify-content:space-between;${hasModels ? 'cursor:pointer;' : ''}padding:4px 0;" data-adm-ep-header="${ep.id}">
|
||||
<div class="admin-user-info" style="flex:1;flex-wrap:wrap;gap:0.3rem;">
|
||||
<span class="admin-user-name">${esc(ep.name)}</span>
|
||||
${ep.model_type === 'image' ? '<span class="admin-badge" style="background:color-mix(in srgb, var(--accent) 20%, transparent);color:var(--accent);">Image</span>' : ''}
|
||||
${kindLabel ? `<span class="admin-badge">${esc(kindLabel)}</span>` : ''}
|
||||
${statusBadge}
|
||||
${ep.is_enabled ? '' : '<span class="admin-badge admin-badge-off">disabled</span>'}
|
||||
${hasModels ? '<span style="font-size:10px;opacity:0.4;">Click to manage models</span>' : ''}
|
||||
@@ -427,7 +430,7 @@ async function loadEndpoints() {
|
||||
${hasModels ? '<svg class="admin-user-chevron" width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" style="opacity:0.3;transition:transform 0.2s,opacity 0.2s;"><polyline points="6 9 12 15 18 9"/></svg>' : ''}
|
||||
</div>
|
||||
</div>
|
||||
<div class="admin-ep-detail">${esc(ep.base_url)}${_isLocalEndpoint(ep.base_url) ? `<button type="button" class="admin-ep-copy-btn" data-adm-copy-url="${esc(ep.base_url)}" title="Copy URL" aria-label="Copy URL" style="background:none;border:none;padding:0 2px;margin-left:6px;cursor:pointer;color:inherit;opacity:0.45;vertical-align:-2px;line-height:1;"><svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg></button>` : ''}${ep.has_key ? ' (key set)' : ''}</div>
|
||||
<div class="admin-ep-detail">${esc(ep.base_url)}${category === 'local' ? `<button type="button" class="admin-ep-copy-btn" data-adm-copy-url="${esc(ep.base_url)}" title="Copy URL" aria-label="Copy URL" style="background:none;border:none;padding:0 2px;margin-left:6px;cursor:pointer;color:inherit;opacity:0.45;vertical-align:-2px;line-height:1;"><svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg></button>` : ''}${ep.has_key ? ' (key set)' : ''}</div>
|
||||
${hasModels ? `<div class="mcp-tools-panel hidden" data-adm-ep-models-panel="${ep.id}"></div>` : ''}
|
||||
</div>`;
|
||||
});
|
||||
@@ -446,7 +449,7 @@ async function loadEndpoints() {
|
||||
container.innerHTML = indices.map(i => rowHtml[i]).join('');
|
||||
};
|
||||
const localIdx = [], apiIdx = [];
|
||||
data.forEach((ep, i) => (_isLocalEndpoint(ep.base_url) ? localIdx : apiIdx).push(i));
|
||||
data.forEach((ep, i) => ((ep.category || (_isLocalEndpoint(ep.base_url) ? 'local' : 'api')) === 'local' ? localIdx : apiIdx).push(i));
|
||||
// Sort each section: enabled endpoints first, disabled at the bottom.
|
||||
// Preserve original order within each group via stable sort.
|
||||
const _sortByEnabled = (a, b) => Number(!!data[b].is_enabled) - Number(!!data[a].is_enabled);
|
||||
@@ -552,22 +555,48 @@ async function loadEndpoints() {
|
||||
} catch (_) {}
|
||||
panel.appendChild(_ld);
|
||||
const _stopSpin = () => { try { _modelsSpin && _modelsSpin.stop(); } catch (_) {} };
|
||||
try {
|
||||
const res = await fetch(`/api/model-endpoints/${epId}/models`, { credentials: 'same-origin' });
|
||||
const models = await res.json();
|
||||
_stopSpin();
|
||||
const _loadingHtml = (label) => `<span style="opacity:0.55;font-size:11px;display:inline-flex;align-items:center;gap:8px;">${esc(label)}</span>`;
|
||||
const renderModels = (models, warning = '') => {
|
||||
const sortedModels = sortModelObjects(models);
|
||||
if (!sortedModels.length) { panel.innerHTML = '<span style="opacity:0.5;font-size:11px;">No models</span>'; return; }
|
||||
const warningHtml = warning ? `<div class="admin-error" style="font-size:11px;margin:6px 0;">${esc(warning)}</div>` : '';
|
||||
const attachRefresh = () => {
|
||||
panel.querySelector(`[data-ep-refresh-models="${epId}"]`)?.addEventListener('click', async (e) => {
|
||||
e.preventDefault();
|
||||
panel.innerHTML = _loadingHtml('Refreshing models...');
|
||||
try {
|
||||
const res = await fetch(`/api/model-endpoints/${epId}/models?refresh=true&refresh_timeout=60`, { credentials: 'same-origin' });
|
||||
const refreshWarning = res.headers.get('X-Model-Refresh-Warning') || '';
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||
const refreshedModels = await res.json();
|
||||
renderModels(refreshedModels, refreshWarning);
|
||||
if (refreshWarning && uiModule?.showToast) uiModule.showToast(refreshWarning, 6000);
|
||||
} catch (_) {
|
||||
renderModels(sortedModels, 'Model refresh failed; kept cached models.');
|
||||
}
|
||||
});
|
||||
};
|
||||
if (!sortedModels.length) {
|
||||
panel.innerHTML = `<div class="mcp-tools-header">
|
||||
<span>Models</span>
|
||||
<span style="display:flex;gap:8px;align-items:center;">
|
||||
<span class="mcp-tools-count">0/0 enabled</span>
|
||||
<a href="#" data-ep-refresh-models="${epId}">Refresh</a>
|
||||
</span>
|
||||
</div>${warningHtml}<span style="opacity:0.5;font-size:11px;">No models</span>`;
|
||||
attachRefresh();
|
||||
return;
|
||||
}
|
||||
const hiddenSet = new Set(sortedModels.filter(m => m.is_hidden).map(m => m.id));
|
||||
const showSearch = sortedModels.length >= 8;
|
||||
panel.innerHTML = `<div class="mcp-tools-header">
|
||||
<span>Models</span>
|
||||
<span style="display:flex;gap:8px;align-items:center;">
|
||||
<span class="mcp-tools-count">${sortedModels.length - hiddenSet.size}/${sortedModels.length} enabled</span>
|
||||
<a href="#" data-ep-refresh-models="${epId}">Refresh</a>
|
||||
<a href="#" data-ep-select-all="${epId}">All</a>
|
||||
<a href="#" data-ep-select-none="${epId}">None</a>
|
||||
</span>
|
||||
</div>${showSearch ? `<input type="search" class="mcp-tools-search" placeholder="Search ${sortedModels.length} models..." data-ep-search="${epId}">` : ''}<div class="mcp-tools-list">` + sortedModels.map(m =>
|
||||
</div>${warningHtml}${showSearch ? `<input type="search" class="mcp-tools-search" placeholder="Search ${sortedModels.length} models..." data-ep-search="${epId}">` : ''}<div class="mcp-tools-list">` + sortedModels.map(m =>
|
||||
`<label title="${esc(m.id)}" data-ep-model-row data-search="${esc((m.display + ' ' + m.id).toLowerCase())}" class="adm-model-row">
|
||||
<input type="checkbox" class="adm-cb-hidden" data-ep-model-id="${esc(m.id)}" ${!m.is_hidden ? 'checked' : ''}>
|
||||
<span class="adm-check-dot" aria-hidden="true"></span>
|
||||
@@ -580,6 +609,7 @@ async function loadEndpoints() {
|
||||
row.style.display = (!needle || row.dataset.search.includes(needle)) ? '' : 'none';
|
||||
});
|
||||
};
|
||||
attachRefresh();
|
||||
panel.querySelector(`[data-ep-search="${epId}"]`)?.addEventListener('input', (e) => filterRows(e.target.value));
|
||||
panel.querySelector(`[data-ep-select-all="${epId}"]`)?.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
@@ -598,6 +628,13 @@ async function loadEndpoints() {
|
||||
panel.querySelectorAll('input[type=checkbox]').forEach(cb => {
|
||||
cb.addEventListener('change', () => _saveEpModelState(epId, panel));
|
||||
});
|
||||
};
|
||||
try {
|
||||
const res = await fetch(`/api/model-endpoints/${epId}/models`, { credentials: 'same-origin' });
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||
const models = await res.json();
|
||||
_stopSpin();
|
||||
renderModels(models);
|
||||
} catch (e) { _stopSpin(); panel.innerHTML = '<span class="admin-error" style="font-size:11px;">Failed to load models</span>'; }
|
||||
}
|
||||
});
|
||||
@@ -637,6 +674,7 @@ async function _saveEpModelState(epId, panel) {
|
||||
function initEndpointForm() {
|
||||
const provider = el('adm-epProvider');
|
||||
const urlInput = el('adm-epUrl');
|
||||
const kindSel = el('adm-epKind');
|
||||
|
||||
// Custom provider picker — mirrors the (now hidden) <select id="adm-epProvider">
|
||||
// so the rest of this function (which reads provider.value and dispatches
|
||||
@@ -688,14 +726,20 @@ function initEndpointForm() {
|
||||
provider.addEventListener('change', () => {
|
||||
if (provider.value) urlInput.value = provider.value;
|
||||
else urlInput.value = '';
|
||||
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
|
||||
});
|
||||
urlInput.addEventListener('input', () => {
|
||||
if (provider.value && urlInput.value.trim() !== provider.value) {
|
||||
provider.value = '';
|
||||
if (kindSel) kindSel.value = 'proxy';
|
||||
_renderPickerMenu();
|
||||
_syncPickerCurrent();
|
||||
}
|
||||
});
|
||||
if (kindSel) kindSel.value = provider.value ? 'api' : (kindSel.value || 'proxy');
|
||||
function _apiEndpointKind() {
|
||||
return (kindSel && kindSel.value) ? kindSel.value : (provider.value ? 'api' : 'proxy');
|
||||
}
|
||||
function _normalizeBaseUrl(raw) {
|
||||
let u = raw.trim();
|
||||
// Fix common protocol typos
|
||||
@@ -784,6 +828,8 @@ function initEndpointForm() {
|
||||
try {
|
||||
const fd = new FormData();
|
||||
fd.append('base_url', url);
|
||||
fd.append('endpoint_kind', _apiEndpointKind());
|
||||
fd.append('model_refresh_timeout', '30');
|
||||
if (apiKey) fd.append('api_key', apiKey);
|
||||
const res = await fetch('/api/model-endpoints/test', {
|
||||
method: 'POST',
|
||||
@@ -828,6 +874,10 @@ function initEndpointForm() {
|
||||
try {
|
||||
const fd = new FormData();
|
||||
fd.append('base_url', url);
|
||||
const endpointKind = _apiEndpointKind();
|
||||
fd.append('endpoint_kind', endpointKind);
|
||||
fd.append('model_refresh_mode', endpointKind === 'proxy' ? 'manual' : 'auto');
|
||||
fd.append('model_refresh_timeout', '30');
|
||||
if (apiKey) fd.append('api_key', apiKey);
|
||||
if (provider.value && provider.selectedOptions && provider.selectedOptions[0]) {
|
||||
fd.append('name', provider.selectedOptions[0].textContent.trim());
|
||||
@@ -842,6 +892,7 @@ function initEndpointForm() {
|
||||
const count = d.models ? d.models.length : 0;
|
||||
urlInput.value = ''; urlInput.style.display = '';
|
||||
el('adm-epApiKey').value = ''; provider.value = '';
|
||||
if (kindSel) kindSel.value = 'proxy';
|
||||
if (epType) epType.value = 'llm';
|
||||
if (d.id) _recentlyAddedEpId = String(d.id);
|
||||
await loadEndpoints();
|
||||
@@ -904,6 +955,8 @@ function initEndpointForm() {
|
||||
const fd = new FormData();
|
||||
fd.append('base_url', url);
|
||||
if (apiKey) fd.append('api_key', apiKey);
|
||||
fd.append('endpoint_kind', 'local');
|
||||
fd.append('model_refresh_mode', 'auto');
|
||||
const lt = el('adm-epLocalType');
|
||||
if (lt) fd.append('model_type', lt.value);
|
||||
fd.append('skip_probe', 'false');
|
||||
@@ -986,6 +1039,8 @@ function initEndpointForm() {
|
||||
const base = item.url.replace('/chat/completions', '').replace(/\/$/, '');
|
||||
const fd = new FormData();
|
||||
fd.append('base_url', base);
|
||||
fd.append('endpoint_kind', 'local');
|
||||
fd.append('model_refresh_mode', 'auto');
|
||||
fd.append('skip_probe', 'false');
|
||||
const r = await fetch('/api/model-endpoints', { method: 'POST', body: fd });
|
||||
if (r.ok) {
|
||||
|
||||
@@ -561,7 +561,7 @@ function _initModelPickerDropdown() {
|
||||
menu.classList.remove('closing', 'hidden');
|
||||
_populate('');
|
||||
if (window.modelsModule && window.modelsModule.refreshModels) {
|
||||
window.modelsModule.refreshModels(true).then(() => {
|
||||
window.modelsModule.refreshModels().then(() => {
|
||||
if (!menu.classList.contains('hidden')) _populate(search.value || '');
|
||||
updateModelPicker();
|
||||
}).catch(() => {});
|
||||
|
||||
@@ -16,6 +16,7 @@ import { sortModelIds } from './modelSort.js';
|
||||
let API_BASE = '';
|
||||
let _cachedItems = []; // cached /api/models items for model-switch dropdown
|
||||
let _lastFetchTime = 0;
|
||||
let _fetchInflight = null;
|
||||
const _FETCH_CACHE_TTL = 30000; // 30s client-side cache for /api/models
|
||||
const COLLAPSE_KEY = 'odysseus-models-collapsed';
|
||||
const FAVORITES_KEY = 'odysseus-model-favorites';
|
||||
@@ -176,8 +177,15 @@ export async function refreshModels(force = false) {
|
||||
box.appendChild(_loadingSpinner.createElement());
|
||||
_loadingSpinner.start();
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/models`);
|
||||
const data = await res.json();
|
||||
if (!_fetchInflight) {
|
||||
_fetchInflight = fetch(`${API_BASE}/api/models`, { credentials: 'same-origin' })
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||
return res.json();
|
||||
})
|
||||
.finally(() => { _fetchInflight = null; });
|
||||
}
|
||||
const data = await _fetchInflight;
|
||||
_lastFetchTime = Date.now();
|
||||
_cachedItems = data.items || [];
|
||||
} catch (e) {
|
||||
|
||||
@@ -1,11 +1,59 @@
|
||||
"""Tests for model_context.py — local endpoint detection, token estimation, known model lookup."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
import src.model_context as model_context
|
||||
from src.model_context import _is_local_endpoint, estimate_tokens, _lookup_known
|
||||
|
||||
|
||||
class _Column:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, value):
|
||||
return ("eq", self.name, value)
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
|
||||
|
||||
class _Query:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
for condition in conditions:
|
||||
if isinstance(condition, tuple) and condition[0] == "eq":
|
||||
_, field, value = condition
|
||||
self.rows = [row for row in self.rows if getattr(row, field) == value]
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _Db:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
|
||||
def query(self, model):
|
||||
return _Query(self.rows)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def _install_endpoint_db(monkeypatch, rows):
|
||||
mod = types.ModuleType("core.database")
|
||||
mod.ModelEndpoint = _ModelEndpoint
|
||||
mod.SessionLocal = lambda: _Db(rows)
|
||||
monkeypatch.setitem(sys.modules, "core.database", mod)
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
def test_localhost(self):
|
||||
assert _is_local_endpoint("http://localhost:5000/v1/chat/completions") is True
|
||||
@@ -23,6 +71,18 @@ class TestIsLocalEndpoint:
|
||||
# 100.64.0.0/10 is the CGNAT range Tailscale uses.
|
||||
assert _is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True
|
||||
|
||||
def test_configured_tailscale_proxy_is_remote(self, monkeypatch):
|
||||
_install_endpoint_db(monkeypatch, [
|
||||
types.SimpleNamespace(
|
||||
base_url="http://100.117.136.97:34521/v1",
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
is_enabled=True,
|
||||
)
|
||||
])
|
||||
|
||||
assert _is_local_endpoint("http://100.117.136.97:34521/v1/chat/completions") is False
|
||||
|
||||
def test_openai_is_remote(self):
|
||||
assert _is_local_endpoint("https://api.openai.com/v1/chat/completions") is False
|
||||
|
||||
@@ -164,3 +224,28 @@ class TestGetContextLength:
|
||||
assert first == 200000
|
||||
assert second == 200000
|
||||
assert len(calls) == 1
|
||||
|
||||
def test_configured_proxy_uses_default_without_model_listing(self, monkeypatch):
|
||||
_install_endpoint_db(monkeypatch, [
|
||||
types.SimpleNamespace(
|
||||
base_url="http://100.117.136.97:34521/v1",
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
is_enabled=True,
|
||||
)
|
||||
])
|
||||
calls = []
|
||||
|
||||
def fake_get(*args, **kwargs):
|
||||
calls.append(args)
|
||||
raise AssertionError("/models should not be queried for configured proxy context")
|
||||
|
||||
monkeypatch.setattr(model_context.httpx, "get", fake_get)
|
||||
|
||||
endpoint = "http://100.117.136.97:34521/v1/chat/completions"
|
||||
first = model_context.get_context_length(endpoint, "unknown-proxy-model")
|
||||
second = model_context.get_context_length(endpoint, "unknown-proxy-model")
|
||||
|
||||
assert first == model_context.DEFAULT_CONTEXT
|
||||
assert second == model_context.DEFAULT_CONTEXT
|
||||
assert calls == []
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -28,7 +30,9 @@ if "core.database" not in sys.modules:
|
||||
sys.modules["core.database"] = _core_db
|
||||
|
||||
import routes.model_routes as model_routes
|
||||
import src.database as src_database
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
import src.llm_core as llm_core
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
@@ -36,7 +40,11 @@ from routes.model_routes import (
|
||||
_normalize_model_ids,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_effective_endpoint_kind,
|
||||
_probe_endpoint,
|
||||
_ping_endpoint,
|
||||
_parse_model_list,
|
||||
_normalize_refresh_mode,
|
||||
_truthy,
|
||||
_speech_settings_using_endpoint,
|
||||
_clear_speech_settings_for_endpoint,
|
||||
@@ -304,6 +312,54 @@ class TestClassifyEndpoint:
|
||||
def test_malformed_url(self):
|
||||
assert _classify_endpoint("not-a-url") == "api"
|
||||
|
||||
def test_tailscale_auto_is_local(self):
|
||||
assert _classify_endpoint("http://100.117.136.97:34521/v1") == "local"
|
||||
|
||||
def test_tailscale_proxy_override_is_api(self):
|
||||
assert _classify_endpoint("http://100.117.136.97:34521/v1", "proxy") == "api"
|
||||
|
||||
def test_tailscale_api_override_is_api(self):
|
||||
assert _classify_endpoint("http://100.117.136.97:34521/v1", "api") == "api"
|
||||
|
||||
def test_public_local_override_is_local(self):
|
||||
assert _classify_endpoint("https://api.openai.com/v1", "local") == "local"
|
||||
|
||||
def test_keyed_legacy_v1_endpoint_is_effective_proxy(self):
|
||||
ep = SimpleNamespace(endpoint_kind="auto", api_key="fake-key")
|
||||
assert _effective_endpoint_kind(ep, "http://100.117.136.97:34521/v1") == "proxy"
|
||||
|
||||
def test_proxy_refresh_mode_defaults_manual(self):
|
||||
assert _normalize_refresh_mode("", "proxy") == "manual"
|
||||
assert _normalize_refresh_mode("auto", "proxy") == "manual"
|
||||
assert _normalize_refresh_mode("manual", "proxy") == "manual"
|
||||
assert _normalize_refresh_mode("auto", "api") == "auto"
|
||||
|
||||
def test_parse_model_list_accepts_json_and_text(self):
|
||||
assert _parse_model_list('["a", "b", "a"]') == ["a", "b"]
|
||||
assert _parse_model_list("a, b\nc") == ["a", "b", "c"]
|
||||
|
||||
def test_ping_endpoint_does_not_request_models_for_openai_style_proxy(self, monkeypatch):
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda url: url, raising=False)
|
||||
seen = []
|
||||
|
||||
def fake_head(*args, **kwargs):
|
||||
raise AssertionError("generic proxy health check should not use HEAD")
|
||||
|
||||
def fake_get(url, headers=None, timeout=None):
|
||||
seen.append(("GET", url))
|
||||
request = httpx.Request("GET", url)
|
||||
return httpx.Response(200, request=request)
|
||||
|
||||
monkeypatch.setattr(model_routes.httpx, "head", fake_head)
|
||||
monkeypatch.setattr(model_routes.httpx, "get", fake_get)
|
||||
|
||||
result = _ping_endpoint("http://100.117.136.97:34521/v1", "fake-key", timeout=1)
|
||||
|
||||
assert result["reachable"] is True
|
||||
assert result["status_code"] == 200
|
||||
assert seen == [("GET", "http://100.117.136.97:34521/v1")]
|
||||
assert all(not url.endswith("/models") for _, url in seen)
|
||||
|
||||
|
||||
# ── setup probing ──
|
||||
|
||||
@@ -534,77 +590,51 @@ if "python_multipart" not in sys.modules:
|
||||
sys.modules["python_multipart"] = _mp_stub
|
||||
|
||||
|
||||
class _PinnedFakeQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
return self
|
||||
|
||||
def order_by(self, *args):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _PinnedFakeDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.added = []
|
||||
self.committed = 0
|
||||
|
||||
def query(self, model):
|
||||
return _PinnedFakeQuery(self.rows)
|
||||
|
||||
def add(self, row):
|
||||
self.added.append(row)
|
||||
|
||||
def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeCol:
|
||||
"""Column stand-in: every comparison/operator just returns itself so the
|
||||
dedupe query expressions evaluate without a real SQLAlchemy column."""
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self
|
||||
|
||||
def is_(self, other):
|
||||
return self
|
||||
class _RouteCondition:
|
||||
def __init__(self, op, field, value):
|
||||
self.op = op
|
||||
self.field = field
|
||||
self.value = value
|
||||
|
||||
def __or__(self, other):
|
||||
return self
|
||||
return ("or", self, other)
|
||||
|
||||
|
||||
class _RouteColumn:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, value):
|
||||
return _RouteCondition("eq", self.name, value)
|
||||
|
||||
def is_(self, value):
|
||||
return _RouteCondition("eq", self.name, value)
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _RecordingEndpoint:
|
||||
class _RouteModelEndpoint:
|
||||
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
|
||||
|
||||
Class-level fake columns let it double as the query class in the dedupe
|
||||
lookup; instance attributes (set in __init__) shadow them per-row.
|
||||
"""
|
||||
|
||||
id = _FakeCol()
|
||||
base_url = _FakeCol()
|
||||
owner = _FakeCol()
|
||||
id = _RouteColumn("id")
|
||||
base_url = _RouteColumn("base_url")
|
||||
is_enabled = _RouteColumn("is_enabled")
|
||||
owner = _RouteColumn("owner")
|
||||
created_at = _RouteColumn("created_at")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
_RecordingEndpoint = _RouteModelEndpoint
|
||||
|
||||
|
||||
class _PinnedFakeRequest:
|
||||
def __init__(self, body=None, headers=None):
|
||||
self._body = body if body is not None else {}
|
||||
@@ -635,6 +665,13 @@ def _make_endpoint(**kwargs):
|
||||
pinned_models=None,
|
||||
model_type="llm",
|
||||
supports_tools=None,
|
||||
endpoint_kind="auto",
|
||||
model_refresh_mode="auto",
|
||||
model_refresh_interval=None,
|
||||
model_refresh_timeout=None,
|
||||
owner=None,
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
@@ -676,7 +713,7 @@ def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
|
||||
|
||||
result = endpoint("ep1", _PinnedFakeRequest())
|
||||
result = endpoint("ep1", _PinnedFakeRequest(), SimpleNamespace(headers={}))
|
||||
|
||||
ids = [row["id"] for row in result]
|
||||
assert ids == ["deploy-1"]
|
||||
@@ -730,6 +767,10 @@ def _create_form_kwargs(**overrides):
|
||||
skip_probe="true", # avoid any network probe in unit tests
|
||||
require_models="false",
|
||||
model_type="llm",
|
||||
endpoint_kind="auto",
|
||||
model_refresh_mode="",
|
||||
model_refresh_interval="",
|
||||
model_refresh_timeout="",
|
||||
supports_tools="",
|
||||
pinned_models="",
|
||||
container_local="false",
|
||||
@@ -772,6 +813,7 @@ def test_post_creates_endpoint_with_pinned_models(monkeypatch):
|
||||
|
||||
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
base_url="http://host:1234/v1",
|
||||
cached_models=json.dumps(["m1"]),
|
||||
hidden_models=None,
|
||||
pinned_models=json.dumps(["old-pin"]),
|
||||
@@ -798,6 +840,7 @@ def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||
|
||||
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
base_url="http://host:1234/v1",
|
||||
cached_models=json.dumps(["m1"]),
|
||||
pinned_models=json.dumps(["keep-me"]),
|
||||
)
|
||||
@@ -814,3 +857,464 @@ def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||
assert json.loads(existing.pinned_models) == ["keep-me"]
|
||||
assert result["pinned_models"] == ["keep-me"]
|
||||
assert db.committed == 0 # nothing to persist
|
||||
class _RouteQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
for condition in conditions:
|
||||
if isinstance(condition, _RouteCondition) and condition.op == "eq":
|
||||
self.rows = [row for row in self.rows if getattr(row, condition.field, None) == condition.value]
|
||||
elif isinstance(condition, tuple) and condition and condition[0] == "or":
|
||||
keep = []
|
||||
for row in self.rows:
|
||||
matched = False
|
||||
for part in condition[1:]:
|
||||
if isinstance(part, _RouteCondition) and part.op == "eq":
|
||||
matched = matched or (getattr(row, part.field, None) == part.value)
|
||||
if matched:
|
||||
keep.append(row)
|
||||
self.rows = keep
|
||||
return self
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
def first(self):
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
|
||||
class _RouteDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.added = []
|
||||
self.committed = 0
|
||||
self.commits = 0
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
return _RouteQuery(self.rows)
|
||||
|
||||
def commit(self):
|
||||
self.committed += 1
|
||||
self.commits += 1
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def add(self, row):
|
||||
self.rows.append(row)
|
||||
self.added.append(row)
|
||||
|
||||
|
||||
_PinnedFakeDb = _RouteDb
|
||||
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target, daemon=None):
|
||||
self.target = target
|
||||
|
||||
def start(self):
|
||||
self.target()
|
||||
|
||||
|
||||
def _wait_for(predicate, timeout=2.0):
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
time.sleep(0.01)
|
||||
return bool(predicate())
|
||||
|
||||
|
||||
def _route_endpoint(router, path, method="GET"):
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{method} {path} route not found")
|
||||
|
||||
|
||||
def _route_ep(
|
||||
id,
|
||||
base_url,
|
||||
*,
|
||||
cached_models=None,
|
||||
endpoint_kind="auto",
|
||||
api_key=None,
|
||||
name=None,
|
||||
pinned_models=None,
|
||||
refresh_mode="auto",
|
||||
refresh_timeout=None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
id=id,
|
||||
name=name or id,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
is_enabled=True,
|
||||
cached_models=json.dumps(cached_models) if cached_models is not None else None,
|
||||
hidden_models=None,
|
||||
pinned_models=json.dumps(pinned_models) if pinned_models is not None else None,
|
||||
model_type="llm",
|
||||
endpoint_kind=endpoint_kind,
|
||||
model_refresh_mode=refresh_mode,
|
||||
model_refresh_interval=None,
|
||||
model_refresh_timeout=refresh_timeout,
|
||||
supports_tools=None,
|
||||
owner=None,
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _route_request():
|
||||
return SimpleNamespace(
|
||||
state=SimpleNamespace(current_user=None),
|
||||
app=SimpleNamespace(state=SimpleNamespace(auth_manager=None)),
|
||||
)
|
||||
|
||||
|
||||
def test_api_models_returns_cached_proxy_models_without_refresh_probe(monkeypatch):
|
||||
row = _route_ep(
|
||||
"proxy",
|
||||
"http://100.117.136.97:34521/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
refresh_mode="manual",
|
||||
)
|
||||
db = _RouteDb([row])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
|
||||
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
|
||||
|
||||
def fail_probe(*args, **kwargs):
|
||||
raise AssertionError("/models probe should not run for cached manual proxy")
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fail_probe)
|
||||
monkeypatch.setattr(threading, "Thread", _ImmediateThread)
|
||||
|
||||
result = _route_endpoint(router, "/api/models")(_route_request())
|
||||
|
||||
assert result["items"][0]["models"] == ["cached-model"]
|
||||
assert result["items"][0]["category"] == "api"
|
||||
assert result["items"][0]["endpoint_kind"] == "proxy"
|
||||
assert "offline" not in result["items"][0]
|
||||
assert json.loads(row.cached_models) == ["cached-model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_local_skips_tailscale_proxy_endpoint(monkeypatch):
|
||||
proxy = _route_ep(
|
||||
"proxy",
|
||||
"http://100.117.136.97:34521/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
)
|
||||
local = _route_ep("local", "http://127.0.0.1:8000/v1", endpoint_kind="local")
|
||||
db = _RouteDb([proxy, local])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("full probe should not run")))
|
||||
|
||||
pinged = []
|
||||
|
||||
def fake_ping(base_url, api_key=None, timeout=1.5):
|
||||
pinged.append(base_url)
|
||||
return {"reachable": True, "status_code": 404, "error": "HTTP 404"}
|
||||
|
||||
monkeypatch.setattr(model_routes, "_ping_endpoint", fake_ping)
|
||||
|
||||
result = await _route_endpoint(router, "/api/model-endpoints/probe-local")(_route_request())
|
||||
|
||||
assert set(result) == {"local"}
|
||||
assert pinged == ["http://127.0.0.1:8000/v1"]
|
||||
|
||||
|
||||
def test_background_refresh_deduplicates_same_base_url(monkeypatch):
|
||||
ep1 = _route_ep("a", "http://127.0.0.1:8000/v1", endpoint_kind="local")
|
||||
ep2 = _route_ep("b", "http://127.0.0.1:8000/v1", endpoint_kind="local")
|
||||
db = _RouteDb([ep1, ep2])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
|
||||
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
|
||||
|
||||
calls = []
|
||||
probe_done = threading.Event()
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
calls.append(base_url)
|
||||
probe_done.set()
|
||||
return ["live-model"]
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
_route_endpoint(router, "/api/models")(_route_request(), refresh=True)
|
||||
|
||||
assert probe_done.wait(2)
|
||||
assert _wait_for(lambda: ep1.cached_models and ep2.cached_models)
|
||||
assert calls == ["http://127.0.0.1:8000/v1"]
|
||||
assert json.loads(ep1.cached_models) == ["live-model"]
|
||||
assert json.loads(ep2.cached_models) == ["live-model"]
|
||||
|
||||
|
||||
def test_background_refresh_failure_keeps_existing_cached_models(monkeypatch):
|
||||
ep = _route_ep(
|
||||
"local",
|
||||
"http://127.0.0.1:8000/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="local",
|
||||
)
|
||||
db = _RouteDb([ep])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "_auth_disabled", lambda: True)
|
||||
monkeypatch.setattr(model_routes, "build_chat_url", lambda base: f"{base}/chat/completions")
|
||||
probe_done = threading.Event()
|
||||
|
||||
def fake_probe(*args, **kwargs):
|
||||
probe_done.set()
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
result = _route_endpoint(router, "/api/models")(_route_request(), refresh=True)
|
||||
|
||||
assert probe_done.wait(2)
|
||||
assert _wait_for(lambda: db.commits > 0)
|
||||
assert result["items"][0]["models"] == ["cached-model"]
|
||||
assert json.loads(ep.cached_models) == ["cached-model"]
|
||||
|
||||
|
||||
def test_llm_core_list_model_ids_uses_cached_configured_proxy(monkeypatch):
|
||||
ep = _route_ep(
|
||||
"proxy",
|
||||
"http://100.117.136.97:34521/v1",
|
||||
cached_models=["cached-model", "hidden-model"],
|
||||
endpoint_kind="proxy",
|
||||
)
|
||||
ep.hidden_models = json.dumps(["hidden-model"])
|
||||
db = _RouteDb([ep])
|
||||
|
||||
monkeypatch.setattr(src_database, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(src_database, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(llm_core.httpx, "get", lambda *a, **k: (_ for _ in ()).throw(AssertionError("/models should not be fetched")))
|
||||
|
||||
assert llm_core.list_model_ids("http://100.117.136.97:34521/v1/chat/completions", timeout=1) == ["cached-model"]
|
||||
|
||||
|
||||
def test_explicit_proxy_test_fetches_models_with_long_timeout(monkeypatch):
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_ping_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("ping should not run when model listing succeeds")))
|
||||
|
||||
calls = []
|
||||
returned = ["NVIDIA NIM/openai/gpt-oss-120b", "mistral/mistral-small-2603"]
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
|
||||
return returned
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
result = _route_endpoint(router, "/api/model-endpoints/test", "POST")(
|
||||
_route_request(),
|
||||
base_url="http://100.117.136.97:34521/v1",
|
||||
api_key="fake-key",
|
||||
endpoint_kind="proxy",
|
||||
)
|
||||
|
||||
assert result["online"] is True
|
||||
assert result["status"] == "online"
|
||||
assert result["models"] == returned
|
||||
assert calls == [{
|
||||
"base_url": "http://100.117.136.97:34521/v1",
|
||||
"api_key": "fake-key",
|
||||
"timeout": 30.0,
|
||||
}]
|
||||
|
||||
|
||||
def test_explicit_proxy_add_fetches_and_caches_models_with_long_timeout(monkeypatch):
|
||||
db = _RouteDb([])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_load_settings", lambda: {})
|
||||
monkeypatch.setattr(model_routes, "_save_settings", lambda settings: None)
|
||||
monkeypatch.setattr("src.auth_helpers.get_current_user", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_ping_endpoint", lambda *a, **k: (_ for _ in ()).throw(AssertionError("ping should not run when model listing succeeds")))
|
||||
|
||||
calls = []
|
||||
returned = ["NVIDIA NIM/openai/gpt-oss-120b", "mistral/mistral-small-2603"]
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
|
||||
return returned
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
result = _route_endpoint(router, "/api/model-endpoints", "POST")(
|
||||
_route_request(),
|
||||
name="Bifrost",
|
||||
base_url="http://100.117.136.97:34521/v1",
|
||||
api_key="fake-key",
|
||||
skip_probe="true",
|
||||
require_models="false",
|
||||
model_type="llm",
|
||||
endpoint_kind="proxy",
|
||||
model_refresh_mode="manual",
|
||||
model_refresh_interval="",
|
||||
model_refresh_timeout="",
|
||||
supports_tools="",
|
||||
container_local="false",
|
||||
shared="true",
|
||||
)
|
||||
|
||||
assert result["online"] is True
|
||||
assert result["status"] == "online"
|
||||
assert result["models"] == returned
|
||||
assert calls == [{
|
||||
"base_url": "http://100.117.136.97:34521/v1",
|
||||
"api_key": "fake-key",
|
||||
"timeout": 30.0,
|
||||
}]
|
||||
assert len(db.rows) == 1
|
||||
assert json.loads(db.rows[0].cached_models) == returned
|
||||
assert db.rows[0].endpoint_kind == "proxy"
|
||||
assert db.rows[0].model_refresh_mode == "manual"
|
||||
|
||||
|
||||
def test_manual_refresh_uses_long_timeout_and_saves_full_model_list(monkeypatch):
|
||||
ep = _route_ep(
|
||||
"proxy",
|
||||
"http://100.117.136.97:34521/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
refresh_mode="manual",
|
||||
)
|
||||
db = _RouteDb([ep])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
|
||||
calls = []
|
||||
refreshed = ["cached-model", "mistral/mistral-small-2603", "provider/nested/model/id"]
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
calls.append({"base_url": base_url, "api_key": api_key, "timeout": timeout})
|
||||
return refreshed
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
response = SimpleNamespace(headers={})
|
||||
result = _route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
|
||||
"proxy",
|
||||
_route_request(),
|
||||
response,
|
||||
refresh=True,
|
||||
refresh_timeout=60,
|
||||
)
|
||||
|
||||
assert [m["id"] for m in result] == refreshed
|
||||
assert calls == [{
|
||||
"base_url": "http://100.117.136.97:34521/v1",
|
||||
"api_key": "fake-key",
|
||||
"timeout": 60.0,
|
||||
}]
|
||||
assert json.loads(ep.cached_models) == refreshed
|
||||
assert db.commits == 1
|
||||
assert response.headers["X-Model-Refresh-Status"] == "refreshed"
|
||||
assert response.headers["X-Model-Refresh-Count"] == "3"
|
||||
|
||||
|
||||
def test_manual_refresh_defaults_to_proxy_long_timeout(monkeypatch):
|
||||
ep = _route_ep(
|
||||
"proxy",
|
||||
"https://proxy.example.test/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="proxy",
|
||||
refresh_mode="manual",
|
||||
)
|
||||
db = _RouteDb([ep])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
|
||||
timeouts = []
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
timeouts.append(timeout)
|
||||
return ["cached-model", "new-model"]
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
response = SimpleNamespace(headers={})
|
||||
_route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
|
||||
"proxy",
|
||||
_route_request(),
|
||||
response,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
assert timeouts == [30.0]
|
||||
assert json.loads(ep.cached_models) == ["cached-model", "new-model"]
|
||||
|
||||
|
||||
def test_manual_refresh_timeout_keeps_cached_models_and_warns(monkeypatch):
|
||||
ep = _route_ep(
|
||||
"proxy",
|
||||
"http://100.117.136.97:34521/v1",
|
||||
cached_models=["cached-model"],
|
||||
endpoint_kind="proxy",
|
||||
api_key="fake-key",
|
||||
refresh_mode="manual",
|
||||
)
|
||||
db = _RouteDb([ep])
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RouteModelEndpoint)
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
|
||||
def fake_probe(base_url, api_key=None, timeout=2):
|
||||
raise httpx.TimeoutException("timed out")
|
||||
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", fake_probe)
|
||||
|
||||
response = SimpleNamespace(headers={})
|
||||
result = _route_endpoint(router, "/api/model-endpoints/{ep_id}/models")(
|
||||
"proxy",
|
||||
_route_request(),
|
||||
response,
|
||||
refresh=True,
|
||||
refresh_timeout=60,
|
||||
)
|
||||
|
||||
assert [m["id"] for m in result] == ["cached-model"]
|
||||
assert json.loads(ep.cached_models) == ["cached-model"]
|
||||
assert db.commits == 0
|
||||
assert response.headers["X-Model-Refresh-Status"] == "failed"
|
||||
assert "kept cached models" in response.headers["X-Model-Refresh-Warning"]
|
||||
|
||||
Reference in New Issue
Block a user