feat(ai): add OpenRouter and Ollama Cloud providers (#231)

Co-authored-by: Alex Kenley <Alex.Kenley@threatvectorsecurity.com>
This commit is contained in:
Alexander Kenley
2026-06-01 15:26:10 +10:00
committed by GitHub
parent 4dbc0fe73a
commit 2c4b8b57dd
27 changed files with 699 additions and 169 deletions

View File

@@ -450,6 +450,7 @@ _API_HOSTS = frozenset([
"api.deepseek.com", "deepseek.com",
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"ollama.com",
])
_MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email",
"gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"])

View File

@@ -55,7 +55,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
# Model resolution
# ---------------------------------------------------------------------------
from src.endpoint_resolver import normalize_base as _normalize_base
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
@@ -95,9 +95,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
for ep in endpoints:
base = _normalize_base(ep.base_url)
provider = _detect_provider(base)
headers = {}
if ep.api_key:
headers["Authorization"] = f"Bearer {ep.api_key}"
headers = build_headers(ep.api_key, base)
if provider == "anthropic":
# Anthropic: match against hardcoded model list
@@ -107,27 +105,32 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
matched = am
break
if matched:
headers["x-api-key"] = ep.api_key or ""
headers["anthropic-version"] = "2023-06-01"
return base + "/v1/messages", matched, headers
return build_chat_url(base), matched, headers
else:
# OpenAI-compatible: probe /models
# OpenAI-compatible and native Ollama: probe the provider's model list.
try:
r = httpx.get(base + "/models", headers=headers, timeout=5)
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")]
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
except Exception:
model_ids = []
# Exact match first
for mid in model_ids:
if mid.lower() == model_name.lower():
return base + "/chat/completions", mid, headers
return build_chat_url(base), mid, headers
# Partial match
for mid in model_ids:
if model_name.lower() in mid.lower() or mid.lower() in model_name.lower():
return base + "/chat/completions", mid, headers
return build_chat_url(base), mid, headers
raise ValueError(f"Model '{spec}' not found on any configured endpoint")
finally:
@@ -1107,18 +1110,23 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict
for ep in endpoints:
base = _normalize_base(ep.base_url)
provider = _detect_provider(base)
headers = {}
if ep.api_key:
headers["Authorization"] = f"Bearer {ep.api_key}"
headers = build_headers(ep.api_key, base)
model_ids = []
if provider == "anthropic":
model_ids = list(ANTHROPIC_MODELS)
else:
try:
r = httpx.get(base + "/models", headers=headers, timeout=5)
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")]
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
except Exception:
model_ids = ["(endpoint offline)"]

View File

@@ -101,6 +101,9 @@ def normalize_base(url: str) -> str:
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
if url.endswith(suffix):
url = url[: -len(suffix)].rstrip("/")
for suffix in ["/chat", "/tags", "/generate"]:
if url.endswith("/api" + suffix):
url = url[: -len(suffix)].rstrip("/")
return url
@@ -113,6 +116,20 @@ def _anthropic_api_root(base: str) -> str:
return base
def _ollama_api_root(base: str) -> str:
"""Return the native Ollama API root, adding /api for ollama.com hosts."""
base = (base or "").strip().rstrip("/")
parsed = urlparse(base)
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if path.endswith("/api"):
return base
if host.endswith("ollama.com"):
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
return root.rstrip("/") + "/api"
return base
def build_chat_url(base: str) -> str:
"""Return the correct chat endpoint URL for a given base."""
base = resolve_url(base)
@@ -120,9 +137,23 @@ def build_chat_url(base: str) -> str:
host = urlparse(base).hostname or ""
if provider == "anthropic" or host.endswith("anthropic.com"):
return _anthropic_api_root(base) + "/v1/messages"
if provider == "ollama" or host.endswith("ollama.com"):
return _ollama_api_root(base) + "/chat"
return base + "/chat/completions"
def build_models_url(base: str) -> str:
"""Return the provider-specific model-list endpoint URL for a base."""
base = resolve_url(base)
provider = _detect_provider(base)
host = urlparse(base).hostname or ""
if provider == "anthropic" or host.endswith("anthropic.com"):
return _anthropic_api_root(base) + "/v1/models"
if provider == "ollama" or host.endswith("ollama.com"):
return _ollama_api_root(base) + "/tags"
return base + "/models"
def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
"""Build auth headers for an endpoint."""
provider = _detect_provider(base)

View File

