Improve Ollama setup and model endpoint handling
This commit is contained in:
@@ -482,7 +482,10 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
}
|
||||
return {"ok": False, "message": f"ntfy returned HTTP {r.status_code} from {full_url}: {r.text[:200]}"}
|
||||
except Exception as e:
|
||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}"[:300]}
|
||||
hint = ""
|
||||
if parsed.hostname not in ("127.0.0.1", "localhost"):
|
||||
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
||||
|
||||
# All other presets: GET against a known health endpoint.
|
||||
# Fall back to detecting from name if preset is missing.
|
||||
|
||||
@@ -902,7 +902,8 @@ def setup_calendar_routes() -> APIRouter:
|
||||
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}")
|
||||
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}")
|
||||
if ev.description:
|
||||
lines.append(f"DESCRIPTION:{ev.description.replace(chr(10), '\\n')}")
|
||||
escaped_desc = ev.description.replace(chr(10), "\\n")
|
||||
lines.append(f"DESCRIPTION:{escaped_desc}")
|
||||
if ev.location:
|
||||
lines.append(f"LOCATION:{ev.location}")
|
||||
if ev.rrule:
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, AsyncGenerator, List
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Form, Query
|
||||
@@ -17,6 +18,7 @@ from src.agent_loop import stream_agent_loop
|
||||
from src import agent_runs
|
||||
from src.model_context import estimate_tokens
|
||||
from src.chat_helpers import coerce_message_and_session
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
@@ -87,6 +89,46 @@ def _message_needs_tools(text: str) -> bool:
|
||||
return any(p.search(text) for p in _TOOL_INTENT_PATTERNS)
|
||||
|
||||
|
||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
if not session_url or not endpoint_base:
|
||||
return False
|
||||
sess = session_url.rstrip("/")
|
||||
base = _normalize_base(endpoint_base).rstrip("/")
|
||||
variants = {
|
||||
base,
|
||||
base + "/chat/completions",
|
||||
build_chat_url(base).rstrip("/"),
|
||||
}
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
|
||||
def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
||||
if not getattr(sess, "endpoint_url", ""):
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
||||
return False
|
||||
db_session = db.query(DBSession).filter(DBSession.id == sess.id).first()
|
||||
if db_session:
|
||||
db_session.endpoint_url = ""
|
||||
db_session.model = ""
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.endpoint_url = ""
|
||||
sess.model = ""
|
||||
sess.headers = {}
|
||||
return True
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def setup_chat_routes(
|
||||
session_manager,
|
||||
chat_handler,
|
||||
@@ -121,6 +163,8 @@ def setup_chat_routes(
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session}' not found")
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
|
||||
# Same allowed_models + daily-cap gate as chat_stream (mirror so the
|
||||
# non-streaming path can't be used to bypass).
|
||||
@@ -259,6 +303,8 @@ def setup_chat_routes(
|
||||
# but BEFORE loading. Prevents cross-user session hijack.
|
||||
_verify_session_owner(request, session)
|
||||
sess = session_manager.get_session(session)
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
except SessionNotFoundError as e:
|
||||
raise HTTPException(404, str(e))
|
||||
except (ValueError, ValidationError):
|
||||
|
||||
@@ -6,12 +6,13 @@ import json
|
||||
import time as _time
|
||||
import logging
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
|
||||
from core.middleware import require_admin
|
||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||
from src.settings import load_settings as _load_settings, save_settings as _save_settings
|
||||
@@ -301,6 +302,21 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
logger.warning(f"Failed to probe {url} with API key: {e}")
|
||||
return []
|
||||
logger.warning(f"Failed to probe {url}: {e}")
|
||||
|
||||
# Older Ollama builds and some proxies expose native /api/tags even when
|
||||
# the OpenAI-compatible /v1/models path is unavailable.
|
||||
try:
|
||||
parsed = urlparse(base)
|
||||
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
|
||||
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
|
||||
r = httpx.get(root + "/api/tags", timeout=timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
|
||||
if models:
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.debug(f"Ollama /api/tags probe failed for {base}: {e}")
|
||||
# Fall back to curated list if the provider has a URL-based match (e.g. z.ai has no /models endpoint)
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||
@@ -310,6 +326,51 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return []
|
||||
|
||||
|
||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
url = base + "/models"
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout)
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
if loc.startswith("/login") or "/login" in loc:
|
||||
return {
|
||||
"reachable": False,
|
||||
"status_code": r.status_code,
|
||||
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
|
||||
if r.status_code < 500:
|
||||
return {"reachable": r.status_code < 400, "status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
else:
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
|
||||
try:
|
||||
parsed = urlparse(base)
|
||||
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
|
||||
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
|
||||
for path in ("/api/version", "/api/tags"):
|
||||
try:
|
||||
r = httpx.get(root + path, timeout=timeout)
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
def setup_model_routes(model_discovery):
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
@@ -549,15 +610,16 @@ def setup_model_routes(model_discovery):
|
||||
db.close()
|
||||
|
||||
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]:
|
||||
url = base.rstrip("/") + "/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
t0 = _time.time()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.5) as client:
|
||||
r = await client.get(url, headers=headers)
|
||||
models = _probe_endpoint(base, api_key, timeout=2.5)
|
||||
lat = round((_time.time() - t0) * 1000)
|
||||
return {"alive": r.status_code < 400, "latency_ms": lat,
|
||||
"status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
|
||||
return {
|
||||
"alive": bool(models),
|
||||
"latency_ms": lat,
|
||||
"status_code": 200 if models else None,
|
||||
"error": None if models else "No models found",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
|
||||
|
||||
@@ -789,6 +851,12 @@ def setup_model_routes(model_discovery):
|
||||
except Exception:
|
||||
pass
|
||||
visible = [m for m in all_models if m not in hidden]
|
||||
status = "online" if all_models else "offline"
|
||||
ping = None
|
||||
if not all_models and r.is_enabled:
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
results.append({
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
@@ -797,7 +865,9 @@ def setup_model_routes(model_discovery):
|
||||
"is_enabled": r.is_enabled,
|
||||
"models": visible,
|
||||
"hidden_count": len(hidden),
|
||||
"online": len(all_models) > 0,
|
||||
"online": status != "offline",
|
||||
"status": status,
|
||||
"ping_error": (ping or {}).get("error") if ping else None,
|
||||
"model_type": getattr(r, "model_type", None) or "llm",
|
||||
"supports_tools": getattr(r, "supports_tools", None),
|
||||
})
|
||||
@@ -840,7 +910,11 @@ def setup_model_routes(model_discovery):
|
||||
should_probe = require_model_list or not _truthy(skip_probe)
|
||||
|
||||
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh)
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=1) if should_probe else []
|
||||
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
|
||||
ping = {"reachable": False, "error": None}
|
||||
if should_probe and not model_ids:
|
||||
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout)
|
||||
if require_model_list and not model_ids:
|
||||
raise HTTPException(400, "No models found for that provider/key")
|
||||
|
||||
@@ -876,6 +950,7 @@ def setup_model_routes(model_discovery):
|
||||
settings["default_model"] = model_ids[0] if model_ids else ""
|
||||
_save_settings(settings)
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -883,8 +958,38 @@ def setup_model_routes(model_discovery):
|
||||
return {
|
||||
"id": ep_id,
|
||||
"name": name.strip(),
|
||||
"base_url": base_url,
|
||||
"models": model_ids,
|
||||
"online": len(model_ids) > 0,
|
||||
"online": bool(model_ids) or bool(ping.get("reachable")),
|
||||
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"),
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
}
|
||||
|
||||
@router.post("/model-endpoints/test")
|
||||
def test_model_endpoint(
|
||||
request: Request,
|
||||
base_url: str = Form(...),
|
||||
api_key: str = Form(""),
|
||||
):
|
||||
require_admin(request)
|
||||
base_url = base_url.strip().rstrip("/")
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
if base_url.endswith(suffix):
|
||||
base_url = base_url[:-len(suffix)].rstrip("/")
|
||||
if not base_url:
|
||||
raise HTTPException(400, "Base URL is required")
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base_url = resolve_url(base_url)
|
||||
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
|
||||
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
return {
|
||||
"base_url": base_url,
|
||||
"online": bool(models) or bool(ping.get("reachable")),
|
||||
"status": "online" if models else ("empty" if ping.get("reachable") else "offline"),
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/probe")
|
||||
@@ -1175,6 +1280,49 @@ def setup_model_routes(model_discovery):
|
||||
_save_settings(settings)
|
||||
return cleared
|
||||
|
||||
def _session_uses_endpoint_url(session_url: str, base_url: str) -> bool:
|
||||
if not session_url or not base_url:
|
||||
return False
|
||||
sess = session_url.rstrip("/")
|
||||
base = _normalize_base(base_url).rstrip("/")
|
||||
variants = {
|
||||
base,
|
||||
base + "/chat/completions",
|
||||
build_chat_url(base).rstrip("/"),
|
||||
}
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
def _clear_sessions_for_endpoint(db, base_url: str) -> int:
|
||||
cleared = 0
|
||||
rows = db.query(DbSession).filter(DbSession.endpoint_url.isnot(None)).all()
|
||||
for row in rows:
|
||||
if _session_uses_endpoint_url(row.endpoint_url or "", base_url):
|
||||
row.endpoint_url = ""
|
||||
row.model = ""
|
||||
row.updated_at = datetime.utcnow()
|
||||
cleared += 1
|
||||
return cleared
|
||||
|
||||
def _clear_loaded_sessions_for_endpoint(base_url: str) -> int:
|
||||
try:
|
||||
from src.ai_interaction import get_session_manager
|
||||
manager = get_session_manager()
|
||||
except Exception:
|
||||
manager = None
|
||||
if not manager:
|
||||
return 0
|
||||
cleared = 0
|
||||
try:
|
||||
for sess in list(getattr(manager, "sessions", {}).values()):
|
||||
if _session_uses_endpoint_url(getattr(sess, "endpoint_url", "") or "", base_url):
|
||||
sess.endpoint_url = ""
|
||||
sess.model = ""
|
||||
sess.headers = {}
|
||||
cleared += 1
|
||||
except Exception:
|
||||
return cleared
|
||||
return cleared
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/dependents")
|
||||
def get_endpoint_dependents(ep_id: str, request: Request):
|
||||
"""Check which settings depend on this endpoint."""
|
||||
@@ -1191,10 +1339,18 @@ def setup_model_routes(model_discovery):
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
# Clean up any settings that reference this endpoint
|
||||
cleared = _clear_settings_for_endpoint(ep_id)
|
||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||
db.delete(ep)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
return {"deleted": True, "cleared_settings": cleared}
|
||||
_local_probe_cache["data"] = None
|
||||
return {
|
||||
"deleted": True,
|
||||
"cleared_settings": cleared,
|
||||
"cleared_sessions": cleared_sessions,
|
||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -284,11 +284,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.close()
|
||||
# Switch model/endpoint mid-session
|
||||
if model is not None and endpoint_url is not None:
|
||||
if endpoint_id:
|
||||
from core.database import ModelEndpoint
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(400, "Model endpoint no longer exists")
|
||||
finally:
|
||||
_db.close()
|
||||
session.model = model
|
||||
session.endpoint_url = endpoint_url
|
||||
# Update auth headers from the endpoint's stored API key
|
||||
if endpoint_id:
|
||||
from core.database import ModelEndpoint
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
|
||||
@@ -4,8 +4,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pty
|
||||
import fcntl
|
||||
import shlex
|
||||
import shutil
|
||||
import uuid
|
||||
@@ -13,6 +11,13 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
import pty
|
||||
except ImportError:
|
||||
fcntl = None
|
||||
pty = None
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -97,6 +102,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
||||
|
||||
async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
"""Run command in a pseudo-TTY so tqdm/progress bars work natively."""
|
||||
if pty is None or fcntl is None:
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'PTY streaming is not available on Windows'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user