Odysseus v1.0
This commit is contained in:
913
src/llm_core.py
Normal file
913
src/llm_core.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# src/llm_core.py
|
||||
import httpx
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LLMConfig:
|
||||
"""Configuration constants for LLM operations."""
|
||||
DEFAULT_TIMEOUT = 30
|
||||
DEFAULT_TEMPERATURE = 1.0
|
||||
DEFAULT_MAX_TOKENS = 0
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 0.5
|
||||
STREAM_TIMEOUT = 300
|
||||
|
||||
|
||||
# Cache for LLM responses
|
||||
def _get_cache_key(url: str, model: str, messages: List[Dict],
|
||||
temperature: float, max_tokens: int) -> str:
|
||||
"""Generate cache key for LLM requests."""
|
||||
hashable_messages = []
|
||||
for msg in messages:
|
||||
sorted_items = tuple(sorted(msg.items()))
|
||||
hashable_messages.append(sorted_items)
|
||||
|
||||
content = json.dumps({
|
||||
'url': url,
|
||||
'model': model,
|
||||
'messages': hashable_messages,
|
||||
'temp': temperature,
|
||||
'max_tokens': max_tokens
|
||||
}, sort_keys=True)
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
_response_cache = {}
|
||||
|
||||
# Dead-host cooldown: maps host (scheme://host:port) -> unix ts when cooldown expires.
|
||||
# When a connect to a host fails, we mark it dead for DEAD_HOST_COOLDOWN seconds so
|
||||
# subsequent calls fail instantly instead of waiting on the connect timeout. Keeps
|
||||
# one unreachable upstream from jamming chat across the rest of the app.
|
||||
#
|
||||
# But a SINGLE transient blip (local model briefly busy, a momentary
|
||||
# Tailscale hiccup) used to trip a full 60s lockout — the user saw a
|
||||
# 503 and thought the model died when it was fine a second later. So:
|
||||
# - require FAIL_THRESHOLD consecutive failures before cooling
|
||||
# - shorter cooldown so recovery is quick
|
||||
# - any success resets the failure counter immediately
|
||||
DEAD_HOST_COOLDOWN = 20.0
|
||||
_HOST_FAIL_THRESHOLD = 2
|
||||
_dead_hosts: Dict[str, float] = {}
|
||||
_host_fails: Dict[str, int] = {}
|
||||
_model_activity: Dict[str, float] = {}
|
||||
|
||||
def _model_activity_key(url: str, model: str) -> str:
|
||||
return f"{(url or '').strip().rstrip()}|{(model or '').strip()}"
|
||||
|
||||
def note_model_activity(url: str, model: str):
|
||||
"""Record that a real upstream request used this endpoint/model."""
|
||||
if not url or not model:
|
||||
return
|
||||
_model_activity[_model_activity_key(url, model)] = time.time()
|
||||
|
||||
def seconds_since_model_activity(url: str, model: str) -> Optional[float]:
|
||||
"""Seconds since the endpoint/model was last used in this process."""
|
||||
ts = _model_activity.get(_model_activity_key(url, model))
|
||||
if not ts:
|
||||
return None
|
||||
return max(0.0, time.time() - ts)
|
||||
|
||||
def _host_key(url: str) -> str:
|
||||
from urllib.parse import urlsplit
|
||||
s = urlsplit(url)
|
||||
return f"{s.scheme}://{s.netloc}" if s.scheme and s.netloc else url
|
||||
|
||||
def _is_host_dead(url: str) -> bool:
|
||||
key = _host_key(url)
|
||||
exp = _dead_hosts.get(key)
|
||||
if exp is None:
|
||||
return False
|
||||
if time.time() >= exp:
|
||||
_dead_hosts.pop(key, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _mark_host_dead(url: str) -> bool:
|
||||
"""Record a connect failure. Only actually cools the host after
|
||||
_HOST_FAIL_THRESHOLD consecutive failures. Returns True if the host
|
||||
is now cooled (so callers can log accurately), False if it's still
|
||||
within its allowed-failure grace."""
|
||||
key = _host_key(url)
|
||||
n = _host_fails.get(key, 0) + 1
|
||||
_host_fails[key] = n
|
||||
if n >= _HOST_FAIL_THRESHOLD:
|
||||
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clear_host_dead(url: str) -> None:
|
||||
key = _host_key(url)
|
||||
_dead_hosts.pop(key, None)
|
||||
_host_fails.pop(key, None)
|
||||
|
||||
|
||||
# Shared async HTTP client. Reusing one client keeps connections warm:
|
||||
# repeat calls to api.anthropic.com / api.openai.com / openrouter skip the
|
||||
# 100-500ms TCP+TLS handshake. Lazy init so we bind to the running event loop.
|
||||
_http_client: Optional[httpx.AsyncClient] = None
|
||||
_http_limits = httpx.Limits(max_connections=100, max_keepalive_connections=30, keepalive_expiry=30.0)
|
||||
|
||||
def _get_http_client() -> httpx.AsyncClient:
|
||||
"""Return process-wide AsyncClient. Per-request timeout is passed at call time."""
|
||||
global _http_client
|
||||
if _http_client is None or _http_client.is_closed:
|
||||
_http_client = httpx.AsyncClient(limits=_http_limits, http2=False)
|
||||
return _http_client
|
||||
|
||||
def _get_cached_response(cache_key: str) -> Optional[str]:
|
||||
"""Get cached response if it exists."""
|
||||
return _response_cache.get(cache_key)
|
||||
|
||||
def _set_cached_response(cache_key: str, response: str) -> None:
|
||||
"""Store response in cache."""
|
||||
if len(_response_cache) > 128:
|
||||
keys_to_remove = list(_response_cache.keys())[:64]
|
||||
for key in keys_to_remove:
|
||||
del _response_cache[key]
|
||||
_response_cache[cache_key] = response
|
||||
|
||||
# ── Anthropic native API adapter ──
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
"claude-opus-4-20250514", "claude-opus-4",
|
||||
"claude-sonnet-4-20250514", "claude-sonnet-4", "claude-sonnet-4-5-20250929", "claude-sonnet-4-5",
|
||||
"claude-haiku-4-20250514", "claude-haiku-4", "claude-haiku-3-5-20241022", "claude-haiku-3-5",
|
||||
]
|
||||
|
||||
def _detect_provider(url: str) -> str:
|
||||
"""Detect API provider from URL."""
|
||||
if "anthropic.com" in (url or ""):
|
||||
return "anthropic"
|
||||
return "openai"
|
||||
|
||||
|
||||
def _provider_label(url: str) -> str:
|
||||
"""Human-friendly provider name for error messages."""
|
||||
u = (url or "").lower()
|
||||
if "anthropic.com" in u: return "Anthropic"
|
||||
if "api.x.ai" in u or "x.ai/" in u: return "xAI"
|
||||
if "openai.com" in u: return "OpenAI"
|
||||
if "openrouter.ai" in u: return "OpenRouter"
|
||||
if "groq.com" in u: return "Groq"
|
||||
if "mistral.ai" in u: return "Mistral"
|
||||
if "deepseek.com" in u: return "DeepSeek"
|
||||
if "googleapis.com" in u or "generativelanguage" in u: return "Google"
|
||||
if "together.xyz" in u or "together.ai" in u: return "Together"
|
||||
if "fireworks.ai" in u: return "Fireworks"
|
||||
if "localhost" in u or "127.0.0.1" in u: return "local endpoint"
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
host = urlparse(url).hostname or "provider"
|
||||
return host
|
||||
except Exception:
|
||||
return "provider"
|
||||
|
||||
|
||||
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
||||
"""Turn an upstream HTTP error into a user-readable sentence.
|
||||
|
||||
Auth failures (401/403) become 'xAI rejected the API key' etc., so the UI
|
||||
stops showing raw JSON like '{"error":{"message":"User not found."}}'.
|
||||
"""
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
body = body.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
body = str(body)
|
||||
provider = _provider_label(url)
|
||||
# Try to pull a message out of the body
|
||||
detail = ""
|
||||
try:
|
||||
j = json.loads(body) if body else {}
|
||||
if isinstance(j, dict):
|
||||
err = j.get("error") or j
|
||||
if isinstance(err, dict):
|
||||
detail = (err.get("message") or err.get("detail") or "").strip()
|
||||
elif isinstance(err, str):
|
||||
detail = err.strip()
|
||||
except Exception:
|
||||
detail = (body or "").strip()[:240]
|
||||
|
||||
if status in (401, 403):
|
||||
msg = f"{provider} rejected the API key"
|
||||
if status == 403:
|
||||
msg = f"{provider} denied access (403)"
|
||||
if detail:
|
||||
msg += f" — {detail}"
|
||||
msg += ". Check Model Endpoints → {} and re-paste the key.".format(provider)
|
||||
return msg
|
||||
if status == 404:
|
||||
return f"{provider} returned 404 — check the base URL and model name." + (f" ({detail})" if detail else "")
|
||||
if status == 429:
|
||||
return f"{provider} rate-limited the request (429)." + (f" {detail}" if detail else "")
|
||||
if status >= 500:
|
||||
return f"{provider} is having an outage (HTTP {status})." + (f" {detail}" if detail else "")
|
||||
return f"{provider} returned HTTP {status}" + (f": {detail}" if detail else "")
|
||||
|
||||
# Models that require max_completion_tokens instead of max_tokens
|
||||
_MAX_COMPLETION_TOKENS_MODELS = {"o1", "o3", "o4", "gpt-4.5", "gpt-5"}
|
||||
|
||||
def _uses_max_completion_tokens(model: str) -> bool:
|
||||
"""Check if a model requires max_completion_tokens instead of max_tokens."""
|
||||
if not model:
|
||||
return False
|
||||
m = model.lower()
|
||||
return any(m.startswith(p) or f"/{p}" in m for p in _MAX_COMPLETION_TOKENS_MODELS)
|
||||
|
||||
# Models that support structured thinking — may output </think> without opening tag
|
||||
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap")
|
||||
|
||||
def _supports_thinking(model: str) -> bool:
|
||||
"""Check if model supports structured thinking output."""
|
||||
if not model:
|
||||
return False
|
||||
m = model.lower()
|
||||
return any(p in m for p in _THINKING_MODEL_PATTERNS)
|
||||
|
||||
def _convert_openai_content_to_anthropic(content):
|
||||
"""Convert OpenAI multimodal content blocks to Anthropic format.
|
||||
|
||||
Converts image_url blocks (data URI) → Anthropic image blocks.
|
||||
Passes text blocks through unchanged.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
converted = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
converted.append(block)
|
||||
continue
|
||||
if block.get("type") == "image_url":
|
||||
url = (block.get("image_url") or {}).get("url", "")
|
||||
# Parse data URI: data:image/<fmt>;base64,<data>
|
||||
if url.startswith("data:"):
|
||||
try:
|
||||
header, b64_data = url.split(",", 1)
|
||||
media_type = header.split(";")[0].replace("data:", "")
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
converted.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": b64_data,
|
||||
},
|
||||
})
|
||||
else:
|
||||
# External URL — use Anthropic's URL source
|
||||
converted.append({
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": url},
|
||||
})
|
||||
elif block.get("type") == "text":
|
||||
converted.append(block)
|
||||
else:
|
||||
converted.append(block)
|
||||
return converted
|
||||
|
||||
|
||||
def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=False, tools=None):
|
||||
"""Convert OpenAI-style messages to Anthropic format."""
|
||||
system_parts = []
|
||||
chat_messages = []
|
||||
for m in messages:
|
||||
if m.get("role") == "system":
|
||||
system_parts.append(m["content"])
|
||||
elif m.get("role") == "tool":
|
||||
# Convert OpenAI tool result to Anthropic format
|
||||
chat_messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": m.get("tool_call_id", ""),
|
||||
"content": m.get("content", ""),
|
||||
}],
|
||||
})
|
||||
elif m.get("role") == "assistant" and isinstance(m.get("tool_calls"), list):
|
||||
# Convert OpenAI assistant tool_calls to Anthropic format
|
||||
content = []
|
||||
if m.get("content"):
|
||||
content.append({"type": "text", "text": m["content"]})
|
||||
for tc in m["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
args_str = fn.get("arguments", "{}")
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
content.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", ""),
|
||||
"name": fn.get("name", ""),
|
||||
"input": args,
|
||||
})
|
||||
chat_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
# Convert multimodal content (image_url → image) for Anthropic
|
||||
content = _convert_openai_content_to_anthropic(m["content"])
|
||||
chat_messages.append({"role": m["role"], "content": content})
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if system_parts:
|
||||
payload["system"] = "\n\n".join(system_parts)
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
# Convert OpenAI-format tools to Anthropic format
|
||||
if tools:
|
||||
anthropic_tools = []
|
||||
for t in tools:
|
||||
if t.get("type") == "function":
|
||||
fn = t["function"]
|
||||
anthropic_tools.append({
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description", ""),
|
||||
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
if anthropic_tools:
|
||||
payload["tools"] = anthropic_tools
|
||||
return payload
|
||||
|
||||
def _build_anthropic_headers(headers):
|
||||
"""Convert Bearer auth to x-api-key for Anthropic."""
|
||||
h = {"Content-Type": "application/json", "anthropic-version": "2023-06-01"}
|
||||
if headers:
|
||||
for k, v in headers.items():
|
||||
if k.lower() == "authorization" and isinstance(v, str) and v.startswith("Bearer "):
|
||||
h["x-api-key"] = v[7:]
|
||||
else:
|
||||
h[k] = v
|
||||
return h
|
||||
|
||||
def _parse_anthropic_response(data: dict) -> str:
|
||||
"""Extract text from Anthropic response."""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
def _normalize_anthropic_url(url: str) -> str:
|
||||
"""Ensure Anthropic URL points to /v1/messages."""
|
||||
url = url.rstrip("/")
|
||||
if url.endswith("/v1/messages"):
|
||||
return url
|
||||
if url.endswith("/v1"):
|
||||
return url + "/messages"
|
||||
return url + "/v1/messages"
|
||||
|
||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
||||
"""List available model IDs from an endpoint."""
|
||||
if _detect_provider(base_chat_url) == "anthropic":
|
||||
return list(ANTHROPIC_MODELS)
|
||||
try:
|
||||
h = {}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
r = httpx.get(base_chat_url.replace("/chat/completions", "/models"), headers=h, timeout=timeout)
|
||||
r.raise_for_status()
|
||||
return [m.get("id") for m in (r.json().get("data") or []) if m.get("id")]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
|
||||
"""Normalize a model ID to match available models."""
|
||||
avail = list_model_ids(endpoint_url, timeout)
|
||||
if not avail:
|
||||
return None
|
||||
if requested in avail:
|
||||
return requested
|
||||
import os as _os
|
||||
req_base = _os.path.basename(requested.rstrip("/"))
|
||||
for a in avail:
|
||||
if _os.path.basename(a.rstrip("/")) == req_base:
|
||||
return a
|
||||
return None
|
||||
|
||||
def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
|
||||
timeout: int = LLMConfig.DEFAULT_TIMEOUT, prompt_type: Optional[str] = None) -> str:
|
||||
"""Synchronous LLM call with optional prompt type enhancement."""
|
||||
h = {"Content-Type": "application/json"}
|
||||
# Tolerate headers that arrive as a JSON string (some sessions stored them
|
||||
# double-encoded) — otherwise h.update() throws "dictionary update sequence
|
||||
# element #0 has length 1; 2 is required".
|
||||
if isinstance(headers, str):
|
||||
try:
|
||||
headers = json.loads(headers)
|
||||
except Exception:
|
||||
headers = None
|
||||
if isinstance(headers, dict):
|
||||
h.update(headers)
|
||||
|
||||
messages_copy = [msg.copy() for msg in messages]
|
||||
|
||||
# Consolidate multiple system messages into one at the start.
|
||||
sys_parts = []
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
||||
else:
|
||||
messages_copy = non_sys
|
||||
|
||||
provider = _detect_provider(url)
|
||||
cache_key = _get_cache_key(url, model, messages_copy, temperature, max_tokens)
|
||||
cached_response = _get_cached_response(cache_key)
|
||||
if cached_response:
|
||||
logger.debug(f"Returning cached response for key: {cache_key}")
|
||||
return cached_response
|
||||
|
||||
if provider == "anthropic":
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages_copy,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens and max_tokens > 0:
|
||||
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
||||
payload[tok_key] = max_tokens
|
||||
try:
|
||||
note_model_activity(target_url, model)
|
||||
r = httpx.post(target_url, headers=h, json=payload, timeout=timeout)
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"POST {target_url} failed: {e}")
|
||||
if not r.is_success:
|
||||
raise HTTPException(502, f"Upstream {target_url} -> {r.status_code}: {r.text}")
|
||||
data = r.json()
|
||||
try:
|
||||
if provider == "anthropic":
|
||||
response = _parse_anthropic_response(data)
|
||||
else:
|
||||
response = data["choices"][0]["message"]["content"]
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(502, f"Unexpected schema from {target_url}: {str(data)[:400]}")
|
||||
|
||||
|
||||
def llm_call_with_fallback(candidates, messages, **kwargs) -> str:
|
||||
"""Sync `llm_call` with an ordered fallback chain.
|
||||
|
||||
`candidates` is a list of (url, model, headers). The first one that returns
|
||||
without an exception wins. Connection / 5xx-style failures fall through to
|
||||
the next candidate. The dead-host cooldown inside `llm_call` makes repeat
|
||||
attempts at an offline primary effectively free.
|
||||
"""
|
||||
cands = [c for c in (candidates or []) if c and c[0] and c[1]]
|
||||
if not cands:
|
||||
raise HTTPException(503, "No model endpoint configured")
|
||||
last_err = None
|
||||
for i, (url, model, headers) in enumerate(cands):
|
||||
try:
|
||||
return llm_call(url, model, messages, headers=headers, **kwargs)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
tag = "primary" if i == 0 else "candidate"
|
||||
logger.warning(f"[fallback] {tag} {model} failed ({type(e).__name__}); trying next")
|
||||
continue
|
||||
raise last_err if last_err else HTTPException(503, "All fallback candidates failed")
|
||||
|
||||
|
||||
async def llm_call_async_with_fallback(candidates, messages, **kwargs) -> str:
|
||||
"""Async variant of `llm_call_with_fallback` — same semantics."""
|
||||
cands = [c for c in (candidates or []) if c and c[0] and c[1]]
|
||||
if not cands:
|
||||
raise HTTPException(503, "No model endpoint configured")
|
||||
last_err = None
|
||||
for i, (url, model, headers) in enumerate(cands):
|
||||
try:
|
||||
return await llm_call_async(url, model, messages, headers=headers, **kwargs)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
tag = "primary" if i == 0 else "candidate"
|
||||
logger.warning(f"[fallback] {tag} {model} failed ({type(e).__name__}); trying next")
|
||||
continue
|
||||
raise last_err if last_err else HTTPException(503, "All fallback candidates failed")
|
||||
|
||||
|
||||
async def llm_call_async(
|
||||
url: str,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS,
|
||||
headers: Optional[Dict] = None,
|
||||
timeout: int = LLMConfig.STREAM_TIMEOUT,
|
||||
max_retries: int = LLMConfig.MAX_RETRIES,
|
||||
prompt_type: Optional[str] = None
|
||||
) -> str:
|
||||
"""Asynchronous LLM call using httpx with connection pooling, timeout, retry logic, and performance logging."""
|
||||
provider = _detect_provider(url)
|
||||
messages_copy = [msg.copy() for msg in messages]
|
||||
|
||||
# Consolidate multiple system messages into one at the start.
|
||||
sys_parts = []
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
||||
else:
|
||||
messages_copy = non_sys
|
||||
|
||||
cache_key = _get_cache_key(url, model, messages_copy, temperature, max_tokens)
|
||||
cached_response = _get_cached_response(cache_key)
|
||||
if cached_response:
|
||||
logger.debug(f"Returning cached response for key: {cache_key}")
|
||||
return cached_response
|
||||
|
||||
if provider == "anthropic":
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens)
|
||||
else:
|
||||
target_url = url
|
||||
h = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages_copy,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens and max_tokens > 0:
|
||||
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
||||
payload[tok_key] = max_tokens
|
||||
|
||||
if _is_host_dead(target_url):
|
||||
raise HTTPException(503, f"Upstream {_host_key(target_url)} marked unreachable (cooldown active)")
|
||||
|
||||
call_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=10.0, pool=5.0)
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
attempt += 1
|
||||
start = time.time()
|
||||
try:
|
||||
note_model_activity(target_url, model)
|
||||
client = _get_http_client()
|
||||
r = await client.post(target_url, headers=h, json=payload, timeout=call_timeout)
|
||||
duration = time.time() - start
|
||||
if not r.is_success:
|
||||
friendly = _format_upstream_error(r.status_code, r.text, target_url)
|
||||
logger.warning(
|
||||
f"LLM async call to {target_url} failed in {duration:.2f}s "
|
||||
f"(attempt {attempt}): HTTP {r.status_code} {friendly}"
|
||||
)
|
||||
raise HTTPException(r.status_code, friendly)
|
||||
logger.info(f"LLM async call to {target_url} succeeded in {duration:.2f}s (attempt {attempt})")
|
||||
_clear_host_dead(target_url)
|
||||
data = r.json()
|
||||
try:
|
||||
if provider == "anthropic":
|
||||
response = _parse_anthropic_response(data)
|
||||
else:
|
||||
response = data["choices"][0]["message"]["content"]
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(502, f"Unexpected schema from {target_url}: {str(data)[:400]}")
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
duration = time.time() - start
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"LLM async connect to {target_url} failed after {duration:.2f}s: {e}{_tail}")
|
||||
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
duration = time.time() - start
|
||||
logger.warning(f"LLM async call attempt {attempt} failed after {duration:.2f}s: {e}")
|
||||
if attempt >= max_retries:
|
||||
raise HTTPException(502, f"POST {target_url} failed after {max_retries} attempts: {e}")
|
||||
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
||||
|
||||
async def stream_llm(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
|
||||
timeout: int = LLMConfig.STREAM_TIMEOUT, prompt_type: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None):
|
||||
"""Stream LLM responses with improved error handling.
|
||||
|
||||
Yields SSE chunks:
|
||||
- data: {"delta": "text"} — text content
|
||||
- data: {"type": "tool_calls", ...} — accumulated native tool calls (before DONE)
|
||||
- event: error — errors
|
||||
- data: [DONE] — end of stream
|
||||
"""
|
||||
provider = _detect_provider(url)
|
||||
messages_copy = [msg.copy() for msg in messages]
|
||||
|
||||
# Consolidate multiple system messages into one at the start.
|
||||
# Some models (e.g. Qwen3.5) reject system messages that aren't first.
|
||||
sys_parts = []
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
messages_copy = [{"role": "system", "content": "\n\n".join(sys_parts)}] + non_sys
|
||||
else:
|
||||
messages_copy = non_sys
|
||||
|
||||
if provider == "anthropic":
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages_copy,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
if max_tokens and max_tokens > 0:
|
||||
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
||||
payload[tok_key] = max_tokens
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
h = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
h.update(headers)
|
||||
|
||||
# Short connect timeout: a reachable peer answers SYN in <100ms even on
|
||||
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
|
||||
stream_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=30.0, pool=5.0)
|
||||
|
||||
if _is_host_dead(target_url):
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Upstream {_host_key(target_url)} unreachable (cooldown active)", "status": 503})}\n\n'
|
||||
return
|
||||
note_model_activity(target_url, model)
|
||||
|
||||
# ── Anthropic streaming ──
|
||||
if provider == "anthropic":
|
||||
_anth_input_tokens = 0
|
||||
_anth_output_tokens = 0
|
||||
# Track tool_use blocks: {index: {id, name, arguments_json}}
|
||||
_anth_tool_blocks: Dict[int, Dict] = {}
|
||||
_anth_block_idx = -1
|
||||
_anth_block_type = ""
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
_clear_host_dead(target_url)
|
||||
if r.status_code != 200:
|
||||
raw = (await r.aread()).decode(errors="replace")
|
||||
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
||||
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||
return
|
||||
async for line in r.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if not data or not data.startswith("{"):
|
||||
continue
|
||||
try:
|
||||
j = json.loads(data)
|
||||
evt = j.get("type", "")
|
||||
if evt == "content_block_start":
|
||||
_anth_block_idx = j.get("index", _anth_block_idx + 1)
|
||||
cb = j.get("content_block", {})
|
||||
_anth_block_type = cb.get("type", "text")
|
||||
if _anth_block_type == "tool_use":
|
||||
_anth_tool_blocks[_anth_block_idx] = {
|
||||
"id": cb.get("id", f"call_{_anth_block_idx}"),
|
||||
"name": cb.get("name", ""),
|
||||
"arguments": "",
|
||||
}
|
||||
elif evt == "content_block_delta":
|
||||
delta = j.get("delta", {})
|
||||
delta_type = delta.get("type", "")
|
||||
if delta_type == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
yield f'data: {json.dumps({"delta": text})}\n\n'
|
||||
elif delta_type == "input_json_delta":
|
||||
# Accumulate tool arguments JSON
|
||||
idx = j.get("index", _anth_block_idx)
|
||||
if idx in _anth_tool_blocks:
|
||||
partial = delta.get("partial_json", "")
|
||||
_anth_tool_blocks[idx]["arguments"] += partial
|
||||
# Stream tool arg deltas for doc tools
|
||||
if partial and _anth_tool_blocks[idx].get("name") in ("create_document", "update_document", "edit_document"):
|
||||
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _anth_tool_blocks[idx]["name"], "arg_delta": partial})}\n\n'
|
||||
elif evt == "message_start":
|
||||
_anth_input_tokens = j.get("message", {}).get("usage", {}).get("input_tokens", 0)
|
||||
elif evt == "message_delta":
|
||||
_anth_output_tokens = j.get("usage", {}).get("output_tokens", 0)
|
||||
elif evt == "message_stop":
|
||||
# Emit accumulated tool calls in OpenAI-compatible format
|
||||
if _anth_tool_blocks:
|
||||
calls = []
|
||||
for idx in sorted(_anth_tool_blocks):
|
||||
tb = _anth_tool_blocks[idx]
|
||||
calls.append({
|
||||
"id": tb["id"],
|
||||
"name": tb["name"],
|
||||
"arguments": tb["arguments"],
|
||||
})
|
||||
yield f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
|
||||
if _anth_input_tokens or _anth_output_tokens:
|
||||
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": _anth_input_tokens, "output_tokens": _anth_output_tokens}})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
elif evt == "error":
|
||||
err_msg = j.get("error", {}).get("message", "Unknown error")
|
||||
yield f'event: error\ndata: {json.dumps({"error": err_msg, "status": 400})}\n\n'
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield "data: [DONE]\n\n"
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"Anthropic stream connect to {target_url} failed: {e}{_tail}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||
except httpx.ReadTimeout:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||
except httpx.NetworkError:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic stream error: {e}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||
return
|
||||
|
||||
# ── OpenAI-compatible streaming ──
|
||||
# Accumulate native tool_calls across streaming chunks
|
||||
_tc_acc: Dict[int, Dict] = {} # index -> {id, name, arguments}
|
||||
# For thinking models: prepend <think> to first content delta so frontend
|
||||
# can detect thinking-in-progress (some models output </think> but no <think>)
|
||||
_thinking_model = _supports_thinking(model)
|
||||
_first_content_sent = False
|
||||
|
||||
def _emit_tool_calls():
|
||||
"""Build the tool_calls event string if any were accumulated."""
|
||||
if not _tc_acc:
|
||||
return None
|
||||
calls = [_tc_acc[i] for i in sorted(_tc_acc)]
|
||||
return f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
|
||||
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
_clear_host_dead(target_url)
|
||||
if r.status_code != 200:
|
||||
raw = (await r.aread()).decode(errors="replace")
|
||||
friendly = _format_upstream_error(r.status_code, raw, target_url)
|
||||
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||
return
|
||||
|
||||
async for line in r.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
tc_event = _emit_tool_calls()
|
||||
if tc_event:
|
||||
yield tc_event
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
try:
|
||||
if data.strip():
|
||||
if data.startswith("{"):
|
||||
j = json.loads(data)
|
||||
# Usage chunk (from stream_options)
|
||||
_choices = j.get("choices") or []
|
||||
_delta0 = _choices[0].get("delta") if _choices else None
|
||||
if "usage" in j and _delta0 in (None, {}, {"content": None}):
|
||||
u = j["usage"]
|
||||
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": u.get("prompt_tokens", 0), "output_tokens": u.get("completion_tokens", 0)}})}\n\n'
|
||||
elif "choices" in j:
|
||||
delta = j["choices"][0].get("delta", {})
|
||||
if isinstance(delta, dict):
|
||||
# Text content
|
||||
# Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1)
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
# Some thinking backends start normal content with a
|
||||
# stray closing tag. Repair only that shape; do not
|
||||
# wrap every first token for model families like
|
||||
# MiniMax, which often stream ordinary answers.
|
||||
if _thinking_model and not _first_content_sent and content.lstrip().lower().startswith("</think"):
|
||||
content = "<think>" + content
|
||||
_first_content_sent = True
|
||||
yield f'data: {json.dumps({"delta": content})}\n\n'
|
||||
# Native tool calls — accumulate across chunks
|
||||
for tc in delta.get("tool_calls", []):
|
||||
idx = tc.get("index", 0)
|
||||
if idx not in _tc_acc:
|
||||
_tc_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.get("id"):
|
||||
_tc_acc[idx]["id"] = tc["id"]
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
_tc_acc[idx]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
_tc_acc[idx]["arguments"] += func["arguments"]
|
||||
# Stream tool arg deltas for doc tools
|
||||
if func["arguments"] and _tc_acc[idx].get("name") in ("create_document", "update_document", "edit_document"):
|
||||
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _tc_acc[idx]["name"], "arg_delta": func["arguments"]})}\n\n'
|
||||
elif "text" in j:
|
||||
if j["text"]:
|
||||
yield f'data: {json.dumps({"delta": j["text"]})}\n\n'
|
||||
else:
|
||||
if data.strip():
|
||||
yield f'data: {json.dumps({"delta": data})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing stream data: {e}")
|
||||
continue
|
||||
|
||||
# End of stream (no explicit [DONE] received)
|
||||
tc_event = _emit_tool_calls()
|
||||
if tc_event:
|
||||
yield tc_event
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"Stream connect to {target_url} failed: {e}{_tail}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||
except httpx.ReadTimeout:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||
except httpx.NetworkError:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||
|
||||
|
||||
async def stream_llm_with_fallback(candidates, messages, **kwargs):
|
||||
"""Wrap stream_llm with an ordered fallback chain.
|
||||
|
||||
`candidates` is a list of (url, model, headers). Each is tried in order,
|
||||
but only retried on a *pre-content* failure — i.e. an ``event: error``
|
||||
that arrives before any assistant text / tool-call data has been yielded.
|
||||
Once a candidate has emitted real output we never switch (that would
|
||||
duplicate streamed tokens); a later error from that candidate passes
|
||||
through unchanged. The dead-host cooldown in stream_llm makes repeat
|
||||
attempts at an offline primary effectively instant.
|
||||
|
||||
Yields the same SSE chunk protocol as stream_llm.
|
||||
"""
|
||||
cands = [c for c in (candidates or []) if c and c[0] and c[1]]
|
||||
if not cands:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "No model endpoint configured", "status": 503})}\n\n'
|
||||
return
|
||||
|
||||
last_error = None
|
||||
for i, (url, model, headers) in enumerate(cands):
|
||||
is_last = (i == len(cands) - 1)
|
||||
emitted = False
|
||||
retried = False
|
||||
async for chunk in stream_llm(url, model, messages, headers=headers, **kwargs):
|
||||
if chunk.startswith("event: error"):
|
||||
if not emitted and not is_last:
|
||||
# Pre-content failure with fallbacks left — swallow and
|
||||
# move to the next candidate.
|
||||
last_error = chunk
|
||||
retried = True
|
||||
if i == 0:
|
||||
logger.warning(f"[fallback] primary {model} failed before output; trying fallback")
|
||||
else:
|
||||
logger.warning(f"[fallback] candidate {model} failed; trying next")
|
||||
break
|
||||
yield chunk
|
||||
continue
|
||||
# Any data chunk other than the terminal [DONE] means real output.
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
emitted = True
|
||||
yield chunk
|
||||
if not retried:
|
||||
return # candidate finished (success, or terminal error already sent)
|
||||
# Every candidate failed pre-content — surface the last error.
|
||||
if last_error:
|
||||
yield last_error
|
||||
Reference in New Issue
Block a user