KNOWN_CONTEXT_WINDOWS lists 'o1' (200k) before 'o1-mini' (128k), and _lookup_known returned on the first substring hit — so "o1-mini" matched 'o1' and reported 200000 instead of 128000. Track the longest matching key instead, so the most specific entry wins regardless of table order.
298 lines
9.5 KiB
Python
298 lines
9.5 KiB
Python
"""
|
|
model_context.py
|
|
|
|
Query and cache model context window sizes from OpenAI-compatible APIs.
|
|
Provides token estimation for context usage tracking.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"}
|
|
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
|
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
|
|
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
|
|
"172.30.", "172.31.", "192.168.", "100.")
|
|
|
|
|
|
def _is_local_endpoint(url: str) -> bool:
|
|
"""Check if URL points to a local/private/tailscale address."""
|
|
try:
|
|
host = urlparse(url).hostname or ""
|
|
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
|
|
except Exception:
|
|
return False
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
DEFAULT_CONTEXT = 128000
|
|
REQUEST_TIMEOUT = 5
|
|
|
|
# Known context windows for major API models (used as fallback when /models
|
|
# endpoint doesn't report context_length).
|
|
# Substring matching — use the shortest unique prefix so variants get caught.
|
|
KNOWN_CONTEXT_WINDOWS = {
|
|
# --- Anthropic ---
|
|
'claude-sonnet-4-5': 200000,
|
|
'claude-sonnet-4-6': 200000,
|
|
'claude-sonnet-4': 200000,
|
|
'claude-opus-4': 200000,
|
|
'claude-haiku-4': 200000,
|
|
'claude-haiku-3-5': 200000,
|
|
'claude-3-5-sonnet': 200000,
|
|
'claude-3-5-haiku': 200000,
|
|
'claude-3-opus': 200000,
|
|
'claude-3-sonnet': 200000,
|
|
'claude-3-haiku': 200000,
|
|
|
|
# --- OpenAI ---
|
|
'gpt-5': 400000,
|
|
'gpt-4.1': 1047576,
|
|
'gpt-4.1-mini': 1047576,
|
|
'gpt-4.1-nano': 1047576,
|
|
'gpt-4o': 128000,
|
|
'gpt-4o-mini': 128000,
|
|
'gpt-4-turbo': 128000,
|
|
'gpt-4': 8192,
|
|
'gpt-3.5-turbo': 16385,
|
|
'o1': 200000,
|
|
'o1-mini': 128000,
|
|
'o1-pro': 200000,
|
|
'o3': 200000,
|
|
'o3-mini': 200000,
|
|
'o4-mini': 200000,
|
|
|
|
# --- DeepSeek ---
|
|
'deepseek-chat': 64000,
|
|
'deepseek-coder': 64000,
|
|
'deepseek-reasoner': 64000,
|
|
'deepseek-r1': 64000,
|
|
'deepseek-v3': 64000,
|
|
'deepseek-v2': 64000,
|
|
|
|
# --- Google ---
|
|
'gemini-2.5-pro': 1048576,
|
|
'gemini-2.5-flash': 1048576,
|
|
'gemini-2.0-flash': 1048576,
|
|
'gemini-1.5-pro': 1048576,
|
|
'gemini-1.5-flash': 1048576,
|
|
'gemma-3': 128000,
|
|
'gemma-2': 8192,
|
|
|
|
# --- Mistral ---
|
|
'mistral-large': 128000,
|
|
'mistral-medium': 32000,
|
|
'mistral-small': 32000,
|
|
'mistral-nemo': 128000,
|
|
'mistral-7b': 32000,
|
|
'mixtral': 32000,
|
|
'codestral': 32000,
|
|
'pixtral': 128000,
|
|
|
|
# --- xAI ---
|
|
'grok-4': 131072,
|
|
'grok-3': 131072,
|
|
'grok-2': 131072,
|
|
|
|
# --- Meta / Llama ---
|
|
'llama-4': 1048576,
|
|
'llama-3.3': 131072,
|
|
'llama-3.2': 131072,
|
|
'llama-3.1': 131072,
|
|
'llama-3': 131072,
|
|
|
|
# --- Qwen ---
|
|
'qwen3': 131072,
|
|
'qwen2.5': 131072,
|
|
'qwen2': 32768,
|
|
'qwq': 32768,
|
|
|
|
# --- Cohere ---
|
|
'command-r-plus': 128000,
|
|
'command-r': 128000,
|
|
'command-a': 256000,
|
|
|
|
# --- Perplexity ---
|
|
'sonar-pro': 200000,
|
|
'sonar': 128000,
|
|
|
|
# --- MiniMax ---
|
|
'minimax': 1000000,
|
|
|
|
# --- Moonshot / Kimi ---
|
|
'moonshot': 128000,
|
|
'kimi': 128000,
|
|
|
|
# --- Microsoft ---
|
|
'phi-4': 16000,
|
|
'phi-3': 128000,
|
|
|
|
# --- Nvidia ---
|
|
'nemotron': 131072,
|
|
|
|
# --- Yi ---
|
|
'yi-large': 32768,
|
|
'yi-1.5': 16384,
|
|
|
|
# --- 01.ai ---
|
|
'yi-lightning': 16384,
|
|
|
|
# --- Nous ---
|
|
'hermes': 131072,
|
|
'nous-hermes': 131072,
|
|
|
|
# --- Open community ---
|
|
'dolphin': 32768,
|
|
'mythomax': 4096,
|
|
'wizard': 32768,
|
|
'openchat': 8192,
|
|
'solar': 32768,
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cache
|
|
# ---------------------------------------------------------------------------
|
|
_context_cache: Dict[str, int] = {}
|
|
|
|
|
|
def get_context_length(endpoint_url: str, model: str) -> int:
|
|
"""Get the context window size for a model.
|
|
|
|
Queries /v1/models on the endpoint and looks for context_length
|
|
or context_window fields. Caches result per model ID.
|
|
Falls back to DEFAULT_CONTEXT if unavailable.
|
|
"""
|
|
is_local = _is_local_endpoint(endpoint_url)
|
|
if not is_local and model in _context_cache:
|
|
return _context_cache[model]
|
|
|
|
ctx = _query_context_length(endpoint_url, model)
|
|
# Only cache non-default values to allow retry on next request.
|
|
# Local endpoints can restart with a different --max-model-len while keeping
|
|
# the same model id, so always re-query them instead of serving stale cache.
|
|
if not is_local and ctx != DEFAULT_CONTEXT:
|
|
_context_cache[model] = ctx
|
|
logger.info(f"Context length for {model}: {ctx}")
|
|
return ctx
|
|
|
|
|
|
def _lookup_known(model: str) -> Optional[int]:
|
|
"""Check known context windows by substring match.
|
|
|
|
Picks the LONGEST matching key so a short key never shadows a more specific
|
|
one. Without this, 'o1' (200k) precedes 'o1-mini' (128k) in the table and a
|
|
first-match return would report o1-mini's window as 200k.
|
|
"""
|
|
name = model.lower()
|
|
basename = name.split("/")[-1] if "/" in name else name
|
|
basename = basename.split(":")[0] # strip :free, :extended etc.
|
|
best_key: Optional[str] = None
|
|
best_ctx: Optional[int] = None
|
|
for key, ctx in KNOWN_CONTEXT_WINDOWS.items():
|
|
if key in basename or key in name:
|
|
if best_key is None or len(key) > len(best_key):
|
|
best_key, best_ctx = key, ctx
|
|
return best_ctx
|
|
|
|
|
|
def _query_context_length(endpoint_url: str, model: str) -> int:
|
|
"""Query the model API for context length."""
|
|
known = _lookup_known(model)
|
|
api_ctx = None
|
|
|
|
# Try llama.cpp /slots endpoint first — reports actual serving context
|
|
if _is_local_endpoint(endpoint_url):
|
|
try:
|
|
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
|
|
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
|
|
if r.is_success:
|
|
slots = r.json()
|
|
if isinstance(slots, list) and slots:
|
|
n_ctx = slots[0].get("n_ctx")
|
|
if n_ctx and isinstance(n_ctx, int) and n_ctx > 0:
|
|
logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}")
|
|
return n_ctx
|
|
except Exception:
|
|
pass
|
|
|
|
models_url = endpoint_url.replace("/chat/completions", "/models")
|
|
try:
|
|
r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)
|
|
if r.is_success:
|
|
data = r.json()
|
|
models_list = data.get("data") or []
|
|
|
|
for m in models_list:
|
|
mid = m.get("id", "")
|
|
if mid == model or mid.split("/")[-1] == model.split("/")[-1]:
|
|
for field in (
|
|
"context_length",
|
|
"context_window",
|
|
"max_model_len",
|
|
"max_context_length",
|
|
"max_seq_len",
|
|
):
|
|
val = m.get(field)
|
|
if val and isinstance(val, (int, float)) and val > 0:
|
|
api_ctx = int(val)
|
|
break
|
|
|
|
if not api_ctx:
|
|
meta = m.get("meta") or m.get("model_extra") or {}
|
|
if isinstance(meta, dict):
|
|
# n_ctx is the actual serving context (set via -c flag in llama.cpp)
|
|
for field in ("n_ctx", "context_length", "context_window", "max_model_len"):
|
|
val = meta.get(field)
|
|
if val and isinstance(val, (int, float)) and val > 0:
|
|
api_ctx = int(val)
|
|
break
|
|
break
|
|
except Exception as e:
|
|
logger.debug(f"Failed to query context length for {model}: {e}")
|
|
|
|
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
|
|
# For cloud APIs, use the larger value (API can report low defaults)
|
|
if api_ctx and known:
|
|
_is_local = _is_local_endpoint(endpoint_url)
|
|
if _is_local and api_ctx < known:
|
|
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
|
|
return api_ctx
|
|
result = max(api_ctx, known)
|
|
if api_ctx < known:
|
|
logger.info(f"API reported {api_ctx} for {model}, using known {known} instead")
|
|
return result
|
|
if api_ctx:
|
|
return api_ctx
|
|
if known:
|
|
logger.info(f"Using known context window for {model}: {known}")
|
|
return known
|
|
|
|
return DEFAULT_CONTEXT
|
|
|
|
|
|
def estimate_tokens(messages: List[Dict]) -> int:
|
|
"""Rough token estimate for a list of messages.
|
|
|
|
Uses chars * 0.3 which is closer to real BPE tokenizer output
|
|
than the commonly-cited chars/4 (which underestimates by ~20-30%).
|
|
Also adds ~4 tokens per message for role/formatting overhead.
|
|
"""
|
|
total = 0
|
|
for msg in messages:
|
|
total += 4 # per-message overhead (role, separators)
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
total += int(len(content) * 0.3)
|
|
elif isinstance(content, list):
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|
total += int(len(item.get("text", "")) * 0.3)
|
|
return total
|