diff --git a/core/database.py b/core/database.py index cbd4bac..4788a45 100644 --- a/core/database.py +++ b/core/database.py @@ -342,6 +342,14 @@ class ModelEndpoint(TimestampMixin, Base): cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models) model_type = Column(String, nullable=True, default="llm") # "llm" or "image" + # auto = classify by URL; local = self-hosted server; api/proxy = external + # OpenAI-compatible API even when reachable through a private/tailnet IP. + endpoint_kind = Column(String, nullable=True, default="auto") + # auto = background refresh with TTL/backoff; manual/disabled = cached-first + # only unless an explicit endpoint probe is requested. + model_refresh_mode = Column(String, nullable=True, default="auto") + model_refresh_interval = Column(Integer, nullable=True, default=None) + model_refresh_timeout = Column(Integer, nullable=True, default=None) # Whether models on this endpoint accept OpenAI-style function # schemas + emit `tool_calls`. Auto-detected at Cookbook auto- # register time from `--enable-auto-tool-choice` in the serve cmd; @@ -809,6 +817,29 @@ def _migrate_add_model_type_column(): except Exception as e: logging.getLogger(__name__).warning(f"model_type migration failed: {e}") +def _migrate_add_model_endpoint_refresh_columns(): + """Add endpoint classification / refresh policy columns if missing.""" + import sqlite3 + db_path = DATABASE_URL.replace("sqlite:///", "") + if not os.path.exists(db_path): + return + try: + conn = sqlite3.connect(db_path) + cursor = conn.execute("PRAGMA table_info(model_endpoints)") + columns = [row[1] for row in cursor.fetchall()] + if columns and "endpoint_kind" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN endpoint_kind TEXT DEFAULT 'auto'") + if columns and "model_refresh_mode" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_mode TEXT DEFAULT 'auto'") + if columns and "model_refresh_interval" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_interval INTEGER") + if columns and "model_refresh_timeout" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER") + conn.commit() + conn.close() + except Exception as e: + logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}") + def _migrate_add_task_run_model_column(): """Add model column to task_runs if it doesn't exist (records which model ran).""" import sqlite3 @@ -1539,6 +1570,7 @@ def init_db(): _migrate_add_pinned_models_column() _migrate_add_notes_sort_order() _migrate_add_model_type_column() + _migrate_add_model_endpoint_refresh_columns() _migrate_add_model_endpoint_owner_column() _migrate_add_supports_tools_column() _migrate_add_task_run_model_column() diff --git a/routes/model_routes.py b/routes/model_routes.py index f4153b0..b56d200 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -11,7 +11,7 @@ import httpx from datetime import datetime from typing import List, Dict, Any, Optional from urllib.parse import urlparse, urlunparse -from fastapi import APIRouter, HTTPException, Form, Query, Body, Request +from fastapi import APIRouter, HTTPException, Form, Query, Body, Request, Response from pydantic import BaseModel from fastapi.responses import StreamingResponse from core.database import SessionLocal, ModelEndpoint, Session as DbSession @@ -335,6 +335,141 @@ def _truthy(value: str | None) -> bool: return (value or "").strip().lower() in ("true", "1", "yes", "on") +_ENDPOINT_KINDS = {"auto", "local", "api", "proxy"} +_REFRESH_MODES = {"auto", "manual", "disabled"} + + +def _normalize_endpoint_kind(value: Any) -> str: + kind = str(value or "auto").strip().lower() + return kind if kind in _ENDPOINT_KINDS else "auto" + + +def _normalize_refresh_mode(value: Any, endpoint_kind: str = "auto") -> str: + mode = str(value or "").strip().lower() + kind = _normalize_endpoint_kind(endpoint_kind) + if mode in ("manual", "disabled"): + return mode + if mode == "auto" and kind != "proxy": + return "auto" + # Proxies default to manual cached-first behavior. Normal local/API + # endpoints keep automatic bounded refreshes. + return "manual" if kind == "proxy" else "auto" + + +def _endpoint_kind(ep: Any) -> str: + return _normalize_endpoint_kind(getattr(ep, "endpoint_kind", None)) + + +def _endpoint_refresh_mode(ep: Any, endpoint_kind: str | None = None) -> str: + return _normalize_refresh_mode(getattr(ep, "model_refresh_mode", None), endpoint_kind or _endpoint_kind(ep)) + + +def _endpoint_refresh_interval(ep: Any, category: str) -> float: + raw = getattr(ep, "model_refresh_interval", None) + try: + val = int(raw) if raw is not None else 0 + except Exception: + val = 0 + if val > 0: + return float(max(30, val)) + return 60.0 if category == "local" else 3600.0 + + +def _endpoint_refresh_timeout(ep: Any, category: str) -> float: + raw = getattr(ep, "model_refresh_timeout", None) + try: + val = int(raw) if raw is not None else 0 + except Exception: + val = 0 + if val > 0: + return float(max(1, min(30, val))) + return 2.5 if category == "local" else 2.0 + + +def _manual_refresh_timeout(ep: Any, category: str, requested: Any = None) -> float: + """Timeout for explicit user-triggered model-list refreshes. + + Background refreshes stay short. A manual refresh is the one path where a + large proxy may legitimately need 15-30s to aggregate its catalog. + """ + requested_val = _parse_positive_int(requested, minimum=1, maximum=60) + if requested_val is not None: + return float(requested_val) + stored = _parse_positive_int(getattr(ep, "model_refresh_timeout", None), minimum=1, maximum=60) + if category == "local": + return float(stored) if stored is not None else _endpoint_refresh_timeout(ep, category) + return float(max(stored or 30, 30)) + + +def _parse_model_list(raw: Any) -> List[str]: + """Return a sanitized list of model ids from JSON/list/comma text.""" + if raw is None: + return [] + value = raw + if isinstance(value, str): + text = value.strip() + if not text: + return [] + try: + parsed = json.loads(text) + if isinstance(parsed, list): + value = parsed + else: + value = re.split(r"[\n,]+", text) + except Exception: + value = re.split(r"[\n,]+", text) + if not isinstance(value, list): + return [] + out = [] + seen = set() + for item in value: + mid = str(item or "").strip() + if not mid or mid in seen: + continue + seen.add(mid) + out.append(mid) + return out + + +def _parse_positive_int(raw: Any, *, minimum: int = 1, maximum: int = 86400) -> Optional[int]: + try: + val = int(str(raw).strip()) + except Exception: + return None + if val < minimum: + return None + return min(val, maximum) + + +def _explicit_model_list_timeout(base_url: str, endpoint_kind: str = "auto", requested: Any = None) -> float: + """Timeout for explicit user-triggered model-list fetches during setup.""" + requested_val = _parse_positive_int(requested, minimum=1, maximum=60) + if requested_val is not None: + return float(requested_val) + kind = _normalize_endpoint_kind(endpoint_kind) + category = _classify_endpoint(base_url, kind) + if kind in ("api", "proxy") or category == "api": + return 30.0 + return 3.0 if _is_ollama_base(base_url) else 2.0 + + +def _cached_model_ids(ep: Any) -> List[str]: + return _parse_model_list(getattr(ep, "cached_models", None)) + + +def _hidden_model_ids(ep: Any) -> set: + return set(_parse_model_list(getattr(ep, "hidden_models", None))) + + +def _is_ollama_base(base_url: str) -> bool: + try: + parsed = urlparse(base_url) + host = (parsed.hostname or "").lower() + return parsed.port == 11434 or "ollama" in host + except Exception: + return "ollama" in (base_url or "").lower() + + # Prefixes/substrings for models that are NOT chat-completions-capable _NON_CHAT_PREFIXES = ( "dall-e", "tts-", "whisper", "text-embedding", "embedding", @@ -441,10 +576,15 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", _TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.") -def _classify_endpoint(base_url: str) -> str: +def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str: """Return 'local' if the endpoint URL points to a private/local address, else 'api'. Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted servers (e.g. Cookbook serve endpoints) get reachability-probed too.""" + kind = _normalize_endpoint_kind(endpoint_kind) + if kind == "local": + return "local" + if kind in ("api", "proxy"): + return "api" try: host = urlparse(base_url).hostname or "" if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES): @@ -456,6 +596,21 @@ def _classify_endpoint(base_url: str) -> str: return "api" +def _effective_endpoint_kind(ep: Any, base_url: str) -> str: + """Return explicit kind, with a legacy proxy heuristic for keyed /v1 URLs.""" + kind = _endpoint_kind(ep) + if kind != "auto": + return kind + if getattr(ep, "api_key", None) and not _is_ollama_base(base_url): + try: + path = (urlparse(base_url).path or "").rstrip("/") + if path.endswith("/v1") or "/openai" in path: + return "proxy" + except Exception: + pass + return "auto" + + def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]: """Probe a base URL's /models endpoint and return list of model IDs. @@ -546,30 +701,18 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> """Reachability probe that does not require installed/listed models.""" from src.endpoint_resolver import resolve_url base = resolve_url(_normalize_base(base_url)) - headers = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers = build_headers(api_key, base) # Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version, - # /api/tags. The OpenAI-style GET base + "/models" returns 404 when the - # base is the host root or the native /api root (e.g. http://localhost:11434, - # http://localhost:11434/api) because /models lives under /v1 there. Treat - # 4xx on a port-11434 / Ollama-named base as "try the native paths" rather - # than as a definitive offline verdict — Ollama is reachable, it just - # doesn't speak OpenAI on that prefix. Without this gate the quickstart - # marks an alive Ollama as offline whenever cached_models is empty (issue - # #1025): _probe_endpoint() falls through to /api/tags on the same 404, but - # _ping_endpoint() was returning before that fallback could run. + # /api/tags. Probe native paths for Ollama-style endpoints, but avoid using + # /models as a generic health check because large proxy catalogs can be slow. parsed_base = urlparse(base) looks_like_ollama = ( parsed_base.port == 11434 or "ollama" in (parsed_base.hostname or "").lower() ) - url = base + "/models" - last_error: Optional[str] = None - try: - r = httpx.get(url, headers=headers, timeout=timeout) + def _result_from_response(r) -> Dict[str, Any]: if 300 <= r.status_code < 400: loc = r.headers.get("location", "") if loc.startswith("/login") or "/login" in loc: @@ -579,13 +722,15 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> "error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.", } return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"} - if r.status_code < 400: - return {"reachable": True, "status_code": r.status_code, "error": None} - if r.status_code < 500 and not looks_like_ollama: - return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"} - last_error = f"HTTP {r.status_code}" - except Exception as e: - last_error = str(e)[:120] + if 200 <= r.status_code < 300: + return { + "reachable": True, + "status_code": r.status_code, + "error": None, + } + return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"} + + last_error: Optional[str] = None try: if looks_like_ollama: @@ -597,14 +742,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> for path in ("/api/version", "/api/tags"): try: r = httpx.get(root + path, timeout=timeout) - if r.status_code < 400: - return {"reachable": True, "status_code": r.status_code, "error": None} - last_error = f"HTTP {r.status_code}" + result = _result_from_response(r) + if result["reachable"]: + return result + last_error = result.get("error") except Exception as e: last_error = str(e)[:120] except Exception: pass + try: + r = httpx.get(base, headers=headers, timeout=timeout) + return _result_from_response(r) + except Exception as e: + last_error = str(e)[:120] + return {"reachable": False, "status_code": None, "error": last_error} @@ -715,17 +867,71 @@ def setup_model_routes(model_discovery): flip).""" _models_cache.clear() - # Track endpoints that have failed recently so we back off probing dead ones. - _probe_failures = {} # ep_id → (last_fail_ts, consecutive_fails) + # Track model-list refreshes by URL+key. This prevents repeated picker/API + # opens from starting duplicate /models probes, and gives slow/offline + # providers a cooldown after failures. + _refresh_state: Dict[str, Dict[str, Any]] = {} _refresh_inflight = {"v": False} # coarse single-flight guard + _REFRESH_FAILURE_BASE = 300.0 + _REFRESH_FAILURE_MAX = 3600.0 - def _refresh_caches_bg(): - """Background thread: re-probe all endpoints in PARALLEL with a tight - timeout, skipping endpoints that have been failing repeatedly. + def _refresh_key(base: str, api_key: Optional[str]) -> str: + return f"{base.rstrip('/')}\x00{api_key or ''}" - Was the cause of gradual server degradation: sequential 3s-timeout - probes against many endpoints (some offline) tied up the threadpool - for 15-30s every cache cycle, eventually exhausting it.""" + def _ts(value: Any) -> float: + try: + return float(value.timestamp()) if value else 0.0 + except Exception: + return 0.0 + + def _failure_delay(fails: int) -> float: + if fails <= 0: + return 0.0 + return min(_REFRESH_FAILURE_BASE * (2 ** max(0, fails - 1)), _REFRESH_FAILURE_MAX) + + def _should_refresh_endpoint(ep: Any, now: float, force: bool = False) -> tuple[bool, Dict[str, Any]]: + base = _normalize_base(getattr(ep, "base_url", "") or "") + kind = _effective_endpoint_kind(ep, base) + category = _classify_endpoint(base, kind) + mode = _endpoint_refresh_mode(ep, kind) + cached = _cached_model_ids(ep) + key = _refresh_key(base, getattr(ep, "api_key", None)) + state = _refresh_state.get(key, {}) + + info = { + "id": getattr(ep, "id", ""), + "base": base, + "api_key": getattr(ep, "api_key", None), + "kind": kind, + "category": category, + "mode": mode, + "key": key, + "timeout": _endpoint_refresh_timeout(ep, category), + } + if not base: + return False, info + if state.get("inflight"): + return False, info + if mode in ("manual", "disabled") and not force: + return False, info + fails = int(state.get("fail_count") or 0) + if fails and not force: + last_failure = float(state.get("last_failure") or 0.0) + if now - last_failure < _failure_delay(fails): + return False, info + if cached and not force: + interval = _endpoint_refresh_interval(ep, category) + last_good = float(state.get("last_success") or 0.0) or _ts(getattr(ep, "updated_at", None)) or _ts(getattr(ep, "created_at", None)) + if last_good and now - last_good < interval: + return False, info + return True, info + + def _refresh_caches_bg(force: bool = False): + """Background thread: safely refresh model caches with per-base single-flight. + + The public /api/models path stays cached-first. This refresh never clears + a non-empty cached model list on timeout/failure, and proxy/manual + endpoints are skipped unless explicitly forced.""" import threading if _refresh_inflight["v"]: return # already running @@ -735,44 +941,63 @@ def setup_model_routes(model_discovery): try: from concurrent.futures import ThreadPoolExecutor, as_completed db = SessionLocal() + changed = False try: endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() - # Skip endpoints that have failed 3+ times in a row in the last 5 min now = _time.time() - to_probe = [] + groups: Dict[str, Dict[str, Any]] = {} for ep in endpoints: - ts, fails = _probe_failures.get(ep.id, (0, 0)) - if fails >= 3 and (now - ts) < 300: + ok, info = _should_refresh_endpoint(ep, now, force=force) + if not ok: continue - to_probe.append(ep) + groups.setdefault(info["key"], { + "base": info["base"], + "api_key": info["api_key"], + "timeout": info["timeout"], + "endpoint_ids": [], + })["endpoint_ids"].append(info["id"]) - def _probe_one(ep): - base = _normalize_base(ep.base_url) + for key in groups: + st = _refresh_state.setdefault(key, {}) + st["inflight"] = True + st["last_attempt"] = now + + def _probe_one(key: str, data: Dict[str, Any]): try: - ids = _probe_endpoint(base, ep.api_key, timeout=2) - return ep, ids, None + ids = _probe_endpoint(data["base"], data.get("api_key"), timeout=data.get("timeout") or 2) + return key, data["endpoint_ids"], ids, None except Exception as e: - return ep, None, e + return key, data["endpoint_ids"], None, e - if to_probe: - # Bounded parallelism — 8 concurrent probes is plenty - with ThreadPoolExecutor(max_workers=min(8, len(to_probe))) as pool: - futures = [pool.submit(_probe_one, ep) for ep in to_probe] + if groups: + with ThreadPoolExecutor(max_workers=min(4, len(groups))) as pool: + futures = [pool.submit(_probe_one, key, data) for key, data in groups.items()] for fut in as_completed(futures): - ep, ids, err = fut.result() + key, endpoint_ids, ids, err = fut.result() + st = _refresh_state.setdefault(key, {}) if ids: - ep.cached_models = json.dumps(ids) - _probe_failures.pop(ep.id, None) + for ep_id in endpoint_ids: + ep_obj = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() + if ep_obj: + ep_obj.cached_models = json.dumps(ids) + changed = True + st["last_success"] = _time.time() + st["fail_count"] = 0 + st.pop("last_failure", None) else: - prev = _probe_failures.get(ep.id, (0, 0)) - _probe_failures[ep.id] = (_time.time(), prev[1] + 1) + st["last_failure"] = _time.time() + st["fail_count"] = int(st.get("fail_count") or 0) + 1 + st["inflight"] = False db.commit() finally: db.close() - _invalidate_models_cache() + if changed: + _invalidate_models_cache() except Exception: pass finally: + for st in _refresh_state.values(): + st["inflight"] = False _refresh_inflight["v"] = False threading.Thread(target=_do, daemon=True).start() @@ -804,24 +1029,15 @@ def setup_model_routes(model_discovery): base = _normalize_base(ep.base_url) provider = _detect_provider(base) # Use cached models — background refresh keeps them updated - model_ids = [] - if ep.cached_models: - try: - model_ids = json.loads(ep.cached_models) - except Exception: - pass + model_ids = _cached_model_ids(ep) ep_model_type = getattr(ep, "model_type", None) or "llm" # Filter out hidden (probe-failed) models - hidden = set() - if ep.hidden_models: - try: - hidden = set(json.loads(ep.hidden_models)) - except Exception: - pass + hidden = _hidden_model_ids(ep) model_ids = [m for m in model_ids if m not in hidden] # Build correct URL based on provider chat_url = build_chat_url(base) - category = _classify_endpoint(base) + kind = _effective_endpoint_kind(ep, base) + category = _classify_endpoint(base, kind) if model_ids: curated_key = _match_provider_curated(base, None) @@ -837,6 +1053,7 @@ def setup_model_routes(model_discovery): "endpoint_id": ep.id, "endpoint_name": ep.name, "category": category, + "endpoint_kind": kind, "model_type": ep_model_type, }) else: @@ -852,6 +1069,7 @@ def setup_model_routes(model_discovery): "endpoint_id": ep.id, "endpoint_name": ep.name, "category": category, + "endpoint_kind": kind, "model_type": ep_model_type, "offline": True, }) @@ -898,11 +1116,11 @@ def setup_model_routes(model_discovery): result = _fetch_models(owner=owner, is_admin=_is_admin) _models_cache[_cache_key] = {"data": result, "time": now} # Kick off background refresh to update caches from live endpoints - _refresh_caches_bg() + _refresh_caches_bg(force=refresh) return result # Brief cache for local-probe results so picker-open doesn't hammer - # /v1/models every time. 8s TTL — long enough to amortize cost, + # endpoint health checks every time. 8s TTL — long enough to amortize cost, # short enough that a freshly-killed local server shows as offline # within ~8s of the user noticing. _LOCAL_PROBE_TTL = 8.0 @@ -912,7 +1130,7 @@ def setup_model_routes(model_discovery): async def probe_local_endpoints(request: Request): """Fast parallel reachability check for LOCAL endpoints only. Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are - assumed up. Local endpoints get a 1.5s /models probe so the UI + assumed up. Local endpoints get a 1.5s cheap reachability probe so the UI can dim stale entries pointing at dead vLLM servers. Returns {ep_id: {alive, latency_ms, error}}.""" require_admin(request) @@ -924,36 +1142,44 @@ def setup_model_routes(model_discovery): db = SessionLocal() try: endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() - local_eps = [ - (ep.id, _normalize_base(ep.base_url), ep.api_key) - for ep in endpoints - if _classify_endpoint(_normalize_base(ep.base_url)) == "local" - ] + local_eps = [] + for ep in endpoints: + base = _normalize_base(ep.base_url) + kind = _effective_endpoint_kind(ep, base) + if _classify_endpoint(base, kind) == "local": + local_eps.append((ep.id, base, ep.api_key)) finally: db.close() - async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]: + grouped: Dict[str, Dict[str, Any]] = {} + for ep_id, base, api_key in local_eps: + key = _refresh_key(base, api_key) + grouped.setdefault(key, {"base": base, "api_key": api_key, "endpoint_ids": []})["endpoint_ids"].append(ep_id) + + async def _probe_one(data: Dict[str, Any]) -> Dict[str, Any]: t0 = _time.time() try: - models = _probe_endpoint(base, api_key, timeout=2.5) + import asyncio as _asyncio + ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5) lat = round((_time.time() - t0) * 1000) return { - "alive": bool(models), + "alive": bool(ping.get("reachable")), "latency_ms": lat, - "status_code": 200 if models else None, - "error": None if models else "No models found", + "status_code": ping.get("status_code"), + "error": ping.get("error"), } except Exception as e: return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]} import asyncio as _asyncio results_list = await _asyncio.gather( - *[_probe_one(eid, base, key) for eid, base, key in local_eps], + *[_probe_one(data) for data in grouped.values()], return_exceptions=False, ) results: Dict[str, Any] = {} - for (eid, _, _), r in zip(local_eps, results_list): - results[eid] = r + for data, r in zip(grouped.values(), results_list): + for eid in data["endpoint_ids"]: + results[eid] = r _local_probe_cache["data"] = results _local_probe_cache["time"] = now @@ -973,50 +1199,28 @@ def setup_model_routes(model_discovery): for ep in endpoints: base = _normalize_base(ep.base_url) provider = _detect_provider(base) + kind = _effective_endpoint_kind(ep, base) + cached_count = len(_cached_model_ids(ep)) entry = { "id": ep.id, "name": ep.name, "base_url": base, "provider": provider, - "category": _classify_endpoint(base), + "category": _classify_endpoint(base, kind), + "endpoint_kind": kind, } - if provider == "anthropic": - # Anthropic has no /models endpoint; just check connectivity - try: - t0 = _time.time() - r = httpx.get(base.rstrip("/"), timeout=5) - entry["latency_ms"] = round((_time.time() - t0) * 1000) - entry["status"] = "online" - entry["model_count"] = len(ANTHROPIC_MODELS) - except Exception as e: - entry["latency_ms"] = None - entry["status"] = "offline" - entry["error"] = str(e) - entry["model_count"] = 0 - else: - url = build_models_url(base) - headers = build_headers(ep.api_key, base) - try: - t0 = _time.time() - r = httpx.get(url, headers=headers, timeout=5) - entry["latency_ms"] = round((_time.time() - t0) * 1000) - r.raise_for_status() - data = r.json() - models = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not models: - models = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] - entry["status"] = "online" - entry["model_count"] = len(models) - except Exception as e: - if "latency_ms" not in entry: - entry["latency_ms"] = None - entry["status"] = "offline" - entry["error"] = str(e) - entry["model_count"] = 0 + try: + t0 = _time.time() + ping = _ping_endpoint(base, ep.api_key, timeout=1.5) + entry["latency_ms"] = round((_time.time() - t0) * 1000) + entry["status"] = "online" if ping.get("reachable") or cached_count else "offline" + entry["error"] = ping.get("error") + entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0) + except Exception as e: + entry["latency_ms"] = None + entry["status"] = "online" if cached_count else "offline" + entry["error"] = str(e) + entry["model_count"] = cached_count results.append(entry) return {"endpoints": results} @@ -1165,19 +1369,8 @@ def setup_model_routes(model_discovery): rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all() results = [] for r in rows: - # Use cached model list to avoid slow probe on every load - all_models = [] - if r.cached_models: - try: - all_models = json.loads(r.cached_models) - except Exception: - pass - hidden = set() - if r.hidden_models: - try: - hidden = set(json.loads(r.hidden_models)) - except Exception: - pass + all_models = _cached_model_ids(r) + hidden = _hidden_model_ids(r) pinned = _normalize_model_ids(getattr(r, "pinned_models", None)) visible = _visible_models(all_models, r.hidden_models, pinned) # Endpoint counts as reachable if it has any model — including @@ -1188,6 +1381,8 @@ def setup_model_routes(model_discovery): ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) if ping.get("reachable"): status = "empty" + base = _normalize_base(r.base_url) + kind = _effective_endpoint_kind(r, base) results.append({ "id": r.id, "name": r.name, @@ -1202,6 +1397,11 @@ def setup_model_routes(model_discovery): "ping_error": (ping or {}).get("error") if ping else None, "model_type": getattr(r, "model_type", None) or "llm", "supports_tools": getattr(r, "supports_tools", None), + "endpoint_kind": kind, + "category": _classify_endpoint(base, kind), + "model_refresh_mode": _endpoint_refresh_mode(r, kind), + "model_refresh_interval": getattr(r, "model_refresh_interval", None), + "model_refresh_timeout": getattr(r, "model_refresh_timeout", None), }) return results finally: @@ -1216,6 +1416,10 @@ def setup_model_routes(model_discovery): skip_probe: str = Form("false"), require_models: str = Form("false"), model_type: str = Form("llm"), + endpoint_kind: str = Form("auto"), + model_refresh_mode: str = Form(""), + model_refresh_interval: str = Form(""), + model_refresh_timeout: str = Form(""), supports_tools: str = Form(""), # "true"/"false"/"" (unknown) pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline container_local: str = Form("false"), @@ -1240,8 +1444,15 @@ def setup_model_routes(model_discovery): if not name.strip(): name = base_url.replace("http://", "").replace("https://", "").split("/")[0] + requested_kind = _normalize_endpoint_kind(endpoint_kind) + refresh_mode = _normalize_refresh_mode(model_refresh_mode, requested_kind) + refresh_interval = _parse_positive_int(model_refresh_interval, minimum=30, maximum=86400) + refresh_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60) require_model_list = _truthy(require_models) - should_probe = require_model_list or not _truthy(skip_probe) + should_probe = ( + require_model_list or requested_kind in ("api", "proxy") or not _truthy(skip_probe) + ) + explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout) # Dedupe: if an endpoint with the same base_url already exists and # is reachable by the caller (shared or owned by them), return it @@ -1259,6 +1470,7 @@ def setup_model_routes(model_discovery): .first() ) if existing: + changed = False # Persist any incoming pinned IDs onto the existing row. An # empty/omitted form field must not wipe previously pinned IDs. _incoming_pinned = _normalize_model_ids(pinned_models) @@ -1268,15 +1480,45 @@ def setup_model_routes(model_discovery): _incoming_pinned, ) existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None + changed = True + existing_kind_for_probe = requested_kind if requested_kind != "auto" else _effective_endpoint_kind(existing, base_url) + if requested_kind != "auto" and _endpoint_kind(existing) == "auto": + existing.endpoint_kind = requested_kind + changed = True + if model_refresh_mode or (requested_kind == "proxy" and _endpoint_refresh_mode(existing, requested_kind) != refresh_mode): + existing.model_refresh_mode = refresh_mode + changed = True + if refresh_interval is not None: + existing.model_refresh_interval = refresh_interval + changed = True + if refresh_timeout is not None: + existing.model_refresh_timeout = refresh_timeout + changed = True + if api_key.strip() and not existing.api_key: + existing.api_key = api_key.strip() + changed = True + if should_probe: + probed_models = _probe_endpoint( + base_url, + (api_key.strip() or existing.api_key or None), + timeout=_explicit_model_list_timeout(base_url, existing_kind_for_probe, refresh_timeout), + ) + if probed_models: + existing.cached_models = json.dumps(probed_models) + changed = True + if changed: _db_dedup.commit() _invalidate_models_cache() + _local_probe_cache["data"] = None + existing_models = _cached_model_ids(existing) _existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None)) + existing_kind = _effective_endpoint_kind(existing, existing.base_url) return { "id": existing.id, "name": existing.name, "base_url": existing.base_url, "models": _visible_models( - getattr(existing, "cached_models", None), + existing_models, getattr(existing, "hidden_models", None), existing.pinned_models, ), @@ -1284,16 +1526,16 @@ def setup_model_routes(model_discovery): "online": True, "status": "online", "existing": True, + "endpoint_kind": existing_kind, + "category": _classify_endpoint(existing.base_url, existing_kind), } finally: _db_dedup.close() - # Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh) - _probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1 - model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else [] + model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=explicit_timeout) if should_probe else [] ping = {"reachable": False, "error": None} - if should_probe and not model_ids: - ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) + if (should_probe or requested_kind in ("api", "proxy")) and not model_ids: + ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=min(explicit_timeout, 2.0)) if require_model_list and not model_ids: raise HTTPException(400, _model_endpoint_error_message(base_url, ping)) @@ -1317,6 +1559,10 @@ def setup_model_routes(model_discovery): api_key=api_key.strip() or None, is_enabled=True, model_type=model_type.strip() if model_type else "llm", + endpoint_kind=requested_kind, + model_refresh_mode=refresh_mode, + model_refresh_interval=refresh_interval, + model_refresh_timeout=refresh_timeout, cached_models=json.dumps(model_ids) if model_ids else None, pinned_models=json.dumps(_pinned) if _pinned else None, supports_tools=_st, @@ -1349,6 +1595,8 @@ def setup_model_routes(model_discovery): "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")), "status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"), "ping_error": ping.get("error") if ping else None, + "endpoint_kind": requested_kind, + "category": _classify_endpoint(base_url, requested_kind), } @router.post("/model-endpoints/test") @@ -1356,6 +1604,8 @@ def setup_model_routes(model_discovery): request: Request, base_url: str = Form(...), api_key: str = Form(""), + endpoint_kind: str = Form("auto"), + model_refresh_timeout: str = Form(""), ): require_admin(request) base_url = _normalize_base(base_url) @@ -1364,9 +1614,11 @@ def setup_model_routes(model_discovery): from src.endpoint_resolver import resolve_url base_url = resolve_url(base_url) base_url = _rewrite_loopback_for_docker(base_url) - probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2 + requested_kind = _normalize_endpoint_kind(endpoint_kind) + configured_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60) + probe_timeout = _explicit_model_list_timeout(base_url, requested_kind, configured_timeout) models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout) - ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout) + ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=min(probe_timeout, 2.0)) return { "base_url": base_url, "online": bool(models) or bool(ping.get("reachable")), @@ -1374,6 +1626,8 @@ def setup_model_routes(model_discovery): "ping_error": ping.get("error") if ping else None, "models": models, "count": len(models), + "endpoint_kind": requested_kind, + "category": _classify_endpoint(base_url, requested_kind), } @router.get("/model-endpoints/{ep_id}/probe") @@ -1415,7 +1669,8 @@ def setup_model_routes(model_discovery): ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() if ep_obj: ep_obj.hidden_models = json.dumps(failed) if failed else None - ep_obj.cached_models = json.dumps(all_models) if all_models else None + if all_models: + ep_obj.cached_models = json.dumps(all_models) db2.commit() finally: db2.close() @@ -1426,7 +1681,13 @@ def setup_model_routes(model_discovery): return StreamingResponse(_stream(), media_type="text/event-stream") @router.get("/model-endpoints/{ep_id}/models") - def list_endpoint_models(ep_id: str, request: Request): + def list_endpoint_models( + ep_id: str, + request: Request, + response: Response, + refresh: bool = False, + refresh_timeout: Optional[int] = Query(None, ge=1, le=60), + ): """List all discovered models for an endpoint with hidden/visible state.""" require_admin(request) db = SessionLocal() @@ -1434,23 +1695,28 @@ def setup_model_routes(model_discovery): ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() if not ep: raise HTTPException(404, "Endpoint not found") - hidden = set() - if ep.hidden_models: + hidden = _hidden_model_ids(ep) + all_models = _cached_model_ids(ep) + if refresh: + base = _normalize_base(ep.base_url) + kind = _effective_endpoint_kind(ep, base) + category = _classify_endpoint(base, kind) + timeout = _manual_refresh_timeout(ep, category, refresh_timeout) try: - hidden = set(json.loads(ep.hidden_models)) - except Exception: - pass - # Try live probe, fall back to cached. Pinned IDs are admin-entered - # and persist regardless of probe results — never overwritten here. - all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) - if all_models: - ep.cached_models = json.dumps(all_models) - db.commit() - elif ep.cached_models: - try: - all_models = json.loads(ep.cached_models) - except Exception: - pass + probed = _probe_endpoint(base, ep.api_key, timeout=timeout) + except Exception as exc: + logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc) + probed = [] + if probed: + all_models = probed + ep.cached_models = json.dumps(all_models) + db.commit() + _invalidate_models_cache() + response.headers["X-Model-Refresh-Status"] = "refreshed" + response.headers["X-Model-Refresh-Count"] = str(len(probed)) + else: + response.headers["X-Model-Refresh-Status"] = "failed" + response.headers["X-Model-Refresh-Warning"] = "Model refresh failed or returned no models; kept cached models." pinned = _normalize_model_ids(getattr(ep, "pinned_models", None)) pinned_set = set(pinned) return [ @@ -1502,7 +1768,6 @@ def setup_model_routes(model_discovery): @router.get("/default-chat") def get_default_chat(request: Request): - import json as _json # SECURITY: resolve the default endpoint + model from the CALLER's # per-user prefs ONLY. We deliberately do NOT fall back to the # global `default_model` / `default_endpoint_id` in settings.json @@ -1635,6 +1900,16 @@ def setup_model_routes(model_discovery): if "pinned_models" in body: _pinned = _normalize_model_ids(body["pinned_models"]) ep.pinned_models = json.dumps(_pinned) if _pinned else None + if "endpoint_kind" in body: + ep.endpoint_kind = _normalize_endpoint_kind(body.get("endpoint_kind")) + if "model_refresh_mode" in body: + ep.model_refresh_mode = _normalize_refresh_mode(body.get("model_refresh_mode"), _endpoint_kind(ep)) + if "model_refresh_interval" in body: + interval = _parse_positive_int(body.get("model_refresh_interval"), minimum=30, maximum=86400) + ep.model_refresh_interval = interval + if "model_refresh_timeout" in body: + timeout = _parse_positive_int(body.get("model_refresh_timeout"), minimum=1, maximum=60) + ep.model_refresh_timeout = timeout # Rotating an API key used to require DELETE+POST, which wiped # endpoint_url/model from every session referencing the old base # URL. Allow in-place updates so the admin can change the key @@ -1664,6 +1939,10 @@ def setup_model_routes(model_discovery): "model_type": ep.model_type, "base_url": ep.base_url, "pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)), + "endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto", + "model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto", + "model_refresh_interval": getattr(ep, "model_refresh_interval", None), + "model_refresh_timeout": getattr(ep, "model_refresh_timeout", None), } finally: db.close() diff --git a/src/llm_core.py b/src/llm_core.py index eb23057..2d66685 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -743,8 +743,74 @@ def _normalize_anthropic_url(url: str) -> str: return url + "/messages" return url + "/v1/messages" + +def _model_list_base(url: str) -> str: + """Normalize model/chat URLs to the configured endpoint base.""" + base = (url or "").strip().rstrip("/") + for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"): + if base.endswith(suffix): + base = base[: -len(suffix)].rstrip("/") + for suffix in ("/chat", "/tags", "/generate"): + if base.endswith("/api" + suffix): + base = base[: -len(suffix)].rstrip("/") + return base + + +def _parse_model_cache(raw) -> List[str]: + if not raw: + return [] + try: + models = json.loads(raw) if isinstance(raw, str) else raw + except Exception: + return [] + if not isinstance(models, list): + return [] + out = [] + seen = set() + for item in models: + mid = str(item or "").strip() + if not mid or mid in seen: + continue + out.append(mid) + seen.add(mid) + return out + + +def _configured_cached_model_ids(endpoint_url: str) -> List[str]: + """Return cached models for a configured endpoint matching endpoint_url.""" + target = _model_list_base(endpoint_url) + if not target: + return [] + try: + from src.database import SessionLocal, ModelEndpoint + except Exception: + return [] + db = SessionLocal() + try: + rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + for ep in rows: + if _model_list_base(getattr(ep, "base_url", "")) != target: + continue + models = _parse_model_cache(getattr(ep, "cached_models", None) or getattr(ep, "models", None)) + if not models: + continue + hidden = set(_parse_model_cache(getattr(ep, "hidden_models", None))) + return [m for m in models if m not in hidden] + except Exception: + return [] + finally: + try: + db.close() + except Exception: + pass + return [] + + def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]: """List available model IDs from an endpoint.""" + cached = _configured_cached_model_ids(base_chat_url) + if cached: + return cached provider = _detect_provider(base_chat_url) if provider == "anthropic": return list(ANTHROPIC_MODELS) diff --git a/src/model_context.py b/src/model_context.py index 6fdd23e..c985d3d 100644 --- a/src/model_context.py +++ b/src/model_context.py @@ -6,6 +6,7 @@ Provides token estimation for context usage tracking. """ import logging +import sys from typing import Dict, List, Optional from urllib.parse import urlparse @@ -21,8 +22,55 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.30.", "172.31.", "192.168.", "100.") +def _normalize_base_for_compare(url: str) -> str: + url = (url or "").strip().rstrip("/") + for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"): + if url.endswith(suffix): + url = url[: -len(suffix)].rstrip("/") + return url + + +def _configured_endpoint_kind(url: str) -> Optional[str]: + """Return configured endpoint kind for a chat/base URL when available.""" + target = _normalize_base_for_compare(url) + if not target: + return None + if "core.database" not in sys.modules: + return None + try: + from core.database import SessionLocal, ModelEndpoint + db = SessionLocal() + try: + rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + for ep in rows: + base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "") + if not base: + continue + if target != base and not target.startswith(base + "/"): + continue + kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower() + if kind in ("local", "api", "proxy"): + return kind + if getattr(ep, "api_key", None): + parsed = urlparse(base) + host = (parsed.hostname or "").lower() + path = (parsed.path or "").rstrip("/") + if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path): + return "proxy" + return "auto" + finally: + db.close() + except Exception: + return None + + def _is_local_endpoint(url: str) -> bool: """Check if URL points to a local/private/tailscale address.""" + kind = _configured_endpoint_kind(url) + if kind in ("api", "proxy"): + return False + if kind == "local": + return True try: host = urlparse(url).hostname or "" return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) @@ -170,6 +218,7 @@ def get_context_length(endpoint_url: str, model: str) -> int: or context_window fields. Caches result per model ID. Falls back to DEFAULT_CONTEXT if unavailable. """ + configured_kind = _configured_endpoint_kind(endpoint_url) is_local = _is_local_endpoint(endpoint_url) if not is_local and model in _context_cache: return _context_cache[model] @@ -178,7 +227,7 @@ def get_context_length(endpoint_url: str, model: str) -> int: # Only cache non-default values to allow retry on next request. # Local endpoints can restart with a different --max-model-len while keeping # the same model id, so always re-query them instead of serving stale cache. - if not is_local and ctx != DEFAULT_CONTEXT: + if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")): _context_cache[model] = ctx logger.info(f"Context length for {model}: {ctx}") return ctx @@ -207,6 +256,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int: """Query the model API for context length.""" known = _lookup_known(model) api_ctx = None + configured_kind = _configured_endpoint_kind(endpoint_url) + + # Large OpenAI-compatible proxies can make /models expensive. If the + # endpoint is explicitly configured as API/proxy, prefer known context + # metadata (or the default) over downloading the full catalog. + if configured_kind in ("api", "proxy"): + if known: + logger.info(f"Using known context window for {model}: {known}") + return known + return DEFAULT_CONTEXT # Try llama.cpp /slots endpoint first — reports actual serving context if _is_local_endpoint(endpoint_url): diff --git a/static/index.html b/static/index.html index baa471f..72544de 100644 --- a/static/index.html +++ b/static/index.html @@ -2079,6 +2079,10 @@