feat(ai): add OpenRouter and Ollama Cloud providers (#231)
Co-authored-by: Alex Kenley <Alex.Kenley@threatvectorsecurity.com>
This commit is contained in:
@@ -450,6 +450,7 @@ _API_HOSTS = frozenset([
|
||||
"api.deepseek.com", "deepseek.com",
|
||||
"api.together.xyz", "api.fireworks.ai",
|
||||
"api.perplexity.ai", "api.x.ai",
|
||||
"ollama.com",
|
||||
])
|
||||
_MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email",
|
||||
"gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"])
|
||||
|
||||
@@ -55,7 +55,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||
|
||||
|
||||
def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
@@ -95,9 +95,7 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
headers = {}
|
||||
if ep.api_key:
|
||||
headers["Authorization"] = f"Bearer {ep.api_key}"
|
||||
headers = build_headers(ep.api_key, base)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic: match against hardcoded model list
|
||||
@@ -107,27 +105,32 @@ def _resolve_model(spec: str) -> Tuple[str, str, Dict]:
|
||||
matched = am
|
||||
break
|
||||
if matched:
|
||||
headers["x-api-key"] = ep.api_key or ""
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
return base + "/v1/messages", matched, headers
|
||||
return build_chat_url(base), matched, headers
|
||||
else:
|
||||
# OpenAI-compatible: probe /models
|
||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||
try:
|
||||
r = httpx.get(base + "/models", headers=headers, timeout=5)
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")]
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
except Exception:
|
||||
model_ids = []
|
||||
|
||||
# Exact match first
|
||||
for mid in model_ids:
|
||||
if mid.lower() == model_name.lower():
|
||||
return base + "/chat/completions", mid, headers
|
||||
return build_chat_url(base), mid, headers
|
||||
|
||||
# Partial match
|
||||
for mid in model_ids:
|
||||
if model_name.lower() in mid.lower() or mid.lower() in model_name.lower():
|
||||
return base + "/chat/completions", mid, headers
|
||||
return build_chat_url(base), mid, headers
|
||||
|
||||
raise ValueError(f"Model '{spec}' not found on any configured endpoint")
|
||||
finally:
|
||||
@@ -1107,18 +1110,23 @@ async def do_list_models(content: str, session_id: Optional[str] = None) -> Dict
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
headers = {}
|
||||
if ep.api_key:
|
||||
headers["Authorization"] = f"Bearer {ep.api_key}"
|
||||
headers = build_headers(ep.api_key, base)
|
||||
|
||||
model_ids = []
|
||||
if provider == "anthropic":
|
||||
model_ids = list(ANTHROPIC_MODELS)
|
||||
else:
|
||||
try:
|
||||
r = httpx.get(base + "/models", headers=headers, timeout=5)
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
model_ids = [m.get("id") for m in (r.json().get("data") or []) if m.get("id")]
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
except Exception:
|
||||
model_ids = ["(endpoint offline)"]
|
||||
|
||||
|
||||
@@ -101,6 +101,9 @@ def normalize_base(url: str) -> str:
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
if url.endswith(suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
for suffix in ["/chat", "/tags", "/generate"]:
|
||||
if url.endswith("/api" + suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
return url
|
||||
|
||||
|
||||
@@ -113,6 +116,20 @@ def _anthropic_api_root(base: str) -> str:
|
||||
return base
|
||||
|
||||
|
||||
def _ollama_api_root(base: str) -> str:
|
||||
"""Return the native Ollama API root, adding /api for ollama.com hosts."""
|
||||
base = (base or "").strip().rstrip("/")
|
||||
parsed = urlparse(base)
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/api"):
|
||||
return base
|
||||
if host.endswith("ollama.com"):
|
||||
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
|
||||
return root.rstrip("/") + "/api"
|
||||
return base
|
||||
|
||||
|
||||
def build_chat_url(base: str) -> str:
|
||||
"""Return the correct chat endpoint URL for a given base."""
|
||||
base = resolve_url(base)
|
||||
@@ -120,9 +137,23 @@ def build_chat_url(base: str) -> str:
|
||||
host = urlparse(base).hostname or ""
|
||||
if provider == "anthropic" or host.endswith("anthropic.com"):
|
||||
return _anthropic_api_root(base) + "/v1/messages"
|
||||
if provider == "ollama" or host.endswith("ollama.com"):
|
||||
return _ollama_api_root(base) + "/chat"
|
||||
return base + "/chat/completions"
|
||||
|
||||
|
||||
def build_models_url(base: str) -> str:
|
||||
"""Return the provider-specific model-list endpoint URL for a base."""
|
||||
base = resolve_url(base)
|
||||
provider = _detect_provider(base)
|
||||
host = urlparse(base).hostname or ""
|
||||
if provider == "anthropic" or host.endswith("anthropic.com"):
|
||||
return _anthropic_api_root(base) + "/v1/models"
|
||||
if provider == "ollama" or host.endswith("ollama.com"):
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
return base + "/models"
|
||||
|
||||
|
||||
def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
"""Build auth headers for an endpoint."""
|
||||
provider = _detect_provider(base)
|
||||
|
||||
171
src/llm_core.py
171
src/llm_core.py
@@ -7,6 +7,7 @@ import logging
|
||||
import hashlib
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -140,9 +141,82 @@ ANTHROPIC_MODELS = [
|
||||
"claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5",
|
||||
]
|
||||
|
||||
|
||||
def _is_ollama_native_url(url: str) -> bool:
|
||||
"""Return True for native Ollama API URLs, including Ollama Cloud."""
|
||||
try:
|
||||
parsed = urlparse(url or "")
|
||||
except Exception:
|
||||
return False
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if host.endswith("ollama.com"):
|
||||
return True
|
||||
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
|
||||
return local_ollama_host and (path == "/api" or path.startswith("/api/"))
|
||||
|
||||
|
||||
def _ollama_api_root(url: str) -> str:
|
||||
"""Return a native Ollama API root such as https://ollama.com/api."""
|
||||
url = (url or "").strip().rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/api/chat"):
|
||||
return url[: -len("/chat")]
|
||||
if path.endswith("/api/tags"):
|
||||
return url[: -len("/tags")]
|
||||
if path.endswith("/api/generate"):
|
||||
return url[: -len("/generate")]
|
||||
if path.endswith("/api"):
|
||||
return url
|
||||
if host.endswith("ollama.com"):
|
||||
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
|
||||
return root.rstrip("/") + "/api"
|
||||
return url
|
||||
|
||||
|
||||
def _normalize_ollama_url(url: str) -> str:
|
||||
"""Ensure a native Ollama URL points at /api/chat."""
|
||||
base = _ollama_api_root(url)
|
||||
return base.rstrip("/") + "/chat"
|
||||
|
||||
|
||||
def _build_ollama_payload(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool = False,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
) -> Dict:
|
||||
payload: Dict = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
options: Dict = {}
|
||||
if temperature is not None:
|
||||
options["temperature"] = temperature
|
||||
if max_tokens and max_tokens > 0:
|
||||
options["num_predict"] = max_tokens
|
||||
if options:
|
||||
payload["options"] = options
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_ollama_response(data: dict) -> str:
|
||||
message = data.get("message") or {}
|
||||
return message.get("content") or data.get("response") or ""
|
||||
|
||||
|
||||
def _detect_provider(url: str) -> str:
|
||||
"""Detect API provider from URL."""
|
||||
u = (url or "").lower()
|
||||
if _is_ollama_native_url(url):
|
||||
return "ollama"
|
||||
if "anthropic.com" in u:
|
||||
return "anthropic"
|
||||
if "openrouter.ai" in u:
|
||||
@@ -166,6 +240,7 @@ def _provider_label(url: str) -> str:
|
||||
"""Human-friendly provider name for error messages."""
|
||||
u = (url or "").lower()
|
||||
if "anthropic.com" in u: return "Anthropic"
|
||||
if "ollama.com" in u: return "Ollama Cloud"
|
||||
if "api.x.ai" in u or "x.ai/" in u: return "xAI"
|
||||
if "openai.com" in u: return "OpenAI"
|
||||
if "openrouter.ai" in u: return "OpenRouter"
|
||||
@@ -396,19 +471,28 @@ def _normalize_anthropic_url(url: str) -> str:
|
||||
|
||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
||||
"""List available model IDs from an endpoint."""
|
||||
if _detect_provider(base_chat_url) == "anthropic":
|
||||
provider = _detect_provider(base_chat_url)
|
||||
if provider == "anthropic":
|
||||
return list(ANTHROPIC_MODELS)
|
||||
try:
|
||||
h = {}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
r = httpx.get(base_chat_url.replace("/chat/completions", "/models"), headers=h, timeout=timeout)
|
||||
if provider == "ollama":
|
||||
models_url = _ollama_api_root(base_chat_url) + "/tags"
|
||||
else:
|
||||
models_url = base_chat_url.replace("/chat/completions", "/models")
|
||||
r = httpx.get(models_url, headers=h, timeout=timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if ids:
|
||||
return ids
|
||||
return [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
return model_ids
|
||||
except Exception:
|
||||
try:
|
||||
if ":11434" in base_chat_url or "ollama" in base_chat_url.lower():
|
||||
@@ -476,6 +560,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
||||
elif provider == "ollama":
|
||||
target_url = _normalize_ollama_url(url)
|
||||
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
@@ -497,6 +584,8 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
|
||||
try:
|
||||
if provider == "anthropic":
|
||||
response = _parse_anthropic_response(data)
|
||||
elif provider == "ollama":
|
||||
response = _parse_ollama_response(data)
|
||||
else:
|
||||
response = data["choices"][0]["message"]["content"]
|
||||
_set_cached_response(cache_key, response)
|
||||
@@ -583,6 +672,12 @@ async def llm_call_async(
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
||||
elif provider == "ollama":
|
||||
target_url = _normalize_ollama_url(url)
|
||||
h = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False)
|
||||
else:
|
||||
target_url = url
|
||||
h = _provider_headers(provider, headers)
|
||||
@@ -621,6 +716,8 @@ async def llm_call_async(
|
||||
try:
|
||||
if provider == "anthropic":
|
||||
response = _parse_anthropic_response(data)
|
||||
elif provider == "ollama":
|
||||
response = _parse_ollama_response(data)
|
||||
else:
|
||||
response = data["choices"][0]["message"]["content"]
|
||||
_set_cached_response(cache_key, response)
|
||||
@@ -673,6 +770,12 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
|
||||
elif provider == "ollama":
|
||||
target_url = _normalize_ollama_url(url)
|
||||
h = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
@@ -699,6 +802,62 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
return
|
||||
note_model_activity(target_url, model)
|
||||
|
||||
# ── Native Ollama streaming ──
|
||||
if provider == "ollama":
|
||||
_ollama_tool_calls: List[Dict] = []
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
_clear_host_dead(target_url)
|
||||
if r.status_code != 200:
|
||||
raw = (await r.aread()).decode(errors="replace")
|
||||
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
||||
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||
return
|
||||
async for line in r.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
j = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
message = j.get("message") or {}
|
||||
thinking = message.get("thinking") or ""
|
||||
if thinking:
|
||||
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n'
|
||||
content = message.get("content") or ""
|
||||
if content:
|
||||
yield f'data: {json.dumps({"delta": content})}\n\n'
|
||||
for tc in message.get("tool_calls") or []:
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
_ollama_tool_calls.append({
|
||||
"id": tc.get("id") or f"call_{len(_ollama_tool_calls)}",
|
||||
"name": fn.get("name") or "",
|
||||
"arguments": json.dumps(fn.get("arguments") or {}),
|
||||
})
|
||||
if j.get("done"):
|
||||
if _ollama_tool_calls:
|
||||
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
|
||||
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
|
||||
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
yield "data: [DONE]\n\n"
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"Ollama stream connect to {target_url} failed: {e}{_tail}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||
except httpx.ReadTimeout:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||
except httpx.NetworkError:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama stream error: {e}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||
return
|
||||
|
||||
# ── Anthropic streaming ──
|
||||
if provider == "anthropic":
|
||||
_anth_input_tokens = 0
|
||||
|
||||
@@ -42,6 +42,7 @@ _SOTA_HOSTS = frozenset({
|
||||
"api.together.xyz", "api.fireworks.ai",
|
||||
"api.perplexity.ai", "api.x.ai",
|
||||
"generativelanguage.googleapis.com", "api.groq.com",
|
||||
"openrouter.ai", "ollama.com",
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user