* feat(provider): add GitHub Copilot provider with device-flow auth
Adds GitHub Copilot as a model provider, so Copilot models (gpt-4o/4.1/5,
Claude, Gemini, …) work through the normal chat + agent loop, incl. native
tool calling and vision.
Auth is one-click via the GitHub OAuth device flow; the access token is stored
as the endpoint's (encrypted) api_key and sent directly as `Authorization:
Bearer` (no Copilot-token exchange, no refresh — matching how editors talk to
the Copilot API). Copilot is a normal ModelEndpoint detected by host; the only
provider-specific behaviour is a small set of required request headers,
injected centrally.
Sign-in is available from Settings → model endpoints ("Connect GitHub
Copilot") and from chat via `/setup copilot`.
- src/copilot.py (new), routes/copilot_routes.py (new): constants, header
builders, device-flow start/poll, model discovery, owner-scoped endpoint
provisioning.
- src/llm_core.py, src/endpoint_resolver.py: detect `copilot`, inject headers,
per-request x-initiator/vision.
- src/agent_loop.py: allowlist api.githubcopilot.com for native tool schemas.
- src/model_context.py: known context windows for Copilot (no unauthenticated
/models probe).
- static/, README, tests/test_copilot*.py.
* Tidy copilot_routes: clarify supports_tools, note _PENDING is per-process
1672 lines
78 KiB
Python
1672 lines
78 KiB
Python
# src/llm_core.py
|
|
import httpx
|
|
import asyncio
|
|
import time
|
|
import json
|
|
import logging
|
|
import hashlib
|
|
import threading
|
|
from fastapi import HTTPException
|
|
from typing import Optional, Dict, List
|
|
from src.model_context import get_context_length, DEFAULT_CONTEXT
|
|
from urllib.parse import urlparse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LLMConfig:
|
|
"""Configuration constants for LLM operations."""
|
|
DEFAULT_TIMEOUT = 30
|
|
DEFAULT_TEMPERATURE = 1.0
|
|
DEFAULT_MAX_TOKENS = 0
|
|
MAX_RETRIES = 3
|
|
RETRY_DELAY = 0.5
|
|
STREAM_TIMEOUT = 300
|
|
|
|
|
|
# Cache for LLM responses
|
|
def _get_cache_key(url: str, model: str, messages: List[Dict],
|
|
temperature: float, max_tokens: int) -> str:
|
|
"""Generate cache key for LLM requests."""
|
|
hashable_messages = []
|
|
for msg in messages:
|
|
sorted_items = tuple(sorted(msg.items()))
|
|
hashable_messages.append(sorted_items)
|
|
|
|
content = json.dumps({
|
|
'url': url,
|
|
'model': model,
|
|
'messages': hashable_messages,
|
|
'temp': temperature,
|
|
'max_tokens': max_tokens
|
|
}, sort_keys=True)
|
|
return hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
_response_cache = {}
|
|
|
|
# Dead-host cooldown: maps host (scheme://host:port) -> unix ts when cooldown expires.
|
|
# When a connect to a host fails, we mark it dead for DEAD_HOST_COOLDOWN seconds so
|
|
# subsequent calls fail instantly instead of waiting on the connect timeout. Keeps
|
|
# one unreachable upstream from jamming chat across the rest of the app.
|
|
#
|
|
# But a SINGLE transient blip (local model briefly busy, a momentary
|
|
# Tailscale hiccup) used to trip a full 60s lockout — the user saw a
|
|
# 503 and thought the model died when it was fine a second later. So:
|
|
# - require FAIL_THRESHOLD consecutive failures before cooling
|
|
# - shorter cooldown so recovery is quick
|
|
# - any success resets the failure counter immediately
|
|
DEAD_HOST_COOLDOWN = 20.0
|
|
_HOST_FAIL_THRESHOLD = 2
|
|
_dead_hosts: Dict[str, float] = {}
|
|
_host_fails: Dict[str, int] = {}
|
|
# Guards the two maps above. The synchronous llm_call() runs inside FastAPI's
|
|
# threadpool (sync routes such as /sessions/auto-sort) while llm_call_async()
|
|
# runs on the event loop, so these maps are mutated from multiple OS threads.
|
|
# Without the lock the get()+1+set on _host_fails is a read-modify-write that
|
|
# loses failure counts under concurrent connect errors (issue #659).
|
|
_host_health_lock = threading.Lock()
|
|
_model_activity: Dict[str, float] = {}
|
|
|
|
def _model_activity_key(url: str, model: str) -> str:
|
|
return f"{(url or '').strip()}|{(model or '').strip()}"
|
|
|
|
def note_model_activity(url: str, model: str):
|
|
"""Record that a real upstream request used this endpoint/model."""
|
|
if not url or not model:
|
|
return
|
|
_model_activity[_model_activity_key(url, model)] = time.time()
|
|
|
|
def seconds_since_model_activity(url: str, model: str) -> Optional[float]:
|
|
"""Seconds since the endpoint/model was last used in this process."""
|
|
ts = _model_activity.get(_model_activity_key(url, model))
|
|
if not ts:
|
|
return None
|
|
return max(0.0, time.time() - ts)
|
|
|
|
def _host_key(url: str) -> str:
|
|
from urllib.parse import urlsplit
|
|
s = urlsplit(url)
|
|
return f"{s.scheme}://{s.netloc}" if s.scheme and s.netloc else url
|
|
|
|
def _is_host_dead(url: str) -> bool:
|
|
key = _host_key(url)
|
|
with _host_health_lock:
|
|
exp = _dead_hosts.get(key)
|
|
if exp is None:
|
|
return False
|
|
if time.time() >= exp:
|
|
_dead_hosts.pop(key, None)
|
|
return False
|
|
return True
|
|
|
|
def _mark_host_dead(url: str) -> bool:
|
|
"""Record a connect failure. Only actually cools the host after
|
|
_HOST_FAIL_THRESHOLD consecutive failures. Returns True if the host
|
|
is now cooled (so callers can log accurately), False if it's still
|
|
within its allowed-failure grace."""
|
|
key = _host_key(url)
|
|
with _host_health_lock:
|
|
n = _host_fails.get(key, 0) + 1
|
|
_host_fails[key] = n
|
|
if n >= _HOST_FAIL_THRESHOLD:
|
|
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
|
return True
|
|
return False
|
|
|
|
def _clear_host_dead(url: str) -> None:
|
|
key = _host_key(url)
|
|
with _host_health_lock:
|
|
_dead_hosts.pop(key, None)
|
|
_host_fails.pop(key, None)
|
|
|
|
|
|
# Shared async HTTP client. Reusing one client keeps connections warm:
|
|
# repeat calls to api.anthropic.com / api.openai.com / openrouter skip the
|
|
# 100-500ms TCP+TLS handshake. Lazy init so we bind to the running event loop.
|
|
_http_client: Optional[httpx.AsyncClient] = None
|
|
_http_limits = httpx.Limits(max_connections=100, max_keepalive_connections=30, keepalive_expiry=30.0)
|
|
|
|
def _get_http_client() -> httpx.AsyncClient:
|
|
"""Return process-wide AsyncClient. Per-request timeout is passed at call time."""
|
|
global _http_client
|
|
if _http_client is None or _http_client.is_closed:
|
|
from src.tls_overrides import llm_verify
|
|
_http_client = httpx.AsyncClient(
|
|
limits=_http_limits, http2=False, verify=llm_verify(),
|
|
)
|
|
return _http_client
|
|
|
|
def _get_cached_response(cache_key: str) -> Optional[str]:
|
|
"""Get cached response if it exists."""
|
|
return _response_cache.get(cache_key)
|
|
|
|
def _set_cached_response(cache_key: str, response: str) -> None:
|
|
"""Store response in cache."""
|
|
if len(_response_cache) > 128:
|
|
keys_to_remove = list(_response_cache.keys())[:64]
|
|
for key in keys_to_remove:
|
|
# pop(), not del: another thread (sync llm_call runs in FastAPI's
|
|
# threadpool) may have already evicted the same snapshotted key,
|
|
# and del would raise KeyError mid-eviction (issue #659).
|
|
_response_cache.pop(key, None)
|
|
_response_cache[cache_key] = response
|
|
|
|
# ── Anthropic native API adapter ──
|
|
|
|
ANTHROPIC_MODELS = [
|
|
"claude-opus-4-20250514", "claude-opus-4",
|
|
"claude-sonnet-4-20250514", "claude-sonnet-4", "claude-sonnet-4-5-20250929", "claude-sonnet-4-5",
|
|
"claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5",
|
|
]
|
|
|
|
|
|
def _is_ollama_native_url(url: str) -> bool:
|
|
"""Return True for native Ollama API URLs, including Ollama Cloud."""
|
|
try:
|
|
parsed = urlparse(url or "")
|
|
except Exception:
|
|
return False
|
|
host = parsed.hostname or ""
|
|
path = (parsed.path or "").rstrip("/")
|
|
if _host_match(url, "ollama.com"):
|
|
return True
|
|
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
|
|
return local_ollama_host and (path == "/api" or path.startswith("/api/"))
|
|
|
|
|
|
def _ollama_api_root(url: str) -> str:
|
|
"""Return a native Ollama API root such as https://ollama.com/api."""
|
|
url = (url or "").strip().rstrip("/")
|
|
parsed = urlparse(url)
|
|
path = (parsed.path or "").rstrip("/")
|
|
if path.endswith("/api/chat"):
|
|
return url[: -len("/chat")]
|
|
if path.endswith("/api/tags"):
|
|
return url[: -len("/tags")]
|
|
if path.endswith("/api/generate"):
|
|
return url[: -len("/generate")]
|
|
if path.endswith("/api"):
|
|
return url
|
|
if _host_match(url, "ollama.com"):
|
|
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
|
|
return root.rstrip("/") + "/api"
|
|
return url
|
|
|
|
|
|
def _normalize_ollama_url(url: str) -> str:
|
|
"""Ensure a native Ollama URL points at /api/chat."""
|
|
base = _ollama_api_root(url)
|
|
return base.rstrip("/") + "/chat"
|
|
|
|
|
|
def _ollama_normalize_tool_messages(messages: List[Dict]) -> List[Dict]:
|
|
"""Adapt Odysseus' canonical OpenAI-style messages to native Ollama /api/chat.
|
|
|
|
Odysseus carries assistant tool calls in the OpenAI shape, where
|
|
`function.arguments` is a JSON *string*. Native Ollama expects it to be a
|
|
JSON *object*; given the string it fails the whole request with HTTP 400
|
|
"Value looks like object, but can't find closing '}' symbol", which aborts
|
|
every follow-up (tool-result) round. Parse the arguments back into an object
|
|
here, on a shallow copy, leaving non-tool messages untouched. The opaque
|
|
Gemini `extra_content` (thought_signature) is dropped — it is meaningless to
|
|
Ollama and only matters when the conversation is replayed to Gemini.
|
|
"""
|
|
out: List[Dict] = []
|
|
for m in messages or []:
|
|
tcs = m.get("tool_calls") if isinstance(m, dict) else None
|
|
if not tcs:
|
|
out.append(m)
|
|
continue
|
|
new_calls = []
|
|
for tc in tcs:
|
|
fn = tc.get("function") or {}
|
|
args = fn.get("arguments")
|
|
if isinstance(args, str):
|
|
try:
|
|
args = json.loads(args) if args.strip() else {}
|
|
except (json.JSONDecodeError, TypeError):
|
|
args = {}
|
|
call: Dict = {"function": {"name": fn.get("name", ""), "arguments": args or {}}}
|
|
if tc.get("id"):
|
|
call["id"] = tc["id"]
|
|
new_calls.append(call)
|
|
nm = dict(m)
|
|
nm["tool_calls"] = new_calls
|
|
out.append(nm)
|
|
return out
|
|
|
|
|
|
def _build_ollama_payload(
|
|
model: str,
|
|
messages: List[Dict],
|
|
temperature: float,
|
|
max_tokens: int,
|
|
stream: bool = False,
|
|
tools: Optional[List[Dict]] = None,
|
|
num_ctx: Optional[int] = None,
|
|
) -> Dict:
|
|
"""Build the JSON payload for Ollama's /api/chat endpoint.
|
|
|
|
``num_ctx`` sets the input context window. Ollama defaults to 2048
|
|
when the option is omitted, so a model with a larger advertised
|
|
window is silently truncated there, and a model with a smaller one
|
|
gets an oversized window it can't service. Pass the discovered
|
|
context length through ``num_ctx``; this builder only emits it when
|
|
the value is trusted (not the ``DEFAULT_CONTEXT`` fallback), so we
|
|
don't guess for unknown models but do tell Ollama the real window
|
|
when we know it — even if it's smaller than 2048.
|
|
"""
|
|
payload: Dict = {
|
|
"model": model,
|
|
"messages": _ollama_normalize_tool_messages(messages),
|
|
"stream": stream,
|
|
}
|
|
options: Dict = {}
|
|
if temperature is not None:
|
|
options["temperature"] = temperature
|
|
if max_tokens and max_tokens > 0:
|
|
options["num_predict"] = max_tokens
|
|
if num_ctx is not None and num_ctx > 0 and num_ctx != DEFAULT_CONTEXT:
|
|
options["num_ctx"] = num_ctx
|
|
if options:
|
|
payload["options"] = options
|
|
if tools:
|
|
payload["tools"] = tools
|
|
return payload
|
|
|
|
|
|
def _parse_ollama_response(data: dict) -> str:
|
|
message = data.get("message") or {}
|
|
return message.get("content") or data.get("response") or ""
|
|
|
|
|
|
def _host_match(url: str, *domains: str) -> bool:
|
|
"""Return True if url's hostname equals any of `domains` or is a subdomain of one.
|
|
|
|
Used by helpers that want "is this Anthropic?" / "is this OpenRouter?"
|
|
style checks. Prefer this over substring matching on the URL: the
|
|
substring form gives wrong answers for unrelated paths or query strings
|
|
that happen to contain the domain text.
|
|
"""
|
|
if not url:
|
|
return False
|
|
try:
|
|
# rstrip(".") so a fully-qualified host with a trailing dot
|
|
# ("api.anthropic.com.") still matches "anthropic.com".
|
|
host = (urlparse(url).hostname or "").lower().rstrip(".")
|
|
except Exception:
|
|
return False
|
|
if not host:
|
|
return False
|
|
return any(host == d or host.endswith("." + d) for d in domains)
|
|
|
|
|
|
def _detect_provider(url: str) -> str:
|
|
"""Detect the API provider from a configured endpoint URL.
|
|
|
|
Matches on hostname (exact or subdomain) rather than substring, so a URL
|
|
that merely contains a provider's domain in its path or query — or a
|
|
look-alike host such as ``anthropic.com.example`` — is not misclassified.
|
|
Unknown hosts fall back to the OpenAI-compatible default, which the
|
|
majority of providers implement.
|
|
"""
|
|
if _is_ollama_native_url(url):
|
|
return "ollama"
|
|
if _host_match(url, "anthropic.com"):
|
|
return "anthropic"
|
|
if _host_match(url, "openrouter.ai"):
|
|
return "openrouter"
|
|
if _host_match(url, "groq.com"):
|
|
return "groq"
|
|
from src.copilot import is_copilot_base
|
|
if is_copilot_base(url):
|
|
return "copilot"
|
|
return "openai"
|
|
|
|
|
|
def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str, str]:
|
|
h = {"Content-Type": "application/json"}
|
|
if isinstance(headers, dict):
|
|
h.update(headers)
|
|
if provider == "openrouter":
|
|
h.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
|
|
h.setdefault("X-OpenRouter-Title", "Odysseus")
|
|
if provider == "copilot":
|
|
# Ensure the Copilot-required headers are present even when the caller
|
|
# didn't pass pre-built headers (e.g. model listing). build_headers()
|
|
# already injects these for the live chat path; setdefault keeps any
|
|
# request-specific values (x-initiator/vision) the caller set.
|
|
from src.copilot import copilot_headers
|
|
for k, v in copilot_headers(None).items():
|
|
h.setdefault(k, v)
|
|
return h
|
|
|
|
|
|
def _provider_label(url: str) -> str:
|
|
"""Human-friendly provider name for error messages."""
|
|
if not url:
|
|
return "provider"
|
|
if _host_match(url, "anthropic.com"): return "Anthropic"
|
|
if _host_match(url, "ollama.com"): return "Ollama Cloud"
|
|
if _host_match(url, "x.ai"): return "xAI"
|
|
if _host_match(url, "openai.com"): return "OpenAI"
|
|
if _host_match(url, "openrouter.ai"): return "OpenRouter"
|
|
if _host_match(url, "groq.com"): return "Groq"
|
|
from src.copilot import is_copilot_base
|
|
if is_copilot_base(url): return "GitHub Copilot"
|
|
if _host_match(url, "mistral.ai"): return "Mistral"
|
|
if _host_match(url, "deepseek.com"): return "DeepSeek"
|
|
if _host_match(url, "googleapis.com"): return "Google"
|
|
if _host_match(url, "together.xyz", "together.ai"): return "Together"
|
|
if _host_match(url, "fireworks.ai"): return "Fireworks"
|
|
if _is_ollama_native_url(url): return "Ollama"
|
|
try:
|
|
host = (urlparse(url).hostname or "").lower()
|
|
except Exception:
|
|
return "provider"
|
|
if host in {"localhost", "127.0.0.1", "::1", "0.0.0.0"}:
|
|
return "local endpoint"
|
|
return host or "provider"
|
|
|
|
|
|
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
|
"""Turn an upstream HTTP error into a user-readable sentence.
|
|
|
|
Auth failures (401/403) become 'xAI rejected the API key' etc., so the UI
|
|
stops showing raw JSON like '{"error":{"message":"User not found."}}'.
|
|
"""
|
|
if isinstance(body, bytes):
|
|
try:
|
|
body = body.decode("utf-8", errors="replace")
|
|
except Exception:
|
|
body = str(body)
|
|
provider = _provider_label(url)
|
|
# Try to pull a message out of the body
|
|
detail = ""
|
|
try:
|
|
j = json.loads(body) if body else {}
|
|
if isinstance(j, dict):
|
|
err = j.get("error") or j
|
|
if isinstance(err, dict):
|
|
detail = (err.get("message") or err.get("detail") or "").strip()
|
|
elif isinstance(err, str):
|
|
detail = err.strip()
|
|
except Exception:
|
|
detail = (body or "").strip()[:240]
|
|
|
|
if status in (401, 403):
|
|
msg = f"{provider} rejected the API key"
|
|
if status == 403:
|
|
msg = f"{provider} denied access (403)"
|
|
if detail:
|
|
msg += f" — {detail}"
|
|
msg += ". Check Model Endpoints → {} and re-paste the key.".format(provider)
|
|
return msg
|
|
if status == 404:
|
|
return f"{provider} returned 404 — check the base URL and model name." + (f" ({detail})" if detail else "")
|
|
if status == 429:
|
|
return f"{provider} rate-limited the request (429)." + (f" {detail}" if detail else "")
|
|
if status >= 500:
|
|
return f"{provider} is having an outage (HTTP {status})." + (f" {detail}" if detail else "")
|
|
return f"{provider} returned HTTP {status}" + (f": {detail}" if detail else "")
|
|
|
|
# Models that require max_completion_tokens instead of max_tokens
|
|
_MAX_COMPLETION_TOKENS_MODELS = {"o1", "o3", "o4", "gpt-4.5", "gpt-5"}
|
|
|
|
def _uses_max_completion_tokens(model: str) -> bool:
|
|
"""Check if a model requires max_completion_tokens instead of max_tokens."""
|
|
if not model:
|
|
return False
|
|
m = model.lower()
|
|
return any(m.startswith(p) or f"/{p}" in m for p in _MAX_COMPLETION_TOKENS_MODELS)
|
|
|
|
# OpenAI reasoning models (o1, o3, o4, gpt-5 families) only accept the default
|
|
# temperature. Sending any explicit value — even 0.0 — returns HTTP 400
|
|
# ("Only the default (1) value is supported"). That otherwise breaks chat when a
|
|
# preset sets a non-default temperature, and makes endpoint probing report a
|
|
# perfectly good model as failing. For these models we omit the field and let
|
|
# the API use its required default. (gpt-4.5 is intentionally excluded — it is
|
|
# not a reasoning model and accepts temperature normally.)
|
|
_FIXED_TEMPERATURE_MODELS = ("o1", "o3", "o4", "gpt-5")
|
|
|
|
def _restricts_temperature(model: str) -> bool:
|
|
"""Check if a model rejects any non-default temperature."""
|
|
if not model:
|
|
return False
|
|
m = model.lower()
|
|
return any(m.startswith(p) or f"/{p}" in m for p in _FIXED_TEMPERATURE_MODELS)
|
|
|
|
# Models that support structured thinking — may output </think> without opening tag
|
|
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap", "gemma")
|
|
|
|
def _supports_thinking(model: str) -> bool:
|
|
"""Check if model supports structured thinking output."""
|
|
if not model:
|
|
return False
|
|
m = model.lower()
|
|
return any(p in m for p in _THINKING_MODEL_PATTERNS)
|
|
|
|
def _convert_openai_content_to_anthropic(content):
|
|
"""Convert OpenAI multimodal content blocks to Anthropic format.
|
|
|
|
Converts image_url blocks (data URI) → Anthropic image blocks.
|
|
Passes text blocks through unchanged.
|
|
"""
|
|
if not isinstance(content, list):
|
|
return content
|
|
converted = []
|
|
for block in content:
|
|
if not isinstance(block, dict):
|
|
converted.append(block)
|
|
continue
|
|
if block.get("type") == "image_url":
|
|
url = (block.get("image_url") or {}).get("url", "")
|
|
# Parse data URI: data:image/<fmt>;base64,<data>
|
|
if url.startswith("data:"):
|
|
try:
|
|
header, b64_data = url.split(",", 1)
|
|
media_type = header.split(";")[0].replace("data:", "")
|
|
except (ValueError, IndexError):
|
|
continue
|
|
converted.append({
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": media_type,
|
|
"data": b64_data,
|
|
},
|
|
})
|
|
else:
|
|
# External URL — use Anthropic's URL source
|
|
converted.append({
|
|
"type": "image",
|
|
"source": {"type": "url", "url": url},
|
|
})
|
|
elif block.get("type") == "text":
|
|
converted.append(block)
|
|
else:
|
|
converted.append(block)
|
|
return converted
|
|
|
|
|
|
def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=False, tools=None):
|
|
"""Convert OpenAI-style messages to Anthropic format."""
|
|
system_parts = []
|
|
chat_messages = []
|
|
for m in messages:
|
|
if m.get("role") == "system":
|
|
system_parts.append(m["content"])
|
|
elif m.get("role") == "tool":
|
|
# Convert OpenAI tool result to Anthropic format
|
|
chat_messages.append({
|
|
"role": "user",
|
|
"content": [{
|
|
"type": "tool_result",
|
|
"tool_use_id": m.get("tool_call_id", ""),
|
|
"content": m.get("content", ""),
|
|
}],
|
|
})
|
|
elif m.get("role") == "assistant" and isinstance(m.get("tool_calls"), list):
|
|
# Convert OpenAI assistant tool_calls to Anthropic format
|
|
content = []
|
|
if m.get("content"):
|
|
content.append({"type": "text", "text": m["content"]})
|
|
for tc in m["tool_calls"]:
|
|
fn = tc.get("function") or {}
|
|
args_str = fn.get("arguments") or "{}"
|
|
try:
|
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
except (json.JSONDecodeError, TypeError):
|
|
args = {}
|
|
content.append({
|
|
"type": "tool_use",
|
|
"id": tc.get("id", ""),
|
|
"name": fn.get("name", ""),
|
|
"input": args,
|
|
})
|
|
chat_messages.append({"role": "assistant", "content": content})
|
|
else:
|
|
# Convert multimodal content (image_url → image) for Anthropic
|
|
content = _convert_openai_content_to_anthropic(m["content"])
|
|
chat_messages.append({"role": m["role"], "content": content})
|
|
# Anthropic only accepts temperature in [0.0, 1.0] and 400s on anything above
|
|
# 1.0. Clamp here (in the Anthropic builder only) so presets/sliders that use
|
|
# the wider OpenAI 0.0-2.0 range — e.g. the shipped "Nietzsche" preset at 1.2
|
|
# — don't hard-break every Claude request. OpenAI's own path is left untouched.
|
|
if temperature is not None:
|
|
temperature = max(0.0, min(temperature, 1.0))
|
|
payload = {
|
|
"model": model,
|
|
"messages": chat_messages,
|
|
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
|
|
"temperature": temperature,
|
|
}
|
|
if system_parts:
|
|
system_text = "\n\n".join(system_parts)
|
|
# Send `system` as a structured text block so we can attach a prompt-cache
|
|
# breakpoint. The agent loop re-sends this same large prefix every round;
|
|
# caching it makes Anthropic re-read it from cache (~90% cheaper, lower TTFB)
|
|
# instead of re-billing it. Skip caching tiny one-off prompts, where the
|
|
# cache-WRITE premium wouldn't pay back (no reuse). Presence of `tools`
|
|
# means an agentic/multi-round call, where the prefix is always reused.
|
|
system_block = {"type": "text", "text": system_text}
|
|
if tools or len(system_text) > 4000:
|
|
system_block["cache_control"] = {"type": "ephemeral"}
|
|
payload["system"] = [system_block]
|
|
if stream:
|
|
payload["stream"] = True
|
|
# Convert OpenAI-format tools to Anthropic format
|
|
if tools:
|
|
anthropic_tools = []
|
|
for t in tools:
|
|
if t.get("type") == "function":
|
|
fn = t["function"]
|
|
anthropic_tools.append({
|
|
"name": fn["name"],
|
|
"description": fn.get("description", ""),
|
|
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
})
|
|
if anthropic_tools:
|
|
# Cache the tool schemas too — they're stable for the whole agent run.
|
|
# The breakpoint caches all tool defs preceding it in the request.
|
|
anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"}
|
|
payload["tools"] = anthropic_tools
|
|
return payload
|
|
|
|
def _build_anthropic_headers(headers):
|
|
"""Convert Bearer auth to x-api-key for Anthropic."""
|
|
h = {"Content-Type": "application/json", "anthropic-version": "2023-06-01"}
|
|
if headers:
|
|
for k, v in headers.items():
|
|
if k.lower() == "authorization" and isinstance(v, str) and v.startswith("Bearer "):
|
|
h["x-api-key"] = v[7:]
|
|
else:
|
|
h[k] = v
|
|
return h
|
|
|
|
def _parse_anthropic_response(data: dict) -> str:
|
|
"""Extract text from an Anthropic response.
|
|
|
|
The Messages API `content` is an array that can hold more than one text
|
|
block (e.g. text split around a tool_use block, or citation-segmented
|
|
text). Concatenate them all instead of returning only the first, which
|
|
silently dropped the rest of the reply.
|
|
"""
|
|
return "".join(
|
|
block.get("text", "")
|
|
for block in data.get("content", [])
|
|
if isinstance(block, dict) and block.get("type") == "text"
|
|
)
|
|
|
|
|
|
def _as_content_blocks(content) -> List[Dict]:
|
|
"""Coerce a message `content` into a list of content blocks.
|
|
|
|
A list (multimodal: text + image parts) passes through; a non-empty string
|
|
becomes a single text block; None/empty yields no blocks. Used when merging
|
|
consecutive user messages so multimodal content isn't str()-ed away.
|
|
"""
|
|
if isinstance(content, list):
|
|
return content
|
|
if content:
|
|
return [{"type": "text", "text": str(content)}]
|
|
return []
|
|
|
|
|
|
def _sanitize_llm_messages(messages: List[Dict]) -> List[Dict]:
|
|
"""Strip Odysseus-only metadata before sending messages to providers.
|
|
|
|
Per the OpenAI chat format: user/system messages must have content; a tool
|
|
message needs content + tool_call_id; an assistant message may carry content,
|
|
tool_calls, or both. The old guard required content on every message, which
|
|
dropped a valid assistant message that has only tool_calls — e.g. the
|
|
follow-up message _append_tool_results builds for a no-prose native tool call
|
|
(content=None, since Gemini/Ollama reject tool_calls alongside ""). Dropping
|
|
it leaves the tool result dangling and breaks the next round.
|
|
"""
|
|
allowed = {"role", "content", "name", "tool_call_id", "tool_calls", "function_call"}
|
|
cleaned = []
|
|
for msg in messages or []:
|
|
if not isinstance(msg, dict):
|
|
continue
|
|
item = {k: v for k, v in msg.items() if k in allowed and v is not None}
|
|
role = item.get("role")
|
|
if not role:
|
|
continue
|
|
if role == "assistant":
|
|
# Re-add an explicit content=None when the message is tool-calls-only
|
|
# (the None was stripped above) so the provider gets the spec-correct
|
|
# `content: null`, not an omitted key.
|
|
if "content" not in item and item.get("tool_calls"):
|
|
item["content"] = None
|
|
if "content" in item or item.get("tool_calls"):
|
|
cleaned.append(item)
|
|
elif role == "tool":
|
|
if "content" in item and "tool_call_id" in item:
|
|
cleaned.append(item)
|
|
elif "content" in item:
|
|
cleaned.append(item)
|
|
|
|
# Repair tool-call adjacency before sending to any OpenAI-compatible
|
|
# provider. Trimming/compaction/retries can leave `role:"tool"` messages
|
|
# without their immediately-preceding assistant `tool_calls` parent, which
|
|
# DeepSeek rejects with:
|
|
# "Messages with role 'tool' must be a response to a preceding message with
|
|
# 'tool_calls'". Also strip unanswered assistant tool_calls; some providers
|
|
# reject those as incomplete conversations.
|
|
repaired: List[Dict] = []
|
|
i = 0
|
|
while i < len(cleaned):
|
|
msg = cleaned[i]
|
|
role = msg.get("role")
|
|
|
|
if role == "tool":
|
|
# Orphan tool result. There is no valid assistant tool_calls parent
|
|
# immediately before this batch, so it cannot be sent.
|
|
logger.debug("Dropping orphan tool message before provider request")
|
|
i += 1
|
|
continue
|
|
|
|
tool_calls = msg.get("tool_calls") if role == "assistant" else None
|
|
if not tool_calls:
|
|
repaired.append(msg)
|
|
i += 1
|
|
continue
|
|
|
|
call_ids = [
|
|
str(tc.get("id"))
|
|
for tc in tool_calls
|
|
if isinstance(tc, dict) and tc.get("id")
|
|
]
|
|
expected = set(call_ids)
|
|
answered_ids = []
|
|
tool_batch = []
|
|
j = i + 1
|
|
while j < len(cleaned) and cleaned[j].get("role") == "tool":
|
|
tid = str(cleaned[j].get("tool_call_id") or "")
|
|
if tid in expected and tid not in answered_ids:
|
|
answered_ids.append(tid)
|
|
tool_batch.append(cleaned[j])
|
|
else:
|
|
logger.debug("Dropping unmatched/duplicate tool message before provider request")
|
|
j += 1
|
|
|
|
if not tool_batch:
|
|
plain = {k: v for k, v in msg.items() if k != "tool_calls"}
|
|
if (plain.get("content") or "").strip():
|
|
repaired.append(plain)
|
|
else:
|
|
logger.debug("Dropping unanswered assistant tool_calls before provider request")
|
|
i = j
|
|
continue
|
|
|
|
answered = set(answered_ids)
|
|
pruned_calls = [
|
|
tc for tc in tool_calls
|
|
if isinstance(tc, dict) and str(tc.get("id")) in answered
|
|
]
|
|
fixed = dict(msg)
|
|
fixed["tool_calls"] = pruned_calls
|
|
if "content" not in fixed:
|
|
fixed["content"] = None
|
|
repaired.append(fixed)
|
|
repaired.extend(tool_batch)
|
|
if len(pruned_calls) != len(tool_calls):
|
|
logger.debug("Pruned unanswered assistant tool_calls before provider request")
|
|
i = j
|
|
|
|
# Merge consecutive user messages to satisfy strict role alternation
|
|
# requirements after invalid tool-call fragments have been removed.
|
|
merged: List[Dict] = []
|
|
for item in repaired:
|
|
if not merged:
|
|
merged.append(item)
|
|
continue
|
|
|
|
last = merged[-1]
|
|
if last.get("role") == "user" and item.get("role") == "user":
|
|
last_copy = dict(last)
|
|
lc = last_copy.get("content")
|
|
ic = item.get("content")
|
|
if isinstance(lc, list) or isinstance(ic, list):
|
|
# Preserve multimodal content blocks (e.g. an image part) by
|
|
# concatenating the block lists. str()-ing a list turned an
|
|
# image message into its Python repr and dropped the image.
|
|
merged_blocks = _as_content_blocks(lc) + _as_content_blocks(ic)
|
|
if merged_blocks:
|
|
last_copy["content"] = merged_blocks
|
|
else:
|
|
last_copy.pop("content", None)
|
|
else:
|
|
last_str = str(lc) if lc is not None else ""
|
|
item_str = str(ic) if ic is not None else ""
|
|
new_content = "\n\n".join(part for part in (last_str, item_str) if part)
|
|
if new_content:
|
|
last_copy["content"] = new_content
|
|
else:
|
|
last_copy.pop("content", None)
|
|
merged[-1] = last_copy
|
|
else:
|
|
merged.append(item)
|
|
|
|
return merged
|
|
|
|
def _normalize_anthropic_url(url: str) -> str:
|
|
"""Ensure Anthropic URL points to /v1/messages."""
|
|
url = url.rstrip("/")
|
|
if url.endswith("/v1/messages"):
|
|
return url
|
|
if url.endswith("/v1"):
|
|
return url + "/messages"
|
|
return url + "/v1/messages"
|
|
|
|
|
|
def _model_list_base(url: str) -> str:
|
|
"""Normalize model/chat URLs to the configured endpoint base."""
|
|
base = (url or "").strip().rstrip("/")
|
|
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
|
if base.endswith(suffix):
|
|
base = base[: -len(suffix)].rstrip("/")
|
|
for suffix in ("/chat", "/tags", "/generate"):
|
|
if base.endswith("/api" + suffix):
|
|
base = base[: -len(suffix)].rstrip("/")
|
|
return base
|
|
|
|
|
|
def _parse_model_cache(raw) -> List[str]:
|
|
if not raw:
|
|
return []
|
|
try:
|
|
models = json.loads(raw) if isinstance(raw, str) else raw
|
|
except Exception:
|
|
return []
|
|
if not isinstance(models, list):
|
|
return []
|
|
out = []
|
|
seen = set()
|
|
for item in models:
|
|
mid = str(item or "").strip()
|
|
if not mid or mid in seen:
|
|
continue
|
|
out.append(mid)
|
|
seen.add(mid)
|
|
return out
|
|
|
|
|
|
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
|
"""Return cached models for a configured endpoint matching endpoint_url."""
|
|
target = _model_list_base(endpoint_url)
|
|
if not target:
|
|
return []
|
|
try:
|
|
from src.database import SessionLocal, ModelEndpoint
|
|
except Exception:
|
|
return []
|
|
db = SessionLocal()
|
|
try:
|
|
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
|
for ep in rows:
|
|
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
|
continue
|
|
models = _parse_model_cache(getattr(ep, "cached_models", None) or getattr(ep, "models", None))
|
|
if not models:
|
|
continue
|
|
hidden = set(_parse_model_cache(getattr(ep, "hidden_models", None)))
|
|
return [m for m in models if m not in hidden]
|
|
except Exception:
|
|
return []
|
|
finally:
|
|
try:
|
|
db.close()
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
|
|
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
|
"""List available model IDs from an endpoint."""
|
|
cached = _configured_cached_model_ids(base_chat_url)
|
|
if cached:
|
|
return cached
|
|
provider = _detect_provider(base_chat_url)
|
|
if provider == "anthropic":
|
|
return list(ANTHROPIC_MODELS)
|
|
try:
|
|
h = {}
|
|
if headers:
|
|
h.update(headers)
|
|
if provider == "ollama":
|
|
models_url = _ollama_api_root(base_chat_url) + "/tags"
|
|
else:
|
|
models_url = base_chat_url.replace("/chat/completions", "/models")
|
|
r = httpx.get(models_url, headers=h, timeout=timeout)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
|
if not model_ids:
|
|
model_ids = [
|
|
m.get("name") or m.get("model")
|
|
for m in (data.get("models") or [])
|
|
if m.get("name") or m.get("model")
|
|
]
|
|
return model_ids
|
|
except Exception:
|
|
try:
|
|
if ":11434" in base_chat_url or "ollama" in base_chat_url.lower():
|
|
root = base_chat_url.replace("/v1/chat/completions", "").replace("/chat/completions", "").rstrip("/")
|
|
r = httpx.get(root + "/api/tags", timeout=timeout)
|
|
r.raise_for_status()
|
|
return [m.get("name") or m.get("model") for m in (r.json().get("models") or []) if m.get("name") or m.get("model")]
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
|
|
"""Normalize a model ID to match available models."""
|
|
avail = list_model_ids(endpoint_url, timeout)
|
|
if not avail:
|
|
return None
|
|
if requested in avail:
|
|
return requested
|
|
import os as _os
|
|
req_base = _os.path.basename(requested.rstrip("/"))
|
|
for a in avail:
|
|
if _os.path.basename(a.rstrip("/")) == req_base:
|
|
return a
|
|
return None
|
|
|
|
def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
|
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
|
|
timeout: int = LLMConfig.DEFAULT_TIMEOUT, prompt_type: Optional[str] = None) -> str:
|
|
"""Synchronous LLM call with optional prompt type enhancement."""
|
|
h = _provider_headers(_detect_provider(url))
|
|
# Tolerate headers that arrive as a JSON string (some sessions stored them
|
|
# double-encoded) — otherwise h.update() throws "dictionary update sequence
|
|
# element #0 has length 1; 2 is required".
|
|
if isinstance(headers, str):
|
|
try:
|
|
headers = json.loads(headers)
|
|
except Exception:
|
|
headers = None
|
|
if isinstance(headers, dict):
|
|
h.update(headers)
|
|
|
|
messages_copy = _sanitize_llm_messages(messages)
|
|
|
|
# Consolidate multiple system messages into one at the start.
|
|
sys_parts = []
|
|
non_sys = []
|
|
for m in messages_copy:
|
|
if m.get("role") == "system":
|
|
sys_parts.append(m.get('content') or '')
|
|
else:
|
|
non_sys.append(m)
|
|
if sys_parts:
|
|
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
|
else:
|
|
messages_copy = non_sys
|
|
|
|
provider = _detect_provider(url)
|
|
cache_key = _get_cache_key(url, model, messages_copy, temperature, max_tokens)
|
|
cached_response = _get_cached_response(cache_key)
|
|
if cached_response:
|
|
logger.debug(f"Returning cached response for key: {cache_key}")
|
|
return cached_response
|
|
|
|
if provider == "anthropic":
|
|
target_url = _normalize_anthropic_url(url)
|
|
h = _build_anthropic_headers(headers)
|
|
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
|
elif provider == "ollama":
|
|
target_url = _normalize_ollama_url(url)
|
|
payload = _build_ollama_payload(
|
|
model, messages_copy, temperature, max_tokens,
|
|
stream=False, num_ctx=get_context_length(url, model),
|
|
)
|
|
else:
|
|
target_url = url
|
|
if provider == "copilot":
|
|
from src.copilot import apply_request_headers
|
|
apply_request_headers(h, messages_copy)
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages_copy,
|
|
"temperature": temperature,
|
|
}
|
|
if _restricts_temperature(model):
|
|
payload.pop("temperature", None)
|
|
if max_tokens and max_tokens > 0:
|
|
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
|
payload[tok_key] = max_tokens
|
|
try:
|
|
note_model_activity(target_url, model)
|
|
r = httpx.post(target_url, headers=h, json=payload, timeout=timeout)
|
|
except Exception as e:
|
|
raise HTTPException(502, f"POST {target_url} failed: {e}")
|
|
if not r.is_success:
|
|
raise HTTPException(502, f"Upstream {target_url} -> {r.status_code}: {r.text}")
|
|
data = r.json()
|
|
try:
|
|
if provider == "anthropic":
|
|
response = _parse_anthropic_response(data)
|
|
elif provider == "ollama":
|
|
response = _parse_ollama_response(data)
|
|
else:
|
|
msg = data["choices"][0]["message"]
|
|
response = msg.get("content") or msg.get("reasoning_content") or ""
|
|
_set_cached_response(cache_key, response)
|
|
return response
|
|
except Exception:
|
|
raise HTTPException(502, f"Unexpected schema from {target_url}: {str(data)[:400]}")
|
|
|
|
|
|
def _dedupe_candidates(candidates):
|
|
"""Filter malformed entries and drop a later repeat of an already-seen
|
|
``(url, model)`` route, preserving order (first occurrence wins).
|
|
|
|
The chain is the primary target followed by the configured fallbacks, so a
|
|
fallback that repeats the session's current model — a common misconfiguration,
|
|
since callers prepend the live ``(url, model)`` to ``default_model_fallbacks``
|
|
— would otherwise make the chain re-attempt the very route that just failed:
|
|
a wasted round-trip plus a spurious ``fallback`` notice for a switch that did
|
|
not happen. Headers are not part of the key; the first tuple (with its
|
|
headers) is the one kept.
|
|
"""
|
|
seen = set()
|
|
out = []
|
|
for c in candidates or []:
|
|
if not c or not c[0] or not c[1]:
|
|
continue
|
|
key = (c[0], c[1])
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
out.append(c)
|
|
return out
|
|
|
|
|
|
def llm_call_with_fallback(candidates, messages, **kwargs) -> str:
|
|
"""Sync `llm_call` with an ordered fallback chain.
|
|
|
|
`candidates` is a list of (url, model, headers). The first one that returns
|
|
without an exception wins. Connection / 5xx-style failures fall through to
|
|
the next candidate. The dead-host cooldown inside `llm_call` makes repeat
|
|
attempts at an offline primary effectively free.
|
|
"""
|
|
cands = _dedupe_candidates(candidates)
|
|
if not cands:
|
|
raise HTTPException(503, "No model endpoint configured")
|
|
last_err = None
|
|
for i, (url, model, headers) in enumerate(cands):
|
|
try:
|
|
return llm_call(url, model, messages, headers=headers, **kwargs)
|
|
except Exception as e:
|
|
last_err = e
|
|
tag = "primary" if i == 0 else "candidate"
|
|
logger.warning(f"[fallback] {tag} {model} failed ({type(e).__name__}); trying next")
|
|
continue
|
|
raise last_err if last_err else HTTPException(503, "All fallback candidates failed")
|
|
|
|
|
|
async def llm_call_async_with_fallback(candidates, messages, **kwargs) -> str:
|
|
"""Async variant of `llm_call_with_fallback` — same semantics."""
|
|
cands = _dedupe_candidates(candidates)
|
|
if not cands:
|
|
raise HTTPException(503, "No model endpoint configured")
|
|
last_err = None
|
|
for i, (url, model, headers) in enumerate(cands):
|
|
try:
|
|
return await llm_call_async(url, model, messages, headers=headers, **kwargs)
|
|
except Exception as e:
|
|
last_err = e
|
|
tag = "primary" if i == 0 else "candidate"
|
|
logger.warning(f"[fallback] {tag} {model} failed ({type(e).__name__}); trying next")
|
|
continue
|
|
raise last_err if last_err else HTTPException(503, "All fallback candidates failed")
|
|
|
|
|
|
async def llm_call_async(
|
|
url: str,
|
|
model: str,
|
|
messages: List[Dict],
|
|
temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
|
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS,
|
|
headers: Optional[Dict] = None,
|
|
timeout: int = LLMConfig.STREAM_TIMEOUT,
|
|
max_retries: int = LLMConfig.MAX_RETRIES,
|
|
prompt_type: Optional[str] = None
|
|
) -> str:
|
|
"""Asynchronous LLM call using httpx with connection pooling, timeout, retry logic, and performance logging."""
|
|
provider = _detect_provider(url)
|
|
messages_copy = _sanitize_llm_messages(messages)
|
|
|
|
# Consolidate multiple system messages into one at the start.
|
|
sys_parts = []
|
|
non_sys = []
|
|
for m in messages_copy:
|
|
if m.get("role") == "system":
|
|
sys_parts.append(m.get('content') or '')
|
|
else:
|
|
non_sys.append(m)
|
|
if sys_parts:
|
|
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
|
else:
|
|
messages_copy = non_sys
|
|
|
|
cache_key = _get_cache_key(url, model, messages_copy, temperature, max_tokens)
|
|
cached_response = _get_cached_response(cache_key)
|
|
if cached_response:
|
|
logger.debug(f"Returning cached response for key: {cache_key}")
|
|
return cached_response
|
|
|
|
if provider == "anthropic":
|
|
target_url = _normalize_anthropic_url(url)
|
|
h = _build_anthropic_headers(headers)
|
|
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
|
elif provider == "ollama":
|
|
target_url = _normalize_ollama_url(url)
|
|
h = {"Content-Type": "application/json"}
|
|
if headers:
|
|
h.update(headers)
|
|
payload = _build_ollama_payload(
|
|
model, messages_copy, temperature, max_tokens,
|
|
stream=False, num_ctx=get_context_length(url, model),
|
|
)
|
|
else:
|
|
target_url = url
|
|
h = _provider_headers(provider, headers)
|
|
if provider == "copilot":
|
|
from src.copilot import apply_request_headers
|
|
apply_request_headers(h, messages_copy)
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages_copy,
|
|
"temperature": temperature,
|
|
}
|
|
if _restricts_temperature(model):
|
|
payload.pop("temperature", None)
|
|
if max_tokens and max_tokens > 0:
|
|
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
|
payload[tok_key] = max_tokens
|
|
|
|
if _is_host_dead(target_url):
|
|
raise HTTPException(503, f"Upstream {_host_key(target_url)} marked unreachable (cooldown active)")
|
|
|
|
call_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=10.0, pool=5.0)
|
|
attempt = 0
|
|
while attempt < max_retries:
|
|
attempt += 1
|
|
start = time.time()
|
|
try:
|
|
note_model_activity(target_url, model)
|
|
client = _get_http_client()
|
|
r = await client.post(target_url, headers=h, json=payload, timeout=call_timeout)
|
|
duration = time.time() - start
|
|
if not r.is_success:
|
|
friendly = _format_upstream_error(r.status_code, r.text, target_url)
|
|
logger.warning(
|
|
f"LLM async call to {target_url} failed in {duration:.2f}s "
|
|
f"(attempt {attempt}): HTTP {r.status_code} {friendly}"
|
|
)
|
|
if r.status_code in (429, 502, 503, 504) and attempt < max_retries:
|
|
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
|
continue
|
|
raise HTTPException(r.status_code, friendly)
|
|
logger.info(f"LLM async call to {target_url} succeeded in {duration:.2f}s (attempt {attempt})")
|
|
_clear_host_dead(target_url)
|
|
data = r.json()
|
|
try:
|
|
if provider == "anthropic":
|
|
response = _parse_anthropic_response(data)
|
|
elif provider == "ollama":
|
|
response = _parse_ollama_response(data)
|
|
else:
|
|
msg = data["choices"][0]["message"]
|
|
response = msg.get("content") or msg.get("reasoning_content") or ""
|
|
_set_cached_response(cache_key, response)
|
|
return response
|
|
except Exception:
|
|
raise HTTPException(502, f"Unexpected schema from {target_url}: {str(data)[:400]}")
|
|
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
|
_cooled = _mark_host_dead(target_url)
|
|
duration = time.time() - start
|
|
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
|
logger.warning(f"LLM async connect to {target_url} failed after {duration:.2f}s: {e}{_tail}")
|
|
if _cooled or attempt >= max_retries:
|
|
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
|
|
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
duration = time.time() - start
|
|
logger.warning(f"LLM async call attempt {attempt} failed after {duration:.2f}s: {e}")
|
|
if attempt >= max_retries:
|
|
raise HTTPException(502, f"POST {target_url} failed after {max_retries} attempts: {e}")
|
|
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
|
|
|
async def stream_llm(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
|
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
|
|
timeout: int = LLMConfig.STREAM_TIMEOUT, prompt_type: Optional[str] = None,
|
|
tools: Optional[List[Dict]] = None):
|
|
"""Stream LLM responses with improved error handling.
|
|
|
|
Yields SSE chunks:
|
|
- data: {"delta": "text"} — text content
|
|
- data: {"type": "tool_calls", ...} — accumulated native tool calls (before DONE)
|
|
- event: error — errors
|
|
- data: [DONE] — end of stream
|
|
"""
|
|
provider = _detect_provider(url)
|
|
messages_copy = _sanitize_llm_messages(messages)
|
|
|
|
# Consolidate multiple system messages into one at the start.
|
|
# Some models (e.g. Qwen3.5) reject system messages that aren't first.
|
|
sys_parts = []
|
|
non_sys = []
|
|
for m in messages_copy:
|
|
if m.get("role") == "system":
|
|
sys_parts.append(m.get('content') or '')
|
|
else:
|
|
non_sys.append(m)
|
|
if sys_parts:
|
|
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
|
else:
|
|
messages_copy = non_sys
|
|
|
|
if provider == "anthropic":
|
|
target_url = _normalize_anthropic_url(url)
|
|
h = _build_anthropic_headers(headers)
|
|
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
|
|
elif provider == "ollama":
|
|
target_url = _normalize_ollama_url(url)
|
|
h = {"Content-Type": "application/json"}
|
|
if headers:
|
|
h.update(headers)
|
|
payload = _build_ollama_payload(
|
|
model, messages_copy, temperature, max_tokens,
|
|
stream=True, tools=tools, num_ctx=get_context_length(url, model),
|
|
)
|
|
else:
|
|
target_url = url
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages_copy,
|
|
"temperature": temperature,
|
|
"stream": True,
|
|
}
|
|
if _restricts_temperature(model):
|
|
payload.pop("temperature", None)
|
|
if provider not in {"openrouter", "groq"}:
|
|
payload["stream_options"] = {"include_usage": True}
|
|
if max_tokens and max_tokens > 0:
|
|
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
|
payload[tok_key] = max_tokens
|
|
if tools:
|
|
payload["tools"] = tools
|
|
h = _provider_headers(provider, headers)
|
|
if provider == "copilot":
|
|
from src.copilot import apply_request_headers
|
|
apply_request_headers(h, messages_copy)
|
|
|
|
# Short connect timeout: a reachable peer answers SYN in <100ms even on
|
|
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
|
|
stream_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=30.0, pool=5.0)
|
|
|
|
if _is_host_dead(target_url):
|
|
yield f'event: error\ndata: {json.dumps({"error": f"Upstream {_host_key(target_url)} unreachable (cooldown active)", "status": 503})}\n\n'
|
|
return
|
|
note_model_activity(target_url, model)
|
|
|
|
# ── Native Ollama streaming ──
|
|
if provider == "ollama":
|
|
_ollama_tool_calls: List[Dict] = []
|
|
try:
|
|
client = _get_http_client()
|
|
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
|
_clear_host_dead(target_url)
|
|
if r.status_code != 200:
|
|
raw = (await r.aread()).decode(errors="replace")
|
|
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
|
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
|
return
|
|
async for line in r.aiter_lines():
|
|
if not line:
|
|
continue
|
|
try:
|
|
j = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
message = j.get("message") or {}
|
|
thinking = message.get("thinking") or ""
|
|
if thinking:
|
|
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n'
|
|
content = message.get("content") or ""
|
|
if content:
|
|
yield f'data: {json.dumps({"delta": content})}\n\n'
|
|
for tc in message.get("tool_calls") or []:
|
|
fn = tc.get("function") or {}
|
|
if fn.get("name"):
|
|
_ollama_tool_calls.append({
|
|
"id": tc.get("id") or f"call_{len(_ollama_tool_calls)}",
|
|
"name": fn.get("name") or "",
|
|
"arguments": json.dumps(fn.get("arguments") or {}),
|
|
})
|
|
if j.get("done"):
|
|
if _ollama_tool_calls:
|
|
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
|
|
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
|
|
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
yield "data: [DONE]\n\n"
|
|
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
|
_cooled = _mark_host_dead(target_url)
|
|
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
|
logger.warning(f"Ollama stream connect to {target_url} failed: {e}{_tail}")
|
|
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
|
except httpx.ReadTimeout:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
|
except httpx.NetworkError:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
|
except Exception as e:
|
|
logger.error(f"Ollama stream error: {e}")
|
|
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
|
return
|
|
|
|
# ── Anthropic streaming ──
|
|
if provider == "anthropic":
|
|
_anth_input_tokens = 0
|
|
_anth_output_tokens = 0
|
|
# Track tool_use blocks: {index: {id, name, arguments_json}}
|
|
_anth_tool_blocks: Dict[int, Dict] = {}
|
|
_anth_block_idx = -1
|
|
_anth_block_type = ""
|
|
try:
|
|
client = _get_http_client()
|
|
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
|
_clear_host_dead(target_url)
|
|
if r.status_code != 200:
|
|
raw = (await r.aread()).decode(errors="replace")
|
|
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
|
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
|
return
|
|
async for line in r.aiter_lines():
|
|
# SSE allows "data:value" with no space after the colon
|
|
# (the space is optional per the spec). Some gateways and
|
|
# local servers omit it; gating on "data: " dropped their
|
|
# entire stream.
|
|
if not line or not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if not data or not data.startswith("{"):
|
|
continue
|
|
try:
|
|
j = json.loads(data)
|
|
evt = j.get("type", "")
|
|
if evt == "content_block_start":
|
|
_anth_block_idx = j.get("index", _anth_block_idx + 1)
|
|
cb = j.get("content_block") or {}
|
|
_anth_block_type = cb.get("type", "text")
|
|
if _anth_block_type == "tool_use":
|
|
_anth_tool_blocks[_anth_block_idx] = {
|
|
"id": cb.get("id") or f"call_{_anth_block_idx}",
|
|
"name": cb.get("name") or "",
|
|
"arguments": "",
|
|
}
|
|
elif evt == "content_block_delta":
|
|
delta = j.get("delta") or {}
|
|
delta_type = delta.get("type", "")
|
|
if delta_type == "text_delta":
|
|
text = delta.get("text") or ""
|
|
if text:
|
|
yield f'data: {json.dumps({"delta": text})}\n\n'
|
|
elif delta_type == "input_json_delta":
|
|
# Accumulate tool arguments JSON
|
|
idx = j.get("index", _anth_block_idx)
|
|
if idx in _anth_tool_blocks:
|
|
partial = delta.get("partial_json") or ""
|
|
_anth_tool_blocks[idx]["arguments"] += partial
|
|
# Stream tool arg deltas for doc tools
|
|
if partial and _anth_tool_blocks[idx].get("name") in ("create_document", "update_document", "edit_document"):
|
|
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _anth_tool_blocks[idx]["name"], "arg_delta": partial})}\n\n'
|
|
elif evt == "message_start":
|
|
_u = j.get("message", {}).get("usage", {})
|
|
_anth_input_tokens = _u.get("input_tokens", 0)
|
|
# Surface prompt-cache effectiveness: cache_read > 0 means the
|
|
# stable system+tools prefix was served from cache this round.
|
|
_c_read = _u.get("cache_read_input_tokens", 0)
|
|
_c_write = _u.get("cache_creation_input_tokens", 0)
|
|
if _c_read or _c_write:
|
|
logger.info(
|
|
"[anthropic-cache] read=%s write=%s fresh_input=%s",
|
|
_c_read, _c_write, _anth_input_tokens,
|
|
)
|
|
elif evt == "message_delta":
|
|
_anth_output_tokens = j.get("usage", {}).get("output_tokens", 0)
|
|
elif evt == "message_stop":
|
|
# Emit accumulated tool calls in OpenAI-compatible format
|
|
if _anth_tool_blocks:
|
|
calls = []
|
|
for idx in sorted(_anth_tool_blocks):
|
|
tb = _anth_tool_blocks[idx]
|
|
calls.append({
|
|
"id": tb["id"],
|
|
"name": tb["name"],
|
|
"arguments": tb["arguments"],
|
|
})
|
|
yield f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
|
|
if _anth_input_tokens or _anth_output_tokens:
|
|
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": _anth_input_tokens, "output_tokens": _anth_output_tokens}})}\n\n'
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
elif evt == "error":
|
|
err_msg = j.get("error", {}).get("message", "Unknown error")
|
|
yield f'event: error\ndata: {json.dumps({"error": err_msg, "status": 400})}\n\n'
|
|
return
|
|
except json.JSONDecodeError:
|
|
continue
|
|
yield "data: [DONE]\n\n"
|
|
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
|
_cooled = _mark_host_dead(target_url)
|
|
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
|
logger.warning(f"Anthropic stream connect to {target_url} failed: {e}{_tail}")
|
|
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
|
except httpx.ReadTimeout:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
|
except httpx.NetworkError:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
|
except Exception as e:
|
|
logger.error(f"Anthropic stream error: {e}")
|
|
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
|
return
|
|
|
|
# ── OpenAI-compatible streaming ──
|
|
# Accumulate native tool_calls across streaming chunks
|
|
_tc_acc: Dict[int, Dict] = {} # index -> {id, name, arguments}
|
|
_tc_last_idx = [-1] # most-recently-touched slot, for providers that omit `index`
|
|
# For thinking models: prepend <think> to first content delta so frontend
|
|
# can detect thinking-in-progress (some models output </think> but no <think>)
|
|
_thinking_model = _supports_thinking(model)
|
|
_first_content_sent = False
|
|
_in_think_tag = False # True while consuming <think>…</think> content
|
|
_think_open_stripped = False # opening <think> tag already removed
|
|
|
|
def _emit_tool_calls():
|
|
"""Build the tool_calls event string if any were accumulated."""
|
|
if not _tc_acc:
|
|
return None
|
|
calls = [_tc_acc[i] for i in sorted(_tc_acc)]
|
|
return f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
|
|
|
|
try:
|
|
client = _get_http_client()
|
|
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
|
_clear_host_dead(target_url)
|
|
if r.status_code != 200:
|
|
raw = (await r.aread()).decode(errors="replace")
|
|
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
|
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
|
return
|
|
|
|
async for line in r.aiter_lines():
|
|
if not line:
|
|
continue
|
|
|
|
# SSE allows "data:value" with no space after the colon; gating
|
|
# on "data: " silently dropped content + usage from providers
|
|
# that omit it.
|
|
if line.startswith("data:"):
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
tc_event = _emit_tool_calls()
|
|
if tc_event:
|
|
yield tc_event
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
|
|
try:
|
|
if data.strip():
|
|
if data.startswith("{"):
|
|
j = json.loads(data)
|
|
# Usage chunk (from stream_options)
|
|
_choices = j.get("choices") or []
|
|
_delta0 = _choices[0].get("delta") if (_choices and _choices[0] is not None) else None
|
|
# Capture usage whenever the chunk carries it and
|
|
# the delta has no actual output. Some gateways /
|
|
# local servers attach usage to the FINAL delta,
|
|
# which also carries role/finish_reason (so it is
|
|
# not exactly None/{}/{"content": None}); gating on
|
|
# those exact shapes discarded their token counts.
|
|
_delta_has_output = isinstance(_delta0, dict) and (
|
|
_delta0.get("content")
|
|
or _delta0.get("reasoning_content")
|
|
or _delta0.get("reasoning")
|
|
or _delta0.get("tool_calls")
|
|
)
|
|
if "usage" in j and not _delta_has_output:
|
|
u = j["usage"] or {}
|
|
_usage_data = {"input_tokens": u.get("prompt_tokens", 0), "output_tokens": u.get("completion_tokens", 0)}
|
|
# llama.cpp puts a `timings` block alongside `usage` with the
|
|
# TRUE generation speed (predicted_per_second) — pure decode,
|
|
# excluding prefill/network. Pass it through so the UI shows the
|
|
# real gen t/s instead of recomputing tokens/wall-clock (which
|
|
# includes prefill and reads ~20-40% low). Prefill speed too.
|
|
_tm = j.get("timings")
|
|
if isinstance(_tm, dict):
|
|
if _tm.get("predicted_per_second"):
|
|
_usage_data["gen_tps"] = round(_tm["predicted_per_second"], 2)
|
|
if _tm.get("prompt_per_second"):
|
|
_usage_data["prefill_tps"] = round(_tm["prompt_per_second"], 2)
|
|
yield f'data: {json.dumps({"type": "usage", "data": _usage_data})}\n\n'
|
|
elif "choices" in j:
|
|
_c0 = (j["choices"] or [None])[0]
|
|
if _c0 is None:
|
|
continue
|
|
delta = _c0.get("delta") or {}
|
|
if isinstance(delta, dict):
|
|
# Text content
|
|
# Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1, Nemotron). vLLM 0.20.2 / NIM emit the field as `reasoning`; older builds use `reasoning_content`. Accept either.
|
|
reasoning = delta.get("reasoning_content") or delta.get("reasoning") or ""
|
|
if reasoning:
|
|
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
|
|
content = delta.get("content") or ""
|
|
if content:
|
|
stripped = content.lstrip()
|
|
# Auto-detect <think>…</think> in content stream.
|
|
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose
|
|
# names don't match _THINKING_MODEL_PATTERNS but still
|
|
# emit literal <think> markup via llama.cpp --jinja.
|
|
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
|
|
_thinking_model = True
|
|
_in_think_tag = True
|
|
if _in_think_tag:
|
|
close_idx = content.lower().find("</think>")
|
|
if close_idx != -1:
|
|
# Split: up-to-</think> → thinking, remainder → content
|
|
think_part = content[:close_idx]
|
|
if not _think_open_stripped:
|
|
# Strip the opening <think[...] > from the first chunk.
|
|
# Use a dedicated flag — _first_content_sent stays False
|
|
# throughout the think block, so it must not be reused.
|
|
tag_end = think_part.lower().find(">")
|
|
if tag_end != -1:
|
|
think_part = think_part[tag_end + 1:]
|
|
_think_open_stripped = True
|
|
regular_part = content[close_idx + len("</think>"):]
|
|
_in_think_tag = False
|
|
if think_part:
|
|
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
|
|
if regular_part:
|
|
_first_content_sent = True
|
|
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
|
|
else:
|
|
# Still inside <think>: route to thinking channel
|
|
if not _think_open_stripped:
|
|
# Strip the opening <think[...] > tag (first chunk only)
|
|
tag_end = stripped.lower().find(">")
|
|
if tag_end != -1:
|
|
content = stripped[tag_end + 1:]
|
|
_think_open_stripped = True
|
|
if content:
|
|
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
|
|
else:
|
|
# Some thinking backends start normal content with a
|
|
# stray closing tag. Repair only that shape; do not
|
|
# wrap every first token for model families like
|
|
# MiniMax, which often stream ordinary answers.
|
|
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
|
|
content = "<think>" + content
|
|
_first_content_sent = True
|
|
yield f'data: {json.dumps({"delta": content})}\n\n'
|
|
# Native tool calls — accumulate across chunks
|
|
for tc in delta.get("tool_calls") or []:
|
|
if tc is None:
|
|
continue
|
|
func = tc.get("function") or {}
|
|
raw_idx = tc.get("index")
|
|
if raw_idx is None:
|
|
# Gemini's OpenAI-compat layer omits `index` on
|
|
# parallel tool calls (every delta arrives as
|
|
# index=None) and sends each call complete in one
|
|
# delta. Without this, all parallel calls collide
|
|
# into slot 0 — later calls overwrite the first's
|
|
# name and CORRUPT its arguments by concatenation,
|
|
# so only one malformed call survives and the
|
|
# follow-up round 400s. A function name marks the
|
|
# start of a new call → allocate a fresh slot;
|
|
# an arg-only continuation attaches to the last.
|
|
if func.get("name") or _tc_last_idx[0] < 0:
|
|
# Next free slot ABOVE any existing key (not
|
|
# len()), so a provider mixing integer indices
|
|
# with index=None can never collide.
|
|
idx = max(_tc_acc, default=-1) + 1
|
|
else:
|
|
idx = _tc_last_idx[0]
|
|
else:
|
|
idx = raw_idx
|
|
_tc_last_idx[0] = idx
|
|
if idx not in _tc_acc:
|
|
_tc_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
|
if tc.get("id"):
|
|
_tc_acc[idx]["id"] = tc["id"]
|
|
# Gemini 3 returns an opaque thought_signature in
|
|
# extra_content on the function-call delta. It MUST be
|
|
# echoed back on the assistant tool_call next round or the
|
|
# follow-up request 400s ("Function call is missing a
|
|
# thought_signature"). Preserve it verbatim; other
|
|
# providers never send it, so this is a no-op for them.
|
|
if tc.get("extra_content"):
|
|
_tc_acc[idx]["extra_content"] = tc["extra_content"]
|
|
if func.get("name"):
|
|
_tc_acc[idx]["name"] = func["name"]
|
|
if "arguments" in func:
|
|
_tc_acc[idx]["arguments"] += func["arguments"]
|
|
# Stream tool arg deltas for doc tools
|
|
if func["arguments"] and _tc_acc[idx].get("name") in ("create_document", "update_document", "edit_document"):
|
|
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _tc_acc[idx]["name"], "arg_delta": func["arguments"]})}\n\n'
|
|
elif "text" in j:
|
|
if j["text"]:
|
|
yield f'data: {json.dumps({"delta": j["text"]})}\n\n'
|
|
else:
|
|
if data.strip():
|
|
yield f'data: {json.dumps({"delta": data})}\n\n'
|
|
except Exception as e:
|
|
logger.error(f"Error parsing stream data: {e}")
|
|
continue
|
|
|
|
# End of stream (no explicit [DONE] received)
|
|
tc_event = _emit_tool_calls()
|
|
if tc_event:
|
|
yield tc_event
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
|
_cooled = _mark_host_dead(target_url)
|
|
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
|
logger.warning(f"Stream connect to {target_url} failed: {e}{_tail}")
|
|
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
|
except httpx.ReadTimeout:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
|
except httpx.NetworkError:
|
|
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
|
except Exception as e:
|
|
logger.error(f"Stream error: {e}")
|
|
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
|
|
|
|
|
def _summarize_stream_error(err_chunk: Optional[str]) -> str:
|
|
"""Pull a short human reason out of an `event: error` SSE chunk for the
|
|
fallback notice. Returns a generic message if it can't be parsed."""
|
|
if not err_chunk:
|
|
return "primary model failed"
|
|
try:
|
|
for line in err_chunk.split("\n"):
|
|
if line.startswith("data: "):
|
|
j = json.loads(line[6:])
|
|
txt = j.get("text") or j.get("error") or ""
|
|
status = j.get("status")
|
|
msg = (f"HTTP {status}: " if status else "") + str(txt)
|
|
return msg[:200].strip() or "primary model failed"
|
|
except Exception:
|
|
pass
|
|
return "primary model failed"
|
|
|
|
|
|
async def stream_llm_with_fallback(candidates, messages, **kwargs):
|
|
"""Wrap stream_llm with an ordered fallback chain.
|
|
|
|
`candidates` is a list of (url, model, headers). Each is tried in order,
|
|
but only retried on a *pre-content* failure — i.e. an ``event: error``
|
|
that arrives before any assistant text / tool-call data has been yielded.
|
|
Once a candidate has emitted real output we never switch (that would
|
|
duplicate streamed tokens); a later error from that candidate passes
|
|
through unchanged. The dead-host cooldown in stream_llm makes repeat
|
|
attempts at an offline primary effectively instant.
|
|
|
|
Yields the same SSE chunk protocol as stream_llm.
|
|
"""
|
|
cands = _dedupe_candidates(candidates)
|
|
if not cands:
|
|
yield f'event: error\ndata: {json.dumps({"error": "No model endpoint configured", "status": 503})}\n\n'
|
|
return
|
|
|
|
primary_model = cands[0][1]
|
|
last_error = None
|
|
for i, (url, model, headers) in enumerate(cands):
|
|
is_last = (i == len(cands) - 1)
|
|
emitted = False
|
|
retried = False
|
|
async for chunk in stream_llm(url, model, messages, headers=headers, **kwargs):
|
|
if chunk.startswith("event: error"):
|
|
if not emitted and not is_last:
|
|
# Pre-content failure with fallbacks left — swallow and
|
|
# move to the next candidate.
|
|
last_error = chunk
|
|
retried = True
|
|
if i == 0:
|
|
logger.warning(f"[fallback] primary {model} failed before output; trying fallback")
|
|
else:
|
|
logger.warning(f"[fallback] candidate {model} failed; trying next")
|
|
break
|
|
yield chunk
|
|
continue
|
|
# Any data chunk other than the terminal [DONE] means real output.
|
|
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
|
# First real output from a NON-primary candidate: tell the client
|
|
# the selected model failed and another answered. Without this the
|
|
# fallback is invisible — a misconfigured provider looks like it
|
|
# works because the reply is shown under the originally selected
|
|
# model's name (e.g. a Bedrock/Claude endpoint that 400s every
|
|
# request but appears fine because another model silently answered).
|
|
if not emitted and i > 0:
|
|
yield ('data: ' + json.dumps({
|
|
"type": "fallback",
|
|
"selected_model": primary_model,
|
|
"answered_by": model,
|
|
"reason": _summarize_stream_error(last_error),
|
|
}) + '\n\n')
|
|
emitted = True
|
|
yield chunk
|
|
if not retried:
|
|
return # candidate finished (success, or terminal error already sent)
|
|
# Every candidate failed pre-content — surface the last error.
|
|
if last_error:
|
|
yield last_error
|