diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py index ce2e0cf..7e7a764 100644 --- a/routes/chat_helpers.py +++ b/routes/chat_helpers.py @@ -188,7 +188,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None: Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None. """ import requests as _req - from src.endpoint_resolver import build_chat_url, build_headers, normalize_base + from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base current_url = sess.endpoint_url or "" db = SessionLocal() @@ -205,15 +205,19 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None: if current_url and base in current_url: continue # Quick ping - ping_url = base + "/models" - headers = {} - if ep.api_key: - headers["Authorization"] = f"Bearer {ep.api_key}" + ping_url = build_models_url(base) + headers = build_headers(ep.api_key, base) try: r = _req.get(ping_url, headers=headers, timeout=5) r.raise_for_status() data = r.json() models = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not models: + models = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] if not models: continue # Found a working endpoint — update session diff --git a/routes/compare_routes.py b/routes/compare_routes.py index 18b2165..2d06e95 100644 --- a/routes/compare_routes.py +++ b/routes/compare_routes.py @@ -62,14 +62,16 @@ def setup_compare_routes(session_manager: SessionManager): db = SessionLocal() try: from core.database import ModelEndpoint + from src.endpoint_resolver import build_headers, normalize_base # Find matching endpoint by URL + base = normalize_base(endpoint) ep = db.query(ModelEndpoint).filter( - ModelEndpoint.base_url == endpoint.replace('/chat/completions', '') + ModelEndpoint.base_url == base ).first() if ep and ep.api_key: s = session_manager.sessions.get(sid) if s: - s.headers = {"Authorization": f"Bearer {ep.api_key}"} + s.headers = build_headers(ep.api_key, ep.base_url) finally: db.close() diff --git a/routes/model_routes.py b/routes/model_routes.py index bd209db..3f4f2f1 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -16,12 +16,60 @@ from core.database import SessionLocal, ModelEndpoint, Session as DbSession from core.middleware import require_admin from src.llm_core import _detect_provider, ANTHROPIC_MODELS from src.settings import load_settings as _load_settings, save_settings as _save_settings -from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, _anthropic_api_root +from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url from src.auth_helpers import owner_filter logger = logging.getLogger(__name__) +def _anthropic_api_root(base: str) -> str: + """Return Anthropic's API root without duplicating /v1.""" + base = (base or "").strip().rstrip("/") + host = urlparse(base).hostname or "" + if host.endswith("anthropic.com") and base.endswith("/v1"): + return base[:-3].rstrip("/") + return base + + +def _ollama_api_root(base: str) -> str: + """Return Ollama's native API root without depending on deferred imports.""" + base = (base or "").strip().rstrip("/") + parsed = urlparse(base) + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if path.endswith("/api"): + return base + if host.endswith("ollama.com"): + root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" + return root.rstrip("/") + "/api" + return base + + +def _models_url(base: str) -> str: + """Return provider-specific model-list URL for route-local probing.""" + provider = _detect_provider(base) + host = urlparse(base).hostname or "" + if provider == "anthropic" or host.endswith("anthropic.com"): + return _anthropic_api_root(base) + "/v1/models" + if provider == "ollama" or host.endswith("ollama.com"): + return _ollama_api_root(base) + "/tags" + return base.rstrip("/") + "/models" + + +def _provider_headers(api_key: Optional[str], base: str) -> Dict[str, str]: + """Build provider auth headers without depending on import-time stubs.""" + if not api_key: + return {} + provider = _detect_provider(base) + host = urlparse(base).hostname or "" + if provider == "anthropic" or host.endswith("anthropic.com"): + return { + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + } + return {"Authorization": f"Bearer {api_key}"} + + # ── Curated model lists per provider ── # For cloud providers that return 100+ models, only show these by default. # A model ID matches if it starts with or equals a curated entry. @@ -87,6 +135,7 @@ _URL_TO_CURATED = { "generativelanguage.googleapis.com": "google", "api.x.ai": "xai", "openrouter.ai": "openrouter", + "ollama.com": "ollama", } @@ -183,9 +232,15 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1 payload = _build_anthropic_payload(model_id, messages, 0.0, 5) if _test_tools: payload["tools"] = [{"name": "test", "description": "Test tool", "input_schema": {"type": "object", "properties": {}}}] + elif provider == "ollama": + from src.llm_core import _build_ollama_payload + target_url = build_chat_url(base) + h = _provider_headers(api_key, base) + h["Content-Type"] = "application/json" + payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools) else: target_url = build_chat_url(base) - h = build_headers(api_key, base) + h = _provider_headers(api_key, base) h["Content-Type"] = "application/json" from src.llm_core import _uses_max_completion_tokens _max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens" @@ -276,10 +331,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis return [] logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}") return list(ANTHROPIC_MODELS) - url = base + "/models" - headers = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" + url = _models_url(base) + headers = _provider_headers(api_key, base) try: r = httpx.get(url, headers=headers, timeout=timeout) r.raise_for_status() @@ -494,10 +547,7 @@ def setup_model_routes(model_discovery): pass model_ids = [m for m in model_ids if m not in hidden] # Build correct URL based on provider - if provider == "anthropic": - chat_url = build_chat_url(base) - else: - chat_url = base + "/chat/completions" + chat_url = build_chat_url(base) category = _classify_endpoint(base) if model_ids: @@ -671,10 +721,8 @@ def setup_model_routes(model_discovery): entry["error"] = str(e) entry["model_count"] = 0 else: - url = base + "/models" - headers = {} - if ep.api_key: - headers["Authorization"] = f"Bearer {ep.api_key}" + url = _models_url(base) + headers = _provider_headers(ep.api_key, base) try: t0 = _time.time() r = httpx.get(url, headers=headers, timeout=5) @@ -682,6 +730,12 @@ def setup_model_routes(model_discovery): r.raise_for_status() data = r.json() models = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not models: + models = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] entry["status"] = "online" entry["model_count"] = len(models) except Exception as e: @@ -896,6 +950,7 @@ def setup_model_routes(model_discovery): for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: if base_url.endswith(suffix): base_url = base_url[:-len(suffix)].rstrip("/") + base_url = _normalize_base(base_url) if not base_url: raise HTTPException(400, "Base URL is required") # Resolve hostname via Tailscale if DNS fails diff --git a/routes/session_routes.py b/routes/session_routes.py index 18e0b18..7dd875e 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -227,6 +227,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ ) # Set auth headers for custom API-key endpoints resolved_key = api_key.strip() if api_key else "" + resolved_base = endpoint_url if not resolved_key and endpoint_id and endpoint_id.strip(): from core.database import ModelEndpoint _db = SessionLocal() @@ -234,10 +235,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first() if ep and ep.api_key: resolved_key = ep.api_key + resolved_base = ep.base_url finally: _db.close() if resolved_key: - session.headers = {"Authorization": f"Bearer {resolved_key}"} + from src.endpoint_resolver import build_headers + session.headers = build_headers(resolved_key, resolved_base) session_manager.save_sessions() # Fire webhook (sync-safe) if webhook_manager: diff --git a/routes/webhook_routes.py b/routes/webhook_routes.py index 8fc88fe..7eead00 100644 --- a/routes/webhook_routes.py +++ b/routes/webhook_routes.py @@ -157,6 +157,7 @@ def setup_webhook_routes( "groq": "https://api.groq.com/openai/v1", "together": "https://api.together.xyz/v1", "openrouter": "https://openrouter.ai/api/v1", + "ollama": "https://ollama.com/api", "fireworks": "https://api.fireworks.ai/inference/v1", } @@ -203,6 +204,7 @@ def setup_webhook_routes( from core.models import ChatMessage from src.llm_core import llm_call_async from core.database import ModelEndpoint + from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base message = body.message.strip() if not message: @@ -244,7 +246,8 @@ def setup_webhook_routes( "Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') " "or provider ('deepseek', 'openai', 'groq', etc.)") - endpoint_url = base_url + "/chat/completions" + base_url = normalize_base(base_url) + endpoint_url = build_chat_url(base_url) if not session_manager: raise HTTPException(500, "Session manager not available") @@ -254,7 +257,7 @@ def setup_webhook_routes( session_id=sid, name="API Chat", endpoint_url=endpoint_url, model=model, owner=token_owner, ) - sess.headers = {"Authorization": f"Bearer {api_key}"} + sess.headers = build_headers(api_key, base_url) session_manager.save_sessions() session_id = sid @@ -271,18 +274,26 @@ def setup_webhook_routes( "No session, api_key, or configured endpoints. " "Pass api_key + model, or configure an endpoint in Admin.") - endpoint_url = ep.base_url.rstrip("/") + "/chat/completions" + base_url = normalize_base(ep.base_url) + endpoint_url = build_chat_url(base_url) model = body.model or "auto" api_key = ep.api_key if model == "auto": try: async with httpx.AsyncClient(timeout=5) as client: - models_url = ep.base_url.rstrip("/") + "/models" - hdrs = {"Authorization": f"Bearer {api_key}"} if api_key else {} + models_url = build_models_url(base_url) + hdrs = build_headers(api_key, base_url) resp = await client.get(models_url, headers=hdrs) resp.raise_for_status() - ids = [m.get("id") for m in (resp.json().get("data") or []) if m.get("id")] + data = resp.json() + ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not ids: + ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] model = ids[0] if ids else "auto" except Exception: raise HTTPException(500, "Could not discover models from endpoint") @@ -296,7 +307,7 @@ def setup_webhook_routes( model=model, owner=token_owner, ) if api_key: - sess.headers = {"Authorization": f"Bearer {api_key}"} + sess.headers = build_headers(api_key, base_url) session_manager.save_sessions() session_id = sid diff --git a/src/agent_loop.py b/src/agent_loop.py index 2c42e9d..6b7d982 100644 --- a/src/agent_loop.py +++ b/src/agent_loop.py @@ -450,6 +450,7 @@ _API_HOSTS = frozenset([ "api.deepseek.com", "deepseek.com", "api.together.xyz", "api.fireworks.ai", "api.perplexity.ai", "api.x.ai", + "ollama.com", ]) _MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email", "gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"]) diff --git a/src/ai_interaction.py b/src/ai_interaction.py index 2db291a..9063ced 100644 --- a/src/ai_interaction.py +++ b/src/ai_interaction.py @@ -55,7 +55,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None): # Model resolution # --------------------------------------------------------------------------- -from src.endpoint_resolver import normalize_base as _normalize_base +from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url def _resolve_model(spec: str) -> Tuple[str, str, Dict]: @@ -95,9 +95,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]: for ep in endpoints: base = _normalize_base(ep.base_url) provider = _detect_provider(base) - headers = {} - if ep.api_key: - headers["Authorization"] = f"Bearer {ep.api_key}" + headers = build_headers(ep.api_key, base) if provider == "anthropic": # Anthropic: match against hardcoded model list @@ -107,27 +105,32 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]: matched = am break if matched: - headers["x-api-key"] = ep.api_key or "" - headers["anthropic-version"] = "2023-06-01" - return base + "/v1/messages", matched, headers + return build_chat_url(base), matched, headers else: - # OpenAI-compatible: probe /models + # OpenAI-compatible and native Ollama: probe the provider's model list. try: - r = httpx.get(base + "/models", headers=headers, timeout=5) + r = httpx.get(build_models_url(base), headers=headers, timeout=5) r.raise_for_status() - model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")] + data = r.json() + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] except Exception: model_ids = [] # Exact match first for mid in model_ids: if mid.lower() == model_name.lower(): - return base + "/chat/completions", mid, headers + return build_chat_url(base), mid, headers # Partial match for mid in model_ids: if model_name.lower() in mid.lower() or mid.lower() in model_name.lower(): - return base + "/chat/completions", mid, headers + return build_chat_url(base), mid, headers raise ValueError(f"Model '{spec}' not found on any configured endpoint") finally: @@ -1107,18 +1110,23 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict for ep in endpoints: base = _normalize_base(ep.base_url) provider = _detect_provider(base) - headers = {} - if ep.api_key: - headers["Authorization"] = f"Bearer {ep.api_key}" + headers = build_headers(ep.api_key, base) model_ids = [] if provider == "anthropic": model_ids = list(ANTHROPIC_MODELS) else: try: - r = httpx.get(base + "/models", headers=headers, timeout=5) + r = httpx.get(build_models_url(base), headers=headers, timeout=5) r.raise_for_status() - model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")] + data = r.json() + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] except Exception: model_ids = ["(endpoint offline)"] diff --git a/src/endpoint_resolver.py b/src/endpoint_resolver.py index df5eb7c..b204c7c 100644 --- a/src/endpoint_resolver.py +++ b/src/endpoint_resolver.py @@ -101,6 +101,9 @@ def normalize_base(url: str) -> str: for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: if url.endswith(suffix): url = url[: -len(suffix)].rstrip("/") + for suffix in ["/chat", "/tags", "/generate"]: + if url.endswith("/api" + suffix): + url = url[: -len(suffix)].rstrip("/") return url @@ -113,6 +116,20 @@ def _anthropic_api_root(base: str) -> str: return base +def _ollama_api_root(base: str) -> str: + """Return the native Ollama API root, adding /api for ollama.com hosts.""" + base = (base or "").strip().rstrip("/") + parsed = urlparse(base) + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if path.endswith("/api"): + return base + if host.endswith("ollama.com"): + root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" + return root.rstrip("/") + "/api" + return base + + def build_chat_url(base: str) -> str: """Return the correct chat endpoint URL for a given base.""" base = resolve_url(base) @@ -120,9 +137,23 @@ def build_chat_url(base: str) -> str: host = urlparse(base).hostname or "" if provider == "anthropic" or host.endswith("anthropic.com"): return _anthropic_api_root(base) + "/v1/messages" + if provider == "ollama" or host.endswith("ollama.com"): + return _ollama_api_root(base) + "/chat" return base + "/chat/completions" +def build_models_url(base: str) -> str: + """Return the provider-specific model-list endpoint URL for a base.""" + base = resolve_url(base) + provider = _detect_provider(base) + host = urlparse(base).hostname or "" + if provider == "anthropic" or host.endswith("anthropic.com"): + return _anthropic_api_root(base) + "/v1/models" + if provider == "ollama" or host.endswith("ollama.com"): + return _ollama_api_root(base) + "/tags" + return base + "/models" + + def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]: """Build auth headers for an endpoint.""" provider = _detect_provider(base) diff --git a/src/llm_core.py b/src/llm_core.py index 60b17b2..55af620 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -7,6 +7,7 @@ import logging import hashlib from fastapi import HTTPException from typing import Optional, Dict, List +from urllib.parse import urlparse logger = logging.getLogger(__name__) @@ -140,9 +141,82 @@ ANTHROPIC_MODELS = [ "claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5", ] + +def _is_ollama_native_url(url: str) -> bool: + """Return True for native Ollama API URLs, including Ollama Cloud.""" + try: + parsed = urlparse(url or "") + except Exception: + return False + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if host.endswith("ollama.com"): + return True + local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434 + return local_ollama_host and (path == "/api" or path.startswith("/api/")) + + +def _ollama_api_root(url: str) -> str: + """Return a native Ollama API root such as https://ollama.com/api.""" + url = (url or "").strip().rstrip("/") + parsed = urlparse(url) + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if path.endswith("/api/chat"): + return url[: -len("/chat")] + if path.endswith("/api/tags"): + return url[: -len("/tags")] + if path.endswith("/api/generate"): + return url[: -len("/generate")] + if path.endswith("/api"): + return url + if host.endswith("ollama.com"): + root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" + return root.rstrip("/") + "/api" + return url + + +def _normalize_ollama_url(url: str) -> str: + """Ensure a native Ollama URL points at /api/chat.""" + base = _ollama_api_root(url) + return base.rstrip("/") + "/chat" + + +def _build_ollama_payload( + model: str, + messages: List[Dict], + temperature: float, + max_tokens: int, + stream: bool = False, + tools: Optional[List[Dict]] = None, +) -> Dict: + payload: Dict = { + "model": model, + "messages": messages, + "stream": stream, + } + options: Dict = {} + if temperature is not None: + options["temperature"] = temperature + if max_tokens and max_tokens > 0: + options["num_predict"] = max_tokens + if options: + payload["options"] = options + if tools: + payload["tools"] = tools + return payload + + +def _parse_ollama_response(data: dict) -> str: + message = data.get("message") or {} + return message.get("content") or data.get("response") or "" + + def _detect_provider(url: str) -> str: """Detect API provider from URL.""" u = (url or "").lower() + if _is_ollama_native_url(url): + return "ollama" if "anthropic.com" in u: return "anthropic" if "openrouter.ai" in u: @@ -166,6 +240,7 @@ def _provider_label(url: str) -> str: """Human-friendly provider name for error messages.""" u = (url or "").lower() if "anthropic.com" in u: return "Anthropic" + if "ollama.com" in u: return "Ollama Cloud" if "api.x.ai" in u or "x.ai/" in u: return "xAI" if "openai.com" in u: return "OpenAI" if "openrouter.ai" in u: return "OpenRouter" @@ -396,19 +471,28 @@ def _normalize_anthropic_url(url: str) -> str: def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]: """List available model IDs from an endpoint.""" - if _detect_provider(base_chat_url) == "anthropic": + provider = _detect_provider(base_chat_url) + if provider == "anthropic": return list(ANTHROPIC_MODELS) try: h = {} if headers: h.update(headers) - r = httpx.get(base_chat_url.replace("/chat/completions", "/models"), headers=h, timeout=timeout) + if provider == "ollama": + models_url = _ollama_api_root(base_chat_url) + "/tags" + else: + models_url = base_chat_url.replace("/chat/completions", "/models") + r = httpx.get(models_url, headers=h, timeout=timeout) r.raise_for_status() data = r.json() - ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if ids: - return ids - return [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")] + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + return model_ids except Exception: try: if ":11434" in base_chat_url or "ollama" in base_chat_url.lower(): @@ -476,6 +560,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL target_url = _normalize_anthropic_url(url) h = _build_anthropic_headers(headers) payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens) + elif provider == "ollama": + target_url = _normalize_ollama_url(url) + payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False) else: target_url = url payload = { @@ -497,6 +584,8 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL try: if provider == "anthropic": response = _parse_anthropic_response(data) + elif provider == "ollama": + response = _parse_ollama_response(data) else: response = data["choices"][0]["message"]["content"] _set_cached_response(cache_key, response) @@ -583,6 +672,12 @@ async def llm_call_async( target_url = _normalize_anthropic_url(url) h = _build_anthropic_headers(headers) payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens) + elif provider == "ollama": + target_url = _normalize_ollama_url(url) + h = {"Content-Type": "application/json"} + if headers: + h.update(headers) + payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False) else: target_url = url h = _provider_headers(provider, headers) @@ -621,6 +716,8 @@ async def llm_call_async( try: if provider == "anthropic": response = _parse_anthropic_response(data) + elif provider == "ollama": + response = _parse_ollama_response(data) else: response = data["choices"][0]["message"]["content"] _set_cached_response(cache_key, response) @@ -673,6 +770,12 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl target_url = _normalize_anthropic_url(url) h = _build_anthropic_headers(headers) payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools) + elif provider == "ollama": + target_url = _normalize_ollama_url(url) + h = {"Content-Type": "application/json"} + if headers: + h.update(headers) + payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools) else: target_url = url payload = { @@ -699,6 +802,62 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl return note_model_activity(target_url, model) + # ── Native Ollama streaming ── + if provider == "ollama": + _ollama_tool_calls: List[Dict] = [] + try: + client = _get_http_client() + async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r: + _clear_host_dead(target_url) + if r.status_code != 200: + raw = (await r.aread()).decode(errors="replace") + friendly = _format_upstream_error(r.status_code, raw, target_url) + yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n' + return + async for line in r.aiter_lines(): + if not line: + continue + try: + j = json.loads(line) + except json.JSONDecodeError: + continue + message = j.get("message") or {} + thinking = message.get("thinking") or "" + if thinking: + yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n' + content = message.get("content") or "" + if content: + yield f'data: {json.dumps({"delta": content})}\n\n' + for tc in message.get("tool_calls") or []: + fn = tc.get("function") or {} + if fn.get("name"): + _ollama_tool_calls.append({ + "id": tc.get("id") or f"call_{len(_ollama_tool_calls)}", + "name": fn.get("name") or "", + "arguments": json.dumps(fn.get("arguments") or {}), + }) + if j.get("done"): + if _ollama_tool_calls: + yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n' + if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None: + yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n' + yield "data: [DONE]\n\n" + return + yield "data: [DONE]\n\n" + except (httpx.ConnectError, httpx.ConnectTimeout) as e: + _cooled = _mark_host_dead(target_url) + _tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry" + logger.warning(f"Ollama stream connect to {target_url} failed: {e}{_tail}") + yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n' + except httpx.ReadTimeout: + yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n' + except httpx.NetworkError: + yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n' + except Exception as e: + logger.error(f"Ollama stream error: {e}") + yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n' + return + # ── Anthropic streaming ── if provider == "anthropic": _anth_input_tokens = 0 diff --git a/src/teacher_escalation.py b/src/teacher_escalation.py index c93b709..4587c00 100644 --- a/src/teacher_escalation.py +++ b/src/teacher_escalation.py @@ -42,6 +42,7 @@ _SOTA_HOSTS = frozenset({ "api.together.xyz", "api.fireworks.ai", "api.perplexity.ai", "api.x.ai", "generativelanguage.googleapis.com", "api.groq.com", + "openrouter.ai", "ollama.com", }) diff --git a/static/index.html b/static/index.html index ab1607c..9d44cbb 100644 --- a/static/index.html +++ b/static/index.html @@ -2036,6 +2036,7 @@ DeepSeek OpenAI OpenRouter + Ollama Cloud Groq Mistral Together AI diff --git a/static/js/admin.js b/static/js/admin.js index 10947fb..4d15a4f 100644 --- a/static/js/admin.js +++ b/static/js/admin.js @@ -4,6 +4,7 @@ import uiModule from './ui.js'; import settingsModule from './settings.js'; import { providerLogo } from './providers.js'; +import { sortModelObjects } from './modelSort.js'; let initialized = false; let modalEl = null; @@ -216,7 +217,7 @@ async function _loadModelsForUser(username, allowedSet, privPanel) { return; } const allEmpty = allowedSet.size === 0; - listEl.innerHTML = allModels.map(m => { + listEl.innerHTML = sortModelObjects(allModels).map(m => { const checked = allEmpty || allowedSet.has(m.mid) ? 'checked' : ''; return ` @@ -377,6 +378,9 @@ async function loadEndpoints() { } }, 1500); } + if (settingsModule && typeof settingsModule.refreshAiModelEndpoints === 'function') { + settingsModule.refreshAiModelEndpoints(); + } try { const res = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); // Treat a non-OK response (e.g. 401/403 for non-admins, or backend @@ -552,17 +556,18 @@ async function loadEndpoints() { const res = await fetch(`/api/model-endpoints/${epId}/models`, { credentials: 'same-origin' }); const models = await res.json(); _stopSpin(); - if (!models.length) { panel.innerHTML = 'No models'; return; } - const hiddenSet = new Set(models.filter(m => m.is_hidden).map(m => m.id)); - const showSearch = models.length >= 8; + const sortedModels = sortModelObjects(models); + if (!sortedModels.length) { panel.innerHTML = 'No models'; return; } + const hiddenSet = new Set(sortedModels.filter(m => m.is_hidden).map(m => m.id)); + const showSearch = sortedModels.length >= 8; panel.innerHTML = ` Models - ${models.length - hiddenSet.size}/${models.length} enabled + ${sortedModels.length - hiddenSet.size}/${sortedModels.length} enabled All None - ${showSearch ? `` : ''}` + models.map(m => + ${showSearch ? `` : ''}` + sortedModels.map(m => ` @@ -623,6 +628,9 @@ async function _saveEpModelState(epId, panel) { const badge = row.querySelector('.admin-badge'); if (badge && !badge.classList.contains('admin-badge-off')) badge.textContent = `${total - hidden.length}/${total} models enabled`; } + if (settingsModule && typeof settingsModule.refreshAiModelEndpoints === 'function') { + settingsModule.refreshAiModelEndpoints(); + } } catch (e) { /* silent */ } } @@ -702,12 +710,19 @@ function initEndpointForm() { // Strip trailing paths that shouldn't be in a base URL u = u.replace(/\/v1\/(models|chat\/completions|completions|messages)\/?$/i, '/v1'); u = u.replace(/\/(models|chat\/completions|completions|v1\/messages)\/?$/i, ''); + u = u.replace(/\/api\/(chat|tags|generate)\/?$/i, '/api'); // Fix double /v1/v1 u = u.replace(/\/v1\/v1$/, '/v1'); // Strip query params and fragments u = u.split('?')[0].split('#')[0]; + try { + const parsed = new URL(u); + if (parsed.hostname.endsWith('ollama.com')) { + u = 'https://ollama.com/api'; + } + } catch(e) {} // Ensure /v1 suffix for bare host:port URLs (not cloud providers) - if (!u.includes('api.') && !u.includes('openrouter') && !u.endsWith('/v1')) { + if (!u.includes('api.') && !u.includes('openrouter') && !u.includes('ollama.com') && !u.endsWith('/v1')) { try { const parsed = new URL(u); if (!parsed.pathname || parsed.pathname === '/') { @@ -814,9 +829,13 @@ function initEndpointForm() { const fd = new FormData(); fd.append('base_url', url); if (apiKey) fd.append('api_key', apiKey); + if (provider.value && provider.selectedOptions && provider.selectedOptions[0]) { + fd.append('name', provider.selectedOptions[0].textContent.trim()); + } const epType = el('adm-epType'); if (epType) fd.append('model_type', epType.value); - fd.append('skip_probe', 'false'); + if (provider.value && /openrouter\.ai|ollama\.com/i.test(provider.value)) fd.append('require_models', 'true'); + else fd.append('skip_probe', 'false'); const res = await fetch('/api/model-endpoints', { method: 'POST', body: fd, credentials: 'same-origin' }); const d = await res.json(); if (res.ok) { diff --git a/static/js/assistant.js b/static/js/assistant.js index e3bcbe0..00ab90e 100644 --- a/static/js/assistant.js +++ b/static/js/assistant.js @@ -7,6 +7,7 @@ import uiModule from './ui.js'; import { selectSession } from './sessions.js'; +import { sortModelIds } from './modelSort.js'; const API = '/api/assistant'; @@ -250,9 +251,8 @@ function _renderSettingsBody(body, data, tzList) { try { const models = await _fetchJSON(`/api/model-endpoints/${ep.id}/models`); let mHTML = ''; - for (const m of (models.models || models || [])) { - const mid = typeof m === 'string' ? m : (m.id || m.name || ''); - if (!mid) continue; + const modelIds = (models.models || models || []).map(m => typeof m === 'string' ? m : (m.id || m.name || '')).filter(Boolean); + for (const mid of sortModelIds(modelIds)) { const sel = mid === crew.model ? ' selected' : ''; mHTML += `${_esc(mid.split('/').pop())}`; } diff --git a/static/js/compare/models.js b/static/js/compare/models.js index c4caf39..081b530 100644 --- a/static/js/compare/models.js +++ b/static/js/compare/models.js @@ -2,6 +2,7 @@ import Storage from '../storage.js'; import state from './state.js'; import uiModule from '../ui.js'; +import { sortModelObjects } from '../modelSort.js'; var escapeHtml = uiModule.esc; @@ -84,9 +85,9 @@ async function fetchModels() { }); }); } - state._fetchModelsCache = models; + state._fetchModelsCache = sortModelObjects(models); state._fetchModelsCacheTime = now; - return models; + return state._fetchModelsCache; } // ── Shuffle pool persistence ── diff --git a/static/js/editor/ai-models.js b/static/js/editor/ai-models.js index a300430..e4ec949 100644 --- a/static/js/editor/ai-models.js +++ b/static/js/editor/ai-models.js @@ -25,6 +25,7 @@ * }} deps */ import { state } from './state.js'; +import { sortModelIds } from '../modelSort.js'; // Heuristic classifier on a model id + endpoint name. A model can be: // - gen: text-to-image generation @@ -106,7 +107,7 @@ export function wireAIModelSelectors({ container, apiBase, openCookbookForImg2im for (const ep of endpoints) { if (!ep.is_enabled) continue; const hasListedModels = Array.isArray(ep.models) && ep.models.length; - const models = hasListedModels ? ep.models : ['']; + const models = hasListedModels ? sortModelIds(ep.models) : ['']; const isImageEndpoint = (ep.model_type || '').toLowerCase() === 'image'; // Image/inpaint endpoints can be called by URL even when their // /models cache is still empty, so don't strand a freshly served diff --git a/static/js/group.js b/static/js/group.js index ed98872..4445928 100644 --- a/static/js/group.js +++ b/static/js/group.js @@ -7,6 +7,7 @@ import chatRenderer from './chatRenderer.js'; import spinnerModule from './spinner.js'; import { providerLogo } from './providers.js'; import { PROMPT_TEMPLATES, getAllPresets } from './presets.js'; +import { sortModelObjects } from './modelSort.js'; let API_BASE = ''; let _active = false; @@ -55,7 +56,7 @@ function _initGroupTab() { result.push({ mid, display: display.split('/').pop(), url: item.url, endpointId: item.endpoint_id }); }); }); - _modelsCache = result; + _modelsCache = sortModelObjects(result); return result; } @@ -412,7 +413,7 @@ export async function showModelPicker() { result.push({ mid, display: display.split('/').pop(), url: item.url, endpointId: item.endpoint_id, epName: item.endpoint_name || '' }); }); }); - _cachedModels = result; + _cachedModels = sortModelObjects(result); return result; } diff --git a/static/js/modelPicker.js b/static/js/modelPicker.js index 725e35a..e0cd0b2 100644 --- a/static/js/modelPicker.js +++ b/static/js/modelPicker.js @@ -4,6 +4,7 @@ import { providerLogo } from './providers.js'; import uiModule from './ui.js'; import settingsModule from './settings.js'; +import { sortModelObjects } from './modelSort.js'; const API_BASE = window.location.origin; @@ -156,7 +157,7 @@ function _initModelPickerDropdown() { }); }); }); - return result; + return sortModelObjects(result); } function _populate(filter) { @@ -184,6 +185,8 @@ function _initModelPickerDropdown() { if (favs.includes(m.mid)) favModels.push(m); else restModels.push(m); }); + sortModelObjects(favModels).forEach(function(m, i) { favModels[i] = m; }); + sortModelObjects(restModels).forEach(function(m, i) { restModels[i] = m; }); function _addSection(label) { const el = document.createElement('div'); diff --git a/static/js/modelSort.js b/static/js/modelSort.js new file mode 100644 index 0000000..5d078d4 --- /dev/null +++ b/static/js/modelSort.js @@ -0,0 +1,29 @@ +// Shared alphabetical sorting for model pickers and dropdowns. + +function _sortText(value) { + return String(value || '').split('/').pop().trim() || String(value || ''); +} + +function _compareText(a, b) { + return _sortText(a).localeCompare(_sortText(b), undefined, { + numeric: true, + sensitivity: 'base', + }) || String(a || '').localeCompare(String(b || ''), undefined, { + numeric: true, + sensitivity: 'base', + }); +} + +export function sortModelIds(models) { + return (models || []).slice().sort(_compareText); +} + +export function compareModelObjects(a, b) { + const aLabel = a && (a.display || a.displayName || a.name || a.mid || a.id || a.model); + const bLabel = b && (b.display || b.displayName || b.name || b.mid || b.id || b.model); + return _compareText(aLabel, bLabel); +} + +export function sortModelObjects(models) { + return (models || []).slice().sort(compareModelObjects); +} diff --git a/static/js/models.js b/static/js/models.js index 1049a2c..3ed0ad0 100644 --- a/static/js/models.js +++ b/static/js/models.js @@ -11,6 +11,7 @@ import dragSortModule from './dragSort.js'; import spinnerModule from './spinner.js'; import { modelColor } from './chatRenderer.js'; import { providerLogo } from './providers.js'; +import { sortModelIds } from './modelSort.js'; let API_BASE = ''; let _cachedItems = []; // cached /api/models items for model-switch dropdown @@ -603,7 +604,7 @@ export async function refreshProviders() { if (openai) { const models = (openai.items?.[0]?.models) || []; - models.forEach(m => { + sortModelIds(models).forEach(m => { const opt = document.createElement('option'); opt.value = m; opt.textContent = m; diff --git a/static/js/providers.js b/static/js/providers.js index 53890bf..832bfc1 100644 --- a/static/js/providers.js +++ b/static/js/providers.js @@ -11,6 +11,14 @@ const _PROVIDERS = [ [/openai|gpt-|^o[13]-|chatgpt|dall-e/i, ''], + // OpenRouter + [/openrouter|open router/i, + ''], + + // Ollama / Ollama Cloud + [/ollama/i, + ''], + // Anthropic — Claude (official Simple Icons) [/anthropic|claude/i, ''], diff --git a/static/js/research/panel.js b/static/js/research/panel.js index 5777b6f..6893ec2 100644 --- a/static/js/research/panel.js +++ b/static/js/research/panel.js @@ -5,6 +5,7 @@ import * as jobs from './jobs.js'; import themeModule from '../theme.js'; import createResearchSynapse from '../researchSynapse.js'; import spinnerModule from '../spinner.js'; +import { sortModelIds } from '../modelSort.js'; // jobId -> { synapse, status } — survives across _renderJobs() rebuilds so // the SVG keeps its accumulated nodes/edges between progress events. @@ -637,7 +638,7 @@ function _populateModels(endpointId) { if (!endpointId) return; const ep = _endpoints.find(e => e.id === endpointId); if (!ep || !ep.models) return; - ep.models.forEach(m => { + sortModelIds(ep.models).forEach(m => { const opt = document.createElement('option'); opt.value = m; opt.textContent = m; diff --git a/static/js/settings.js b/static/js/settings.js index a43c9ec..c4e48be 100644 --- a/static/js/settings.js +++ b/static/js/settings.js @@ -5,6 +5,7 @@ import uiModule from './ui.js'; import searchModule from './search.js'; import { makeWindowDraggable } from './windowDrag.js'; import { clearDockSide } from './modalSnap.js'; +import { sortModelIds } from './modelSort.js'; let initialized = false; let modalEl = null; @@ -31,6 +32,7 @@ function initTabs() { // they flip toggles instead of having to close + reopen the modal. document.body.classList.toggle('settings-appearance-open', tab === 'appearance'); syncAppearanceOpacity(tab === 'appearance'); + if (tab === 'ai') refreshAiModelEndpoints(); }); }); } @@ -160,6 +162,93 @@ function initOpacityToggle() { AI TAB ═══════════════════════════════════════════ */ +const _aiEndpointRefreshers = new Set(); +let _aiEndpointRefreshInFlight = null; + +async function _fetchModelEndpoints() { + const epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); + const endpoints = await epRes.json(); + return Array.isArray(endpoints) ? endpoints : []; +} + +function _endpointLabel(ep) { + return ep.name + (ep.online ? '' : ' (offline)'); +} + +function _fillEndpointSelect(selectEl, endpoints, selected, keepBlank) { + if (!selectEl) return; + const previous = selected !== undefined ? selected : selectEl.value; + const blankText = keepBlank && selectEl.options[0] && selectEl.options[0].value === '' + ? selectEl.options[0].textContent + : null; + while (selectEl.options.length) selectEl.remove(0); + if (blankText !== null) { + const blank = document.createElement('option'); + blank.value = ''; + blank.textContent = blankText; + selectEl.appendChild(blank); + } + (endpoints || []).forEach(function(ep) { + if (!ep.is_enabled) return; + const opt = document.createElement('option'); + opt.value = ep.id; + opt.textContent = _endpointLabel(ep); + selectEl.appendChild(opt); + }); + if (previous && Array.from(selectEl.options).some(function(o) { return o.value === previous; })) { + selectEl.value = previous; + } else if (blankText !== null) { + selectEl.value = ''; + } +} + +function _fillModelSelect(selectEl, models, selected, keepBlank) { + if (!selectEl) return; + const previous = selected !== undefined ? selected : selectEl.value; + const blankText = keepBlank && selectEl.options[0] && selectEl.options[0].value === '' + ? selectEl.options[0].textContent + : null; + while (selectEl.options.length) selectEl.remove(0); + if (blankText !== null) { + const blank = document.createElement('option'); + blank.value = ''; + blank.textContent = blankText; + selectEl.appendChild(blank); + } + sortModelIds(models).forEach(function(m) { + const opt = document.createElement('option'); + opt.value = m; + opt.textContent = String(m).split('/').pop(); + selectEl.appendChild(opt); + }); + if (previous && Array.from(selectEl.options).some(function(o) { return o.value === previous; })) { + selectEl.value = previous; + } else if (blankText !== null) { + selectEl.value = ''; + } +} + +function _registerAiEndpointRefresh(fn) { + _aiEndpointRefreshers.add(fn); +} + +export async function refreshAiModelEndpoints() { + if (_aiEndpointRefreshInFlight) return _aiEndpointRefreshInFlight; + _aiEndpointRefreshInFlight = (async function() { + try { + const endpoints = await _fetchModelEndpoints(); + _aiEndpointRefreshers.forEach(function(fn) { + try { fn(endpoints); } catch (e) { console.warn('[settings] endpoint refresh handler failed', e); } + }); + } catch (e) { + console.warn('[settings] failed to refresh model endpoints', e); + } finally { + _aiEndpointRefreshInFlight = null; + } + })(); + return _aiEndpointRefreshInFlight; +} + /* Shared fallback-chain widget — mirrors the Default Chat Model fallback UI * for other model cards (Utility, Vision, …). Pass in the container/button * IDs, the endpoints list, the settings key to persist under, and the @@ -181,7 +270,7 @@ function _bindFallbackWidget(opts) { while (selectEl.options.length) selectEl.remove(0); var ep = (endpointsRef() || []).find(function(e) { return e.id === epId; }); if (ep && ep.models) { - ep.models.forEach(function(m) { + sortModelIds(ep.models).forEach(function(m) { if (!modelsFilter(m, ep)) return; var o = document.createElement('option'); o.value = m; @@ -270,6 +359,7 @@ function _bindFallbackWidget(opts) { return { setInitial: function(list) { current = (list || []).slice(); render(); }, + refresh: render, }; } @@ -289,31 +379,21 @@ async function initDefaultChat() { // Fill any with the models for a given endpoint id. function fillModels(selectEl, epId, selected) { - while (selectEl.options.length) selectEl.remove(0); var ep = _endpoints.find(function(e) { return e.id === epId; }); - if (ep && ep.models) { - ep.models.forEach(function(m) { - var opt = document.createElement('option'); - opt.value = m; - opt.textContent = m.split('/').pop(); - selectEl.appendChild(opt); - }); - } - if (selected) selectEl.value = selected; + _fillModelSelect(selectEl, ep ? ep.models : [], selected, false); } try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - _endpoints = await epRes.json(); - enabledEndpoints().forEach(function(ep) { - var opt = document.createElement('option'); - opt.value = ep.id; - opt.textContent = ep.name + (ep.online ? '' : ' (offline)'); - epSel.appendChild(opt); - }); + _endpoints = await _fetchModelEndpoints(); + _fillEndpointSelect(epSel, _endpoints, epSel.value, false); } catch (e) { console.warn('Failed to load endpoints for default chat', e); } function refreshModels(selectedModel) { fillModels(modelSel, epSel.value, selectedModel); } + function refreshEndpointOptions(selectedEndpoint, selectedModel) { + _fillEndpointSelect(epSel, _endpoints, selectedEndpoint !== undefined ? selectedEndpoint : epSel.value, false); + refreshModels(selectedModel !== undefined ? selectedModel : modelSel.value); + renderFallbacks(); + } // Render the fallback chain. Each row is endpoint + model + remove. function renderFallbacks() { @@ -409,6 +489,11 @@ async function initDefaultChat() { renderFallbacks(); saveDefault(); }); + + _registerAiEndpointRefresh(function(endpoints) { + _endpoints = endpoints; + refreshEndpointOptions(epSel.value, modelSel.value); + }); } /* ── Utility Model ── */ @@ -417,35 +502,19 @@ async function initUtilityModel() { var modelSel = el('set-utilityModelSelect'); var msg = el('set-utilityChatMsg'); var _endpoints = []; + var fallbackWidget = null; if (epSel && epSel.options[0]) epSel.options[0].textContent = 'Same as chat'; if (modelSel && modelSel.options[0]) modelSel.options[0].textContent = 'Same as chat'; try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - _endpoints = await epRes.json(); - _endpoints.forEach(function(ep) { - if (!ep.is_enabled) return; - var opt = document.createElement('option'); - opt.value = ep.id; - opt.textContent = ep.name + (ep.online ? '' : ' (offline)'); - epSel.appendChild(opt); - }); + _endpoints = await _fetchModelEndpoints(); + _fillEndpointSelect(epSel, _endpoints, epSel.value, true); } catch (e) { console.warn('Failed to load endpoints for utility model', e); } function refreshModels(selectedModel) { - while (modelSel.options.length > 1) modelSel.remove(1); var epId = epSel.value; - if (!epId) { modelSel.value = ''; return; } var ep = _endpoints.find(function(e) { return e.id === epId; }); - if (ep && ep.models) { - ep.models.forEach(function(m) { - var opt = document.createElement('option'); - opt.value = m; - opt.textContent = m.split('/').pop(); - modelSel.appendChild(opt); - }); - } - if (selectedModel) modelSel.value = selectedModel; + _fillModelSelect(modelSel, ep ? ep.models : [], selectedModel, true); } try { @@ -453,7 +522,7 @@ async function initUtilityModel() { var settings = await res.json(); if (settings.utility_endpoint_id) epSel.value = settings.utility_endpoint_id; refreshModels(settings.utility_model || ''); - _bindFallbackWidget({ + fallbackWidget = _bindFallbackWidget({ containerId: 'set-utilityFallbacks', addBtnId: 'set-utilityAddFallback', endpoints: function() { return _endpoints; }, @@ -483,6 +552,13 @@ async function initUtilityModel() { epSel.addEventListener('change', function() { refreshModels(''); saveUtility(); }); modelSel.addEventListener('change', saveUtility); + + _registerAiEndpointRefresh(function(endpoints) { + _endpoints = endpoints; + _fillEndpointSelect(epSel, _endpoints, epSel.value, true); + refreshModels(modelSel.value); + if (fallbackWidget && fallbackWidget.refresh) fallbackWidget.refresh(); + }); } /* ── Teacher Model ── */ @@ -501,31 +577,14 @@ async function initTeacherModel() { var _endpoints = []; try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - _endpoints = await epRes.json(); - _endpoints.forEach(function(ep) { - if (!ep.is_enabled) return; - var opt = document.createElement('option'); - opt.value = ep.id; - opt.textContent = ep.name + (ep.online ? '' : ' (offline)'); - epSel.appendChild(opt); - }); + _endpoints = await _fetchModelEndpoints(); + _fillEndpointSelect(epSel, _endpoints, epSel.value, true); } catch (e) { console.warn('Failed to load endpoints for teacher model', e); } function refreshModels(selectedModel) { - while (modelSel.options.length > 1) modelSel.remove(1); var epId = epSel.value; - if (!epId) { modelSel.value = ''; return; } var ep = _endpoints.find(function(e) { return e.id === epId; }); - if (ep && ep.models) { - ep.models.forEach(function(m) { - var opt = document.createElement('option'); - opt.value = m; - opt.textContent = m.split('/').pop(); - modelSel.appendChild(opt); - }); - } - if (selectedModel) modelSel.value = selectedModel; + _fillModelSelect(modelSel, ep ? ep.models : [], selectedModel, true); } // Disable / enable the endpoint+model dropdowns based on the @@ -595,6 +654,12 @@ async function initTeacherModel() { } epSel.addEventListener('change', function() { refreshModels(''); saveTeacher(); }); modelSel.addEventListener('change', saveTeacher); + + _registerAiEndpointRefresh(function(endpoints) { + _endpoints = endpoints; + _fillEndpointSelect(epSel, _endpoints, epSel.value, true); + refreshModels(modelSel.value); + }); } /* ── Image Generation ── */ @@ -624,7 +689,7 @@ async function initImageSettings() { if (_isInpaintModel(mid)) imageModels.push(mid); }); }); - imageModels.forEach(mid => { const opt = document.createElement('option'); opt.value = mid; opt.textContent = mid; modelSel.appendChild(opt); }); + sortModelIds(imageModels).forEach(mid => { const opt = document.createElement('option'); opt.value = mid; opt.textContent = mid; modelSel.appendChild(opt); }); // Hardcoded fallbacks shown as "(not detected)" so users know what to // download/serve to enable inpaint here. ['stable-diffusion-3.5-medium', 'stable-diffusion-inpainting'].forEach(mid => { @@ -666,6 +731,7 @@ async function initVisionSettings() { const enabledToggle = el('set-visionEnabledToggle'); const configWrap = vlSel ? vlSel.closest('div[style*="flex-direction"]') : null; var _visionEndpoints = []; + var visionFallbackWidget = null; var _vlExclude = ['audio', 'realtime', 'tts', 'dall-e', 'embedding', 'search', 'whisper']; function _isVisionModel(mid) { var lower = String(mid || '').toLowerCase(); @@ -674,27 +740,30 @@ async function initVisionSettings() { try { const modelsRes = await fetch('/api/models', { credentials: 'same-origin' }); const modelsData = await modelsRes.json(); + const visionModels = []; (modelsData.items || []).forEach(item => { if (item.offline) return; (item.models || []).forEach(mid => { if (_isVisionModel(mid)) { - var opt = document.createElement('option'); opt.value = mid; opt.textContent = mid; vlSel.appendChild(opt); + visionModels.push(mid); } }); }); + sortModelIds(visionModels).forEach(mid => { + var opt = document.createElement('option'); opt.value = mid; opt.textContent = mid; vlSel.appendChild(opt); + }); } catch (e) { console.warn('Failed to load models for vision settings', e); } // Also pull the raw endpoint list so the fallback widget can resolve // endpoint-id → models the same way the other cards do. try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - _visionEndpoints = await epRes.json(); + _visionEndpoints = await _fetchModelEndpoints(); } catch (e) { console.warn('Failed to load endpoints for vision fallback', e); } try { const settingsRes = await fetch('/api/auth/settings', { credentials: 'same-origin' }); const settings = await settingsRes.json(); if (settings.vision_model) vlSel.value = settings.vision_model; if (enabledToggle) enabledToggle.checked = settings.vision_enabled !== false; - _bindFallbackWidget({ + visionFallbackWidget = _bindFallbackWidget({ containerId: 'set-visionFallbacks', addBtnId: 'set-visionAddFallback', endpoints: function() { return _visionEndpoints; }, @@ -725,6 +794,11 @@ async function initVisionSettings() { } vlSel.addEventListener('change', saveSettings); if (enabledToggle) enabledToggle.addEventListener('change', function() { syncVisionDisabled(); saveSettings(); }); + + _registerAiEndpointRefresh(function(endpoints) { + _visionEndpoints = endpoints; + if (visionFallbackWidget && visionFallbackWidget.refresh) visionFallbackWidget.refresh(); + }); } /* ── Face Recognition ── */ @@ -1292,44 +1366,24 @@ async function initResearchSettings() { var modelSel = el('set-researchModel'); var tokensInput = el('set-researchMaxTokens'); var msg = el('set-researchMsg'); + var endpoints = []; try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - var endpoints = await epRes.json(); - endpoints.forEach(function(ep) { - if (!ep.is_enabled) return; - var opt = document.createElement('option'); - opt.value = ep.id; - opt.textContent = ep.name + (ep.online ? '' : ' (offline)'); - epSel.appendChild(opt); - }); + endpoints = await _fetchModelEndpoints(); + _fillEndpointSelect(epSel, endpoints, epSel.value, true); } catch (e) { console.warn('Failed to load endpoints for research', e); } - async function refreshModels(selectedModel) { + function refreshModels(selectedModel) { var epId = epSel.value; - while (modelSel.options.length > 1) modelSel.remove(1); - if (!epId) { modelSel.value = ''; return; } - try { - var epRes = await fetch('/api/model-endpoints', { credentials: 'same-origin' }); - var endpoints = await epRes.json(); - var ep = endpoints.find(function(e) { return e.id === epId; }); - if (ep && ep.models) { - ep.models.forEach(function(m) { - var opt = document.createElement('option'); - opt.value = m; - opt.textContent = m.split('/').pop(); - modelSel.appendChild(opt); - }); - } - if (selectedModel) modelSel.value = selectedModel; - } catch (e) { /* ignore */ } + var ep = endpoints.find(function(e) { return e.id === epId; }); + _fillModelSelect(modelSel, ep ? ep.models : [], selectedModel, true); } try { var res = await fetch('/api/auth/settings', { credentials: 'same-origin' }); var settings = await res.json(); if (settings.research_endpoint_id) epSel.value = settings.research_endpoint_id; - await refreshModels(settings.research_model || ''); + refreshModels(settings.research_model || ''); if (settings.research_max_tokens) tokensInput.value = settings.research_max_tokens; } catch (e) { console.warn('Failed to load research settings', e); } @@ -1371,11 +1425,17 @@ async function initResearchSettings() { } epSel.addEventListener('change', async function() { - await refreshModels(''); + refreshModels(''); saveResearch(); }); modelSel.addEventListener('change', saveResearch); tokensInput.addEventListener('change', saveResearch); + + _registerAiEndpointRefresh(function(nextEndpoints) { + endpoints = nextEndpoints; + _fillEndpointSelect(epSel, endpoints, epSel.value, true); + refreshModels(modelSel.value); + }); } /* ── Deep Research Search (Search tab) ── */ @@ -4202,6 +4262,7 @@ export function open(tab) { const activeTab = tab || (modalEl.querySelector('[data-settings-tab].active') || {}).dataset?.settingsTab || 'services'; document.body.classList.toggle('settings-appearance-open', activeTab === 'appearance'); syncAppearanceOpacity(activeTab === 'appearance'); + if (activeTab === 'ai') refreshAiModelEndpoints(); if (ADMIN_TABS.has(activeTab) && window.adminModule && !window.adminModule._initialized) { window.adminModule._initData(); } @@ -4226,7 +4287,7 @@ export function close() { } } -const settingsModule = { open, close, initIntegrations, initUnifiedIntegrations, syncAdminVisibility }; +const settingsModule = { open, close, initIntegrations, initUnifiedIntegrations, syncAdminVisibility, refreshAiModelEndpoints }; export default settingsModule; diff --git a/static/js/slashCommands.js b/static/js/slashCommands.js index 6485c29..81bb159 100644 --- a/static/js/slashCommands.js +++ b/static/js/slashCommands.js @@ -47,13 +47,14 @@ const SETUP_PROVIDER_URLS = { deepseek: { name: 'DeepSeek', url: 'https://api.deepseek.com/v1' }, openai: { name: 'OpenAI', url: 'https://api.openai.com/v1' }, openrouter: { name: 'OpenRouter', url: 'https://openrouter.ai/api/v1' }, + ollama: { name: 'Ollama Cloud', url: 'https://ollama.com/api' }, xai: { name: 'xAI', url: 'https://api.x.ai/v1' }, anthropic: { name: 'Anthropic', url: 'https://api.anthropic.com/v1' }, groq: { name: 'Groq', url: 'https://api.groq.com/openai/v1' }, gemini: { name: 'Gemini', url: 'https://generativelanguage.googleapis.com/v1beta/openai' }, google: { name: 'Gemini', url: 'https://generativelanguage.googleapis.com/v1beta/openai' }, }; -const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'xai', 'anthropic', 'groq', 'gemini']; +const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini']; const SETUP_PROVIDER_HINT = SETUP_PROVIDER_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_NAMES[SETUP_PROVIDER_NAMES.length - 1]; const SETUP_LOCAL_ICON = ''; const SETUP_API_ICON = ''; @@ -67,6 +68,8 @@ function _setupProviderFromInput(input) { openai: 'openai', chatgpt: 'openai', openrouter: 'openrouter', + ollama: 'ollama', + ollamacloud: 'ollama', anthropic: 'anthropic', claude: 'anthropic', groq: 'groq', @@ -84,6 +87,7 @@ function _extractSetupProviderCredential(input) { const providerAliases = [ ['deepseek ai', 'deepseek'], ['deepseek', 'deepseek'], ['open router', 'openrouter'], ['openrouter', 'openrouter'], + ['ollama cloud', 'ollama'], ['ollama', 'ollama'], ['open ai', 'openai'], ['openai', 'openai'], ['chatgpt', 'openai'], ['anthropic', 'anthropic'], ['claude', 'anthropic'], ['groq', 'groq'], @@ -488,8 +492,13 @@ function detectProvider(input) { for (const suffix of ['/models', '/chat/completions', '/completions', '/v1/messages']) { if (url.endsWith(suffix)) url = url.slice(0, -suffix.length).replace(/\/+$/, ''); } + url = url.replace(/\/api\/(chat|tags|generate)\/?$/i, '/api'); + try { + const parsed = new URL(url); + if (parsed.hostname.endsWith('ollama.com')) url = 'https://ollama.com/api'; + } catch(e) {} // Add /v1 if bare host:port - if (/^https?:\/\/[^/]+$/.test(url) && !url.includes('api.')) url += '/v1'; + if (/^https?:\/\/[^/]+$/.test(url) && !url.includes('api.') && !url.includes('ollama.com')) url += '/v1'; return { base_url: url, api_key: '', name: '' }; } // Known key patterns @@ -507,6 +516,13 @@ function detectProvider(input) { return null; } +function setupChatUrlForEndpoint(detected) { + const base = (detected.base_url || '').replace(/\/+$/, ''); + if (detected.name === 'Anthropic') return base.replace(/\/v1$/, '') + '/v1/messages'; + if (base.includes('ollama.com')) return 'https://ollama.com/api/chat'; + return base + '/chat/completions'; +} + async function connectDetectedSetupEndpoint(detected) { const providerLabel = detected.name || 'custom endpoint'; const chatBox = document.getElementById('chat-history'); @@ -555,7 +571,7 @@ async function connectDetectedSetupEndpoint(detected) { await typewriterReply(`Found ${count} model${count > 1 ? 's' : ''} on ${providerLabel}. Starting a chat...`); if (modelsModule) await modelsModule.refreshModels(true); const firstModel = data.models[0]; - const chatUrl = detected.base_url + (detected.name === 'Anthropic' ? '/v1/messages' : '/chat/completions'); + const chatUrl = setupChatUrlForEndpoint(detected); if (sessionModule) { await sessionModule.createDirectChat(chatUrl, firstModel, data.id); } diff --git a/static/js/tasks.js b/static/js/tasks.js index 119f341..5576c45 100644 --- a/static/js/tasks.js +++ b/static/js/tasks.js @@ -6,6 +6,7 @@ import uiModule from './ui.js'; import markdownModule from './markdown.js'; import * as spinnerModule from './spinner.js'; import { makeWindowDraggable } from './windowDrag.js'; +import { sortModelIds } from './modelSort.js'; const API_BASE = window.location.origin; let _open = false; @@ -1259,7 +1260,7 @@ function _showForm(existing, initTaskType, initTriggerType) { if (it.offline || !it.models || it.models.length === 0) continue; const group = document.createElement('optgroup'); group.label = it.endpoint_name || it.host || 'endpoint'; - const all = [...(it.models || []), ...(it.models_extra || [])]; + const all = sortModelIds([...(it.models || []), ...(it.models_extra || [])]); for (const m of all) { const opt = document.createElement('option'); opt.value = `${it.url}::${m}`; diff --git a/tests/test_endpoint_resolver.py b/tests/test_endpoint_resolver.py index 926f8b8..447aecd 100644 --- a/tests/test_endpoint_resolver.py +++ b/tests/test_endpoint_resolver.py @@ -11,15 +11,35 @@ def normalize_base(url: str) -> str: for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: if url.endswith(suffix): url = url[: -len(suffix)].rstrip("/") + for suffix in ["/chat", "/tags", "/generate"]: + if url.endswith("/api" + suffix): + url = url[: -len(suffix)].rstrip("/") return url def _detect_provider(url: str) -> str: + parsed = urlparse(url or "") + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if host.endswith("ollama.com") or (parsed.port == 11434 and (path == "/api" or path.startswith("/api/"))): + return "ollama" if "anthropic.com" in (url or ""): return "anthropic" return "openai" +def _ollama_api_root(base: str) -> str: + base = (base or "").strip().rstrip("/") + parsed = urlparse(base) + host = parsed.hostname or "" + path = (parsed.path or "").rstrip("/") + if path.endswith("/api"): + return base + if host.endswith("ollama.com"): + return f"{parsed.scheme}://{parsed.netloc}/api" + return base + + def build_chat_url(base: str) -> str: provider = _detect_provider(base) if provider == "anthropic": @@ -27,9 +47,18 @@ def build_chat_url(base: str) -> str: if host.endswith("anthropic.com") and base.rstrip("/").endswith("/v1"): base = base.rstrip("/")[:-3].rstrip("/") return base + "/v1/messages" + if provider == "ollama": + return _ollama_api_root(base) + "/chat" return base + "/chat/completions" +def build_models_url(base: str) -> str: + provider = _detect_provider(base) + if provider == "ollama": + return _ollama_api_root(base) + "/tags" + return base + "/models" + + def build_headers(api_key, base: str) -> dict: if not api_key: return {} @@ -52,6 +81,9 @@ class TestNormalizeBase: def test_strips_v1_messages(self): assert normalize_base("https://api.anthropic.com/v1/messages") == "https://api.anthropic.com" + def test_strips_ollama_native_chat(self): + assert normalize_base("https://ollama.com/api/chat") == "https://ollama.com/api" + def test_trailing_slash(self): assert normalize_base("https://api.openai.com/v1/") == "https://api.openai.com/v1" @@ -78,6 +110,20 @@ class TestBuildChatUrl: def test_local_endpoint(self): assert build_chat_url("http://localhost:8000/v1") == "http://localhost:8000/v1/chat/completions" + def test_ollama_cloud_native_api(self): + assert build_chat_url("https://ollama.com/api") == "https://ollama.com/api/chat" + + def test_ollama_cloud_root_adds_api(self): + assert build_chat_url("https://ollama.com") == "https://ollama.com/api/chat" + + +class TestBuildModelsUrl: + def test_openai_models(self): + assert build_models_url("https://api.openai.com/v1") == "https://api.openai.com/v1/models" + + def test_ollama_tags(self): + assert build_models_url("https://ollama.com/api") == "https://ollama.com/api/tags" + class TestBuildHeaders: def test_no_key(self): diff --git a/tests/test_llm_core_ollama.py b/tests/test_llm_core_ollama.py new file mode 100644 index 0000000..18b9819 --- /dev/null +++ b/tests/test_llm_core_ollama.py @@ -0,0 +1,43 @@ +"""Regression tests for native Ollama Cloud provider handling.""" +import httpx + +from src import llm_core + + +def test_detects_ollama_cloud_native_provider(): + assert llm_core._detect_provider("https://ollama.com/api") == "ollama" + assert llm_core._detect_provider("https://ollama.com/api/chat") == "ollama" + + +def test_llm_call_posts_native_ollama_payload(monkeypatch): + seen = {} + + def fake_post(url, headers=None, json=None, timeout=None): + seen["url"] = url + seen["headers"] = headers + seen["json"] = json + seen["timeout"] = timeout + request = httpx.Request("POST", url) + return httpx.Response( + 200, + request=request, + json={"message": {"content": "OK"}, "done": True}, + ) + + monkeypatch.setattr(llm_core.httpx, "post", fake_post) + + result = llm_core.llm_call( + "https://ollama.com/api", + "gpt-oss:120b-test", + [{"role": "user", "content": "Say OK"}], + temperature=0.2, + max_tokens=7, + headers={"Authorization": "Bearer ollama-key"}, + timeout=11, + ) + + assert result == "OK" + assert seen["url"] == "https://ollama.com/api/chat" + assert seen["headers"]["Authorization"] == "Bearer ollama-key" + assert seen["json"]["stream"] is False + assert seen["json"]["options"] == {"temperature": 0.2, "num_predict": 7} diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index e4c1405..f6b276d 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -65,6 +65,9 @@ class TestMatchProviderCurated: def test_xai_url(self): assert _match_provider_curated("https://api.x.ai/v1", "openai") == "xai" + def test_ollama_url(self): + assert _match_provider_curated("https://ollama.com/api", "openai") == "ollama" + def test_no_url_match_returns_provider(self): assert _match_provider_curated("https://localhost:1234", "openai") == "openai" @@ -263,6 +266,26 @@ class TestSetupProbeSafety: assert _probe_endpoint("https://api.anthropic.com/v1", "good-key") == ["claude-sonnet-4-5"] assert seen == ["https://api.anthropic.com/v1/models"] + def test_ollama_cloud_probe_uses_native_tags_endpoint(self, monkeypatch): + monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda url: url, raising=False) + monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/")) + seen = [] + + def fake_get(url, headers=None, timeout=None): + seen.append((url, headers)) + request = httpx.Request("GET", url) + response = httpx.Response( + 200, + request=request, + json={"models": [{"name": "gpt-oss:120b"}, {"model": "qwen3:235b"}]}, + ) + return response + + monkeypatch.setattr(model_routes.httpx, "get", fake_get) + + assert _probe_endpoint("https://ollama.com/api", "ollama-key") == ["gpt-oss:120b", "qwen3:235b"] + assert seen == [("https://ollama.com/api/tags", {"Authorization": "Bearer ollama-key"})] + def test_unkeyed_anthropic_probe_can_use_curated_fallback(self, monkeypatch): monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda url: url, raising=False) monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/"))