feat(ai): add OpenRouter and Ollama Cloud providers (#231)

Co-authored-by: Alex Kenley <Alex.Kenley@threatvectorsecurity.com>
This commit is contained in:
Alexander Kenley
2026-06-01 15:26:10 +10:00
committed by GitHub
parent 4dbc0fe73a
commit 2c4b8b57dd
27 changed files with 699 additions and 169 deletions

View File

@@ -7,6 +7,7 @@ import logging
import hashlib
from fastapi import HTTPException
from typing import Optional, Dict, List
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
@@ -140,9 +141,82 @@ ANTHROPIC_MODELS = [
"claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5",
]
def _is_ollama_native_url(url: str) -> bool:
"""Return True for native Ollama API URLs, including Ollama Cloud."""
try:
parsed = urlparse(url or "")
except Exception:
return False
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if host.endswith("ollama.com"):
return True
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
return local_ollama_host and (path == "/api" or path.startswith("/api/"))
def _ollama_api_root(url: str) -> str:
"""Return a native Ollama API root such as https://ollama.com/api."""
url = (url or "").strip().rstrip("/")
parsed = urlparse(url)
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if path.endswith("/api/chat"):
return url[: -len("/chat")]
if path.endswith("/api/tags"):
return url[: -len("/tags")]
if path.endswith("/api/generate"):
return url[: -len("/generate")]
if path.endswith("/api"):
return url
if host.endswith("ollama.com"):
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
return root.rstrip("/") + "/api"
return url
def _normalize_ollama_url(url: str) -> str:
"""Ensure a native Ollama URL points at /api/chat."""
base = _ollama_api_root(url)
return base.rstrip("/") + "/chat"
def _build_ollama_payload(
model: str,
messages: List[Dict],
temperature: float,
max_tokens: int,
stream: bool = False,
tools: Optional[List[Dict]] = None,
) -> Dict:
payload: Dict = {
"model": model,
"messages": messages,
"stream": stream,
}
options: Dict = {}
if temperature is not None:
options["temperature"] = temperature
if max_tokens and max_tokens > 0:
options["num_predict"] = max_tokens
if options:
payload["options"] = options
if tools:
payload["tools"] = tools
return payload
def _parse_ollama_response(data: dict) -> str:
message = data.get("message") or {}
return message.get("content") or data.get("response") or ""
def _detect_provider(url: str) -> str:
"""Detect API provider from URL."""
u = (url or "").lower()
if _is_ollama_native_url(url):
return "ollama"
if "anthropic.com" in u:
return "anthropic"
if "openrouter.ai" in u:
@@ -166,6 +240,7 @@ def _provider_label(url: str) -> str:
"""Human-friendly provider name for error messages."""
u = (url or "").lower()
if "anthropic.com" in u: return "Anthropic"
if "ollama.com" in u: return "Ollama Cloud"
if "api.x.ai" in u or "x.ai/" in u: return "xAI"
if "openai.com" in u: return "OpenAI"
if "openrouter.ai" in u: return "OpenRouter"
@@ -396,19 +471,28 @@ def _normalize_anthropic_url(url: str) -> str:
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
"""List available model IDs from an endpoint."""
if _detect_provider(base_chat_url) == "anthropic":
provider = _detect_provider(base_chat_url)
if provider == "anthropic":
return list(ANTHROPIC_MODELS)
try:
h = {}
if headers:
h.update(headers)
r = httpx.get(base_chat_url.replace("/chat/completions", "/models"), headers=h, timeout=timeout)
if provider == "ollama":
models_url = _ollama_api_root(base_chat_url) + "/tags"
else:
models_url = base_chat_url.replace("/chat/completions", "/models")
r = httpx.get(models_url, headers=h, timeout=timeout)
r.raise_for_status()
data = r.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if ids:
return ids
return [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
return model_ids
except Exception:
try:
if ":11434" in base_chat_url or "ollama" in base_chat_url.lower():
@@ -476,6 +560,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
else:
target_url = url
payload = {
@@ -497,6 +584,8 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
try:
if provider == "anthropic":
response = _parse_anthropic_response(data)
elif provider == "ollama":
response = _parse_ollama_response(data)
else:
response = data["choices"][0]["message"]["content"]
_set_cached_response(cache_key, response)
@@ -583,6 +672,12 @@ async def llm_call_async(
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
h = {"Content-Type": "application/json"}
if headers:
h.update(headers)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
else:
target_url = url
h = _provider_headers(provider, headers)
@@ -621,6 +716,8 @@ async def llm_call_async(
try:
if provider == "anthropic":
response = _parse_anthropic_response(data)
elif provider == "ollama":
response = _parse_ollama_response(data)
else:
response = data["choices"][0]["message"]["content"]
_set_cached_response(cache_key, response)
@@ -673,6 +770,12 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
elif provider == "ollama":
target_url = _normalize_ollama_url(url)
h = {"Content-Type": "application/json"}
if headers:
h.update(headers)
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
else:
target_url = url
payload = {
@@ -699,6 +802,62 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
return
note_model_activity(target_url, model)
# ── Native Ollama streaming ──
if provider == "ollama":
_ollama_tool_calls: List[Dict] = []
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
_clear_host_dead(target_url)
if r.status_code != 200:
raw = (await r.aread()).decode(errors="replace")
friendly = _format_upstream_error(r.status_code, raw, target_url)
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
return
async for line in r.aiter_lines():
if not line:
continue
try:
j = json.loads(line)
except json.JSONDecodeError:
continue
message = j.get("message") or {}
thinking = message.get("thinking") or ""
if thinking:
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n'
content = message.get("content") or ""
if content:
yield f'data: {json.dumps({"delta": content})}\n\n'
for tc in message.get("tool_calls") or []:
fn = tc.get("function") or {}
if fn.get("name"):
_ollama_tool_calls.append({
"id": tc.get("id") or f"call_{len(_ollama_tool_calls)}",
"name": fn.get("name") or "",
"arguments": json.dumps(fn.get("arguments") or {}),
})
if j.get("done"):
if _ollama_tool_calls:
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
yield "data: [DONE]\n\n"
return
yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url)
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"Ollama stream connect to {target_url} failed: {e}{_tail}")
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
except httpx.ReadTimeout:
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
except httpx.NetworkError:
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
except Exception as e:
logger.error(f"Ollama stream error: {e}")
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
return
# ── Anthropic streaming ──
if provider == "anthropic":
_anth_input_tokens = 0