@@ -7,6 +7,7 @@ import logging
import hashlib
from fastapi import HTTPException
from typing import Optional, Dict, List
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
@@ -140,9 +141,82 @@ ANTHROPIC_MODELS = [
"claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5",
]
def _is_ollama_native_url(url: str) -> bool:
"""Return True for native Ollama API URLs, including Ollama Cloud."""
try:
parsed = urlparse(url or "")
except Exception:
return False
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if host.endswith("ollama.com"):
return True
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
return local_ollama_host and (path == "/api" or path.startswith("/api/"))
def _ollama_api_root(url: str) -> str:
"""Return a native Ollama API root such as https://ollama.com/api."""
url = (url or "").strip().rstrip("/")
parsed = urlparse(url)
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if path.endswith("/api/chat"):
return url[: -len("/chat")]
if path.endswith("/api/tags"):
return url[: -len("/tags")]
if path.endswith("/api/generate"):
return url[: -len("/generate")]
if path.endswith("/api"):
return url
if host.endswith("ollama.com"):
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
return root.rstrip("/") + "/api"
return url
def _normalize_ollama_url(url: str) -> str:
"""Ensure a native Ollama URL points at /api/chat."""
base = _ollama_api_root(url)
return base.rstrip("/") + "/chat"
def _build_ollama_payload(
model: str,
messages: List[Dict],
temperature: float,
max_tokens: int,
stream: bool = False,
tools: Optional[List[Dict]] = None,
) -> Dict:
payload: Dict = {
"model": model,
"messages": messages,
"stream": stream,
}
options: Dict = {}
if temperature is not None:
options["temperature"] = temperature
if max_tokens and max_tokens > 0:
options["num_predict"] = max_tokens
if options:
payload["options"] = options
if tools:
payload["tools"] = tools
return payload
def _parse_ollama_response(data: dict) -> str:
message = data.get("message") or {}
return message.get("content") or data.get("response") or ""
def _detect_provider(url: str) -> str:
"""Detect API provider from URL."""
u = (url or "").lower()
if _is_ollama_native_url(url):
return "ollama"
if "anthropic.com" in u:
return "anthropic"
if "openrouter.ai" in u:
@@ -166,6 +240,7 @@ def _provider_label(url: str) -> str:
"""Human-friendly provider name for error messages."""
u = (url or "").lower()
if "anthropic.com" in u: return "Anthropic"
if "ollama.com" in u: return "Ollama Cloud"
if "api.x.ai" in u or "x.ai/" in u: return "xAI"
if "openai.com" in u: return "OpenAI"
if "openrouter.ai" in u: return "OpenRouter"
@@ -396,19 +471,28 @@ def _normalize_anthropic_url(url: str) -> str:
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
"""List available model IDs from an endpoint."""
if _detect_provider(base_chat_url) == "anthropic":
provider = _detect_provider(base_chat_url)
if provider == "anthropic":
return list(ANTHROPIC_MODELS)
try:
h = {}
if headers:
h.update(headers)
r = httpx.get(base_chat_url.replace("/chat/completions", "/models"), headers=h, timeout=timeout)
if provider == "ollama":
models_url = _ollama_api_root(base_chat_url) + "/tags"
else:
models_url = base_chat_url.replace("/chat/completions", "/models")
r = httpx.get(models_url, headers=h, timeout=timeout)
r.raise_for_status()
data = r.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if ids:
return ids
return [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
return model_ids
except Exception:
try:
if ":11434" in base_chat_url or "ollama" in base_chat_url.lower():
@@ -476,6 +560,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
else:
target_url = url
payload = {
@@ -497,6 +584,8 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
try:
if provider == "anthropic":
response = _parse_anthropic_response(data)
elif provider == "ollama":
response = _parse_ollama_response(data)
else:
response = data["choices"][0]["message"]["content"]
_set_cached_response(cache_key, response)
@@ -583,6 +672,12 @@ async def llm_call_async(
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
h = {"Content-Type": "application/json"}
if headers:
h.update(headers)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
else:
target_url = url
h = _provider_headers(provider, headers)
@@ -621,6 +716,8 @@ async def llm_call_async(
try:
if provider == "anthropic":
response = _parse_anthropic_response(data)
elif provider == "ollama":
response = _parse_ollama_response(data)
else:
response = data["choices"][0]["message"]["content"]
_set_cached_response(cache_key, response)
@@ -673,6 +770,12 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
h = {"Content-Type": "application/json"}
if headers:
h.update(headers)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
else:
target_url = url
payload = {
@@ -699,6 +802,62 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
return
note_model_activity(target_url, model)
# ── Native Ollama streaming ──
if provider == "ollama":
_ollama_tool_calls: List[Dict] = []
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
_clear_host_dead(target_url)
if r.status_code != 200:
raw = (await r.aread()).decode(errors="replace")
friendly = _format_upstream_error(r.status_code, raw, target_url)
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
return
async for line in r.aiter_lines():
if not line:
continue
try:
j = json.loads(line)
except json.JSONDecodeError:
continue
message = j.get("message") or {}
thinking = message.get("thinking") or ""
if thinking:
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n'
content = message.get("content") or ""
if content:
yield f'data: {json.dumps({"delta": content})}\n\n'
for tc in message.get("tool_calls") or []:
fn = tc.get("function") or {}
if fn.get("name"):
_ollama_tool_calls.append({
"id": tc.get("id") or f"call_{len(_ollama_tool_calls)}",
"name": fn.get("name") or "",
"arguments": json.dumps(fn.get("arguments") or {}),
})
if j.get("done"):
if _ollama_tool_calls:
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
yield "data: [DONE]\n\n"
return
yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url)
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"Ollama stream connect to {target_url} failed: {e}{_tail}")
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
except httpx.ReadTimeout:
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
except httpx.NetworkError:
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
except Exception as e:
logger.error(f"Ollama stream error: {e}")
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
return
# ── Anthropic streaming ──
if provider == "anthropic":
_anth_input_tokens = 0

View File

@@ -42,6 +42,7 @@ _SOTA_HOSTS = frozenset({
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"generativelanguage.googleapis.com", "api.groq.com",
"openrouter.ai", "ollama.com",
})