diff --git a/routes/model_routes.py b/routes/model_routes.py index 4282285..935f3b9 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -14,62 +14,19 @@ from pydantic import BaseModel from fastapi.responses import StreamingResponse from core.database import SessionLocal, ModelEndpoint, Session as DbSession from core.middleware import require_admin -from src.llm_core import _detect_provider, ANTHROPIC_MODELS +from src.llm_core import _detect_provider, _host_match, ANTHROPIC_MODELS from src.settings import load_settings as _load_settings, save_settings as _save_settings -from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url +from src.endpoint_resolver import ( + normalize_base as _normalize_base, + build_chat_url, + build_models_url, + build_headers, +) from src.auth_helpers import owner_filter logger = logging.getLogger(__name__) -def _anthropic_api_root(base: str) -> str: - """Return Anthropic's API root without duplicating /v1.""" - base = (base or "").strip().rstrip("/") - host = urlparse(base).hostname or "" - if host.endswith("anthropic.com") and base.endswith("/v1"): - return base[:-3].rstrip("/") - return base - - -def _ollama_api_root(base: str) -> str: - """Return Ollama's native API root without depending on deferred imports.""" - base = (base or "").strip().rstrip("/") - parsed = urlparse(base) - host = parsed.hostname or "" - path = (parsed.path or "").rstrip("/") - if path.endswith("/api"): - return base - if host.endswith("ollama.com"): - root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" - return root.rstrip("/") + "/api" - return base - - -def _models_url(base: str) -> str: - """Return provider-specific model-list URL for route-local probing.""" - provider = _detect_provider(base) - host = urlparse(base).hostname or "" - if provider == "anthropic" or host.endswith("anthropic.com"): - return _anthropic_api_root(base) + "/v1/models" - if provider == "ollama" or host.endswith("ollama.com"): - return _ollama_api_root(base) + "/tags" - return base.rstrip("/") + "/models" - - -def _provider_headers(api_key: Optional[str], base: str) -> Dict[str, str]: - """Build provider auth headers without depending on import-time stubs.""" - if not api_key: - return {} - provider = _detect_provider(base) - host = urlparse(base).hostname or "" - if provider == "anthropic" or host.endswith("anthropic.com"): - return { - "x-api-key": api_key, - "anthropic-version": "2023-06-01", - } - return {"Authorization": f"Bearer {api_key}"} - - # ── Curated model lists per provider ── # For cloud providers that return 100+ models, only show these by default. # A model ID matches if it starts with or equals a curated entry. @@ -122,31 +79,35 @@ _PROVIDER_CURATED = { ], } -# Map URL substrings → curated-list keys for providers whose _detect_provider() +# Map hostnames → curated-list keys for providers whose _detect_provider() # returns a generic value (e.g. "openai") but deserve their own curated list. # "openrouter" is a sentinel meaning "no curation — show all models as curated". -_URL_TO_CURATED = { - "z.ai": "zai", - "api.deepseek.com": "deepseek", - "api.groq.com": "groq", - "api.mistral.ai": "mistral", - "api.together.xyz": "together", - "api.fireworks.ai": "fireworks", - "generativelanguage.googleapis.com": "google", - "api.x.ai": "xai", - "openrouter.ai": "openrouter", - "ollama.com": "ollama", -} +# Entries are matched by hostname equality or subdomain suffix (via _host_match), +# so e.g. "deepseek.com" covers api.deepseek.com without matching the substring +# inside an unrelated URL. +_HOST_TO_CURATED = ( + ("z.ai", "zai"), + ("deepseek.com", "deepseek"), + ("groq.com", "groq"), + ("mistral.ai", "mistral"), + ("together.xyz", "together"), + ("together.ai", "together"), + ("fireworks.ai", "fireworks"), + ("googleapis.com", "google"), + ("x.ai", "xai"), + ("openrouter.ai", "openrouter"), + ("ollama.com", "ollama"), +) def _match_provider_curated(base_url: str, provider: str) -> str: """Return the curated-list key for a given endpoint. - Checks the base URL against _URL_TO_CURATED first, then falls back - to the raw provider string from _detect_provider(). + Matches the base URL's hostname against known providers; falls back to + the raw provider string from _detect_provider(). """ - for substring, key in _URL_TO_CURATED.items(): - if substring in (base_url or ""): + for domain, key in _HOST_TO_CURATED: + if _host_match(base_url, domain): return key return provider @@ -235,12 +196,12 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1 elif provider == "ollama": from src.llm_core import _build_ollama_payload target_url = build_chat_url(base) - h = _provider_headers(api_key, base) + h = build_headers(api_key, base) h["Content-Type"] = "application/json" payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools) else: target_url = build_chat_url(base) - h = _provider_headers(api_key, base) + h = build_headers(api_key, base) h["Content-Type"] = "application/json" from src.llm_core import _uses_max_completion_tokens _max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens" @@ -308,7 +269,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis base = resolve_url(_normalize_base(base_url)) if _detect_provider(base) == "anthropic": # Try Anthropic's /v1/models endpoint first - url = _anthropic_api_root(base) + "/v1/models" + url = build_models_url(base) headers = {"anthropic-version": "2023-06-01"} if api_key: headers["x-api-key"] = api_key @@ -331,8 +292,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis return [] logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}") return list(ANTHROPIC_MODELS) - url = _models_url(base) - headers = _provider_headers(api_key, base) + url = build_models_url(base) + headers = build_headers(api_key, base) try: r = httpx.get(url, headers=headers, timeout=timeout) r.raise_for_status() @@ -746,8 +707,8 @@ def setup_model_routes(model_discovery): entry["error"] = str(e) entry["model_count"] = 0 else: - url = _models_url(base) - headers = _provider_headers(ep.api_key, base) + url = build_models_url(base) + headers = build_headers(ep.api_key, base) try: t0 = _time.time() r = httpx.get(url, headers=headers, timeout=5) @@ -971,11 +932,6 @@ def setup_model_routes(model_discovery): shared: str = Form("true"), ): require_admin(request) - base_url = base_url.strip().rstrip("/") - # Normalize: strip trailing /models, /chat/completions, /v1/messages etc to get clean base - for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: - if base_url.endswith(suffix): - base_url = base_url[:-len(suffix)].rstrip("/") base_url = _normalize_base(base_url) if not base_url: raise HTTPException(400, "Base URL is required") @@ -1085,10 +1041,7 @@ def setup_model_routes(model_discovery): api_key: str = Form(""), ): require_admin(request) - base_url = base_url.strip().rstrip("/") - for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: - if base_url.endswith(suffix): - base_url = base_url[:-len(suffix)].rstrip("/") + base_url = _normalize_base(base_url) if not base_url: raise HTTPException(400, "Base URL is required") from src.endpoint_resolver import resolve_url diff --git a/src/endpoint_resolver.py b/src/endpoint_resolver.py index aec81a8..f0cd163 100644 --- a/src/endpoint_resolver.py +++ b/src/endpoint_resolver.py @@ -12,7 +12,7 @@ from typing import Optional, Tuple, Dict from urllib.parse import urlparse, urlunparse from src.database import SessionLocal, ModelEndpoint -from src.llm_core import _detect_provider +from src.llm_core import _detect_provider, _host_match logger = logging.getLogger(__name__) @@ -145,8 +145,7 @@ def normalize_base(url: str) -> str: def _anthropic_api_root(base: str) -> str: """Return Anthropic's API root, preserving /v1 for OpenAI-compatible APIs elsewhere.""" base = (base or "").strip().rstrip("/") - host = urlparse(base).hostname or "" - if host.endswith("anthropic.com") and base.endswith("/v1"): + if _host_match(base, "anthropic.com") and base.endswith("/v1"): return base[:-3].rstrip("/") return base @@ -155,11 +154,10 @@ def _ollama_api_root(base: str) -> str: """Return the native Ollama API root, adding /api for ollama.com hosts.""" base = (base or "").strip().rstrip("/") parsed = urlparse(base) - host = parsed.hostname or "" path = (parsed.path or "").rstrip("/") if path.endswith("/api"): return base - if host.endswith("ollama.com"): + if _host_match(base, "ollama.com"): root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" return root.rstrip("/") + "/api" return base @@ -169,10 +167,9 @@ def build_chat_url(base: str) -> str: """Return the correct chat endpoint URL for a given base.""" base = resolve_url(base) provider = _detect_provider(base) - host = urlparse(base).hostname or "" - if provider == "anthropic" or host.endswith("anthropic.com"): + if provider == "anthropic": return _anthropic_api_root(base) + "/v1/messages" - if provider == "ollama" or host.endswith("ollama.com"): + if provider == "ollama": return _ollama_api_root(base) + "/chat" return base + "/chat/completions" @@ -181,10 +178,9 @@ def build_models_url(base: str) -> str: """Return the provider-specific model-list endpoint URL for a base.""" base = resolve_url(base) provider = _detect_provider(base) - host = urlparse(base).hostname or "" - if provider == "anthropic" or host.endswith("anthropic.com"): + if provider == "anthropic": return _anthropic_api_root(base) + "/v1/models" - if provider == "ollama" or host.endswith("ollama.com"): + if provider == "ollama": return _ollama_api_root(base) + "/tags" return base + "/models" @@ -231,24 +227,28 @@ def resolve_endpoint( except Exception: return fallback_url, fallback_model, fallback_headers - ep_id = (get_user_setting(f"{setting_prefix}_endpoint_id", owner or "", settings.get(f"{setting_prefix}_endpoint_id", "")) or "").strip() - model = (get_user_setting(f"{setting_prefix}_model", owner or "", settings.get(f"{setting_prefix}_model", "")) or "").strip() + owner_str = owner or "" + def _stg(key: str) -> str: + return (get_user_setting(key, owner_str, settings.get(key, "")) or "").strip() + + ep_id = _stg(f"{setting_prefix}_endpoint_id") + model = _stg(f"{setting_prefix}_model") # Unset Utility means "same as Default Chat Model". This keeps background # features usable out of the box and lets users override Utility only when # they explicitly want a separate cheaper/faster model. if setting_prefix == "utility" and not ep_id: - ep_id = (get_user_setting("default_endpoint_id", owner or "", settings.get("default_endpoint_id", "")) or "").strip() - model = (get_user_setting("default_model", owner or "", settings.get("default_model", "")) or "").strip() + ep_id = _stg("default_endpoint_id") + model = _stg("default_model") # Fall back to utility model for task/research/auto-naming if not specifically configured. # If Utility itself is unset, the block above makes that resolve to Default Chat. if not ep_id and setting_prefix != "utility": - ep_id = (get_user_setting("utility_endpoint_id", owner or "", settings.get("utility_endpoint_id", "")) or "").strip() - model = (get_user_setting("utility_model", owner or "", settings.get("utility_model", "")) or "").strip() + ep_id = _stg("utility_endpoint_id") + model = _stg("utility_model") if not ep_id: - ep_id = (get_user_setting("default_endpoint_id", owner or "", settings.get("default_endpoint_id", "")) or "").strip() - model = (get_user_setting("default_model", owner or "", settings.get("default_model", "")) or "").strip() + ep_id = _stg("default_endpoint_id") + model = _stg("default_model") if not ep_id: return fallback_url, fallback_model, fallback_headers @@ -342,7 +342,8 @@ def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list: try: from src.settings import get_user_setting, load_settings settings = load_settings() - if not (get_user_setting("utility_endpoint_id", owner or "", settings.get("utility_endpoint_id", "")) or "").strip(): + utility_ep = (get_user_setting("utility_endpoint_id", owner or "", settings.get("utility_endpoint_id", "")) or "").strip() + if not utility_ep: return _resolve_fallback_candidates("default_model_fallbacks", owner=owner) except Exception: pass diff --git a/src/llm_core.py b/src/llm_core.py index 0d4ddc5..f77f3bb 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -163,7 +163,7 @@ def _is_ollama_native_url(url: str) -> bool: return False host = parsed.hostname or "" path = (parsed.path or "").rstrip("/") - if host.endswith("ollama.com"): + if _host_match(url, "ollama.com"): return True local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434 return local_ollama_host and (path == "/api" or path.startswith("/api/")) @@ -173,7 +173,6 @@ def _ollama_api_root(url: str) -> str: """Return a native Ollama API root such as https://ollama.com/api.""" url = (url or "").strip().rstrip("/") parsed = urlparse(url) - host = parsed.hostname or "" path = (parsed.path or "").rstrip("/") if path.endswith("/api/chat"): return url[: -len("/chat")] @@ -183,7 +182,7 @@ def _ollama_api_root(url: str) -> str: return url[: -len("/generate")] if path.endswith("/api"): return url - if host.endswith("ollama.com"): + if _host_match(url, "ollama.com"): root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com" return root.rstrip("/") + "/api" return url @@ -225,16 +224,43 @@ def _parse_ollama_response(data: dict) -> str: return message.get("content") or data.get("response") or "" +def _host_match(url: str, *domains: str) -> bool: + """Return True if url's hostname equals any of `domains` or is a subdomain of one. + + Used by helpers that want "is this Anthropic?" / "is this OpenRouter?" + style checks. Prefer this over substring matching on the URL: the + substring form gives wrong answers for unrelated paths or query strings + that happen to contain the domain text. + """ + if not url: + return False + try: + # rstrip(".") so a fully-qualified host with a trailing dot + # ("api.anthropic.com.") still matches "anthropic.com". + host = (urlparse(url).hostname or "").lower().rstrip(".") + except Exception: + return False + if not host: + return False + return any(host == d or host.endswith("." + d) for d in domains) + + def _detect_provider(url: str) -> str: - """Detect API provider from URL.""" - u = (url or "").lower() + """Detect the API provider from a configured endpoint URL. + + Matches on hostname (exact or subdomain) rather than substring, so a URL + that merely contains a provider's domain in its path or query — or a + look-alike host such as ``anthropic.com.example`` — is not misclassified. + Unknown hosts fall back to the OpenAI-compatible default, which the + majority of providers implement. + """ if _is_ollama_native_url(url): return "ollama" - if "anthropic.com" in u: + if _host_match(url, "anthropic.com"): return "anthropic" - if "openrouter.ai" in u: + if _host_match(url, "openrouter.ai"): return "openrouter" - if "groq.com" in u: + if _host_match(url, "groq.com"): return "groq" return "openai" @@ -251,26 +277,27 @@ def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str def _provider_label(url: str) -> str: """Human-friendly provider name for error messages.""" - u = (url or "").lower() - if "anthropic.com" in u: return "Anthropic" - if "ollama.com" in u: return "Ollama Cloud" - if "api.x.ai" in u or "x.ai/" in u: return "xAI" - if "openai.com" in u: return "OpenAI" - if "openrouter.ai" in u: return "OpenRouter" - if "groq.com" in u: return "Groq" - if "mistral.ai" in u: return "Mistral" - if "deepseek.com" in u: return "DeepSeek" - if "googleapis.com" in u or "generativelanguage" in u: return "Google" - if "together.xyz" in u or "together.ai" in u: return "Together" - if "fireworks.ai" in u: return "Fireworks" - if "ollama" in u or ":11434" in u: return "Ollama" - if "localhost" in u or "127.0.0.1" in u: return "local endpoint" + if not url: + return "provider" + if _host_match(url, "anthropic.com"): return "Anthropic" + if _host_match(url, "ollama.com"): return "Ollama Cloud" + if _host_match(url, "x.ai"): return "xAI" + if _host_match(url, "openai.com"): return "OpenAI" + if _host_match(url, "openrouter.ai"): return "OpenRouter" + if _host_match(url, "groq.com"): return "Groq" + if _host_match(url, "mistral.ai"): return "Mistral" + if _host_match(url, "deepseek.com"): return "DeepSeek" + if _host_match(url, "googleapis.com"): return "Google" + if _host_match(url, "together.xyz", "together.ai"): return "Together" + if _host_match(url, "fireworks.ai"): return "Fireworks" + if _is_ollama_native_url(url): return "Ollama" try: - from urllib.parse import urlparse - host = urlparse(url).hostname or "provider" - return host + host = (urlparse(url).hostname or "").lower() except Exception: return "provider" + if host in {"localhost", "127.0.0.1", "::1", "0.0.0.0"}: + return "local endpoint" + return host or "provider" def _format_upstream_error(status: int, body: bytes | str, url: str) -> str: diff --git a/tests/test_provider_detection.py b/tests/test_provider_detection.py new file mode 100644 index 0000000..fb53291 --- /dev/null +++ b/tests/test_provider_detection.py @@ -0,0 +1,136 @@ +"""Provider detection tests (re: #768). + +These import the *real* helpers from ``src.llm_core`` (not local copies) so a +regression in hostname matching is actually caught. The point of the change +under test is that provider detection keys off the URL's *hostname*, not a +substring of the whole URL — so a domain appearing in a path/query, or a +look-alike host, must not be misclassified. +""" +import pytest + +from src import llm_core +from src import endpoint_resolver +from src.endpoint_resolver import build_chat_url, build_models_url + + +class TestHostMatch: + def test_exact_host(self): + assert llm_core._host_match("https://anthropic.com/v1", "anthropic.com") + + def test_subdomain(self): + assert llm_core._host_match("https://api.anthropic.com/v1", "anthropic.com") + + def test_multiple_domains(self): + assert llm_core._host_match("https://api.together.ai/v1", "together.xyz", "together.ai") + + def test_trailing_dot_fqdn(self): + # A fully-qualified host with a trailing dot is legal and resolvable. + assert llm_core._host_match("https://api.anthropic.com./v1", "anthropic.com") + + def test_domain_in_path_does_not_match(self): + assert not llm_core._host_match("https://myproxy.internal/anthropic.com/v1", "anthropic.com") + + def test_domain_in_query_does_not_match(self): + assert not llm_core._host_match("https://example.com/v1?ref=anthropic.com", "anthropic.com") + + def test_lookalike_host_does_not_match(self): + assert not llm_core._host_match("https://anthropic.com.example/v1", "anthropic.com") + + def test_none_and_empty_safe(self): + assert not llm_core._host_match(None, "anthropic.com") + assert not llm_core._host_match("", "anthropic.com") + + +class TestDetectProviderRealHosts: + def test_anthropic(self): + assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic" + + def test_openrouter(self): + assert llm_core._detect_provider("https://openrouter.ai/api/v1") == "openrouter" + + def test_groq_openai_compat_path(self): + # Groq's base carries an /openai/v1 path; detection must still see the host. + assert llm_core._detect_provider("https://api.groq.com/openai/v1") == "groq" + + def test_ollama_native_unchanged(self): + assert llm_core._detect_provider("https://ollama.com/api") == "ollama" + + def test_unknown_host_defaults_to_openai(self): + assert llm_core._detect_provider("https://api.example.com/v1") == "openai" + + +class TestDetectProviderRejectsSubstringFalsePositives: + """The regression that motivated #768: substring matching mislabeled these.""" + + def test_provider_domain_in_path(self): + assert llm_core._detect_provider("https://myproxy.internal/anthropic.com/v1") == "openai" + + def test_provider_domain_in_query(self): + assert llm_core._detect_provider("https://example.com/v1?ref=anthropic.com") == "openai" + + def test_lookalike_host(self): + assert llm_core._detect_provider("https://anthropic.com.example/v1") == "openai" + + def test_none_safe(self): + assert llm_core._detect_provider(None) == "openai" + + +class TestBuildersRejectLookalikeHosts: + """build_chat_url / build_models_url must route look-alike and + domain-in-path hosts to the OpenAI-compatible default, not the + anthropic/ollama branches. Before #815's follow-up these builders still + fell back to ``host.endswith("anthropic.com")`` style checks, so + ``notanthropic.com`` was misrouted to the Anthropic messages endpoint. + """ + + @pytest.fixture(autouse=True) + def _stub_dns(self, monkeypatch): + # build_* call resolve_url(), which does real DNS + tailscale lookups. + # Provider routing is independent of name resolution, so stub it out to + # keep these deterministic and offline. + monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u) + + def test_real_anthropic_chat(self): + assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages" + + def test_lookalike_anthropic_chat_is_openai(self): + assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions" + + def test_lookalike_anthropic_models_is_openai(self): + assert build_models_url("https://anthropic.com.evil.com") == "https://anthropic.com.evil.com/models" + + def test_anthropic_domain_in_path_is_openai(self): + assert build_chat_url("https://myproxy.internal/anthropic.com/v1") == "https://myproxy.internal/anthropic.com/v1/chat/completions" + + def test_real_ollama_chat(self): + assert build_chat_url("https://ollama.com") == "https://ollama.com/api/chat" + + def test_lookalike_ollama_chat_is_openai(self): + assert build_chat_url("https://notollama.com") == "https://notollama.com/chat/completions" + + def test_lookalike_ollama_models_is_openai(self): + assert build_models_url("https://notollama.com") == "https://notollama.com/models" + + +class TestBuildersLocalAndDockerEndpoints: + """Local and docker endpoints must keep working after the hostname change: + a local ``/v1`` base stays OpenAI-compatible, and a native Ollama ``/api`` + path is still detected by path even on a non-ollama.com host such as + host.docker.internal. + """ + + @pytest.fixture(autouse=True) + def _stub_dns(self, monkeypatch): + monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u) + + def test_local_v1_chat_is_openai_compatible(self): + assert build_chat_url("http://localhost:8000/v1") == "http://localhost:8000/v1/chat/completions" + + def test_local_v1_models_is_openai_compatible(self): + assert build_models_url("http://127.0.0.1:1234/v1") == "http://127.0.0.1:1234/v1/models" + + def test_docker_internal_ollama_api_path_is_native_chat(self): + assert build_chat_url("http://host.docker.internal:11434/api") == "http://host.docker.internal:11434/api/chat" + + def test_docker_internal_ollama_api_path_is_native_models(self): + assert build_models_url("http://host.docker.internal:11434/api") == "http://host.docker.internal:11434/api/tags"