""" model_context.py Query and cache model context window sizes from OpenAI-compatible APIs. Provides token estimation for context usage tracking. """ import logging from typing import Dict, List, Optional from urllib.parse import urlparse import httpx logger = logging.getLogger(__name__) _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1"} _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.", "172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.", "192.168.", "100.") def _is_local_endpoint(url: str) -> bool: """Check if URL points to a local/private/tailscale address.""" try: host = urlparse(url).hostname or "" return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) except Exception: return False # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- DEFAULT_CONTEXT = 128000 REQUEST_TIMEOUT = 5 # Known context windows for major API models (used as fallback when /models # endpoint doesn't report context_length). # Substring matching — use the shortest unique prefix so variants get caught. KNOWN_CONTEXT_WINDOWS = { # --- Anthropic --- 'claude-sonnet-4-5': 200000, 'claude-sonnet-4-6': 200000, 'claude-sonnet-4': 200000, 'claude-opus-4': 200000, 'claude-haiku-4': 200000, 'claude-haiku-3-5': 200000, 'claude-3-5-sonnet': 200000, 'claude-3-5-haiku': 200000, 'claude-3-opus': 200000, 'claude-3-sonnet': 200000, 'claude-3-haiku': 200000, # --- OpenAI --- 'gpt-5': 400000, 'gpt-4.1': 1047576, 'gpt-4.1-mini': 1047576, 'gpt-4.1-nano': 1047576, 'gpt-4o': 128000, 'gpt-4o-mini': 128000, 'gpt-4-turbo': 128000, 'gpt-4': 8192, 'gpt-3.5-turbo': 16385, 'o1': 200000, 'o1-mini': 128000, 'o1-pro': 200000, 'o3': 200000, 'o3-mini': 200000, 'o4-mini': 200000, # --- DeepSeek --- 'deepseek-chat': 64000, 'deepseek-coder': 64000, 'deepseek-reasoner': 64000, 'deepseek-r1': 64000, 'deepseek-v3': 64000, 'deepseek-v2': 64000, # --- Google --- 'gemini-2.5-pro': 1048576, 'gemini-2.5-flash': 1048576, 'gemini-2.0-flash': 1048576, 'gemini-1.5-pro': 1048576, 'gemini-1.5-flash': 1048576, 'gemma-3': 128000, 'gemma-2': 8192, # --- Mistral --- 'mistral-large': 128000, 'mistral-medium': 32000, 'mistral-small': 32000, 'mistral-nemo': 128000, 'mistral-7b': 32000, 'mixtral': 32000, 'codestral': 32000, 'pixtral': 128000, # --- xAI --- 'grok-4': 131072, 'grok-3': 131072, 'grok-2': 131072, # --- Meta / Llama --- 'llama-4': 1048576, 'llama-3.3': 131072, 'llama-3.2': 131072, 'llama-3.1': 131072, 'llama-3': 131072, # --- Qwen --- 'qwen3': 131072, 'qwen2.5': 131072, 'qwen2': 32768, 'qwq': 32768, # --- Cohere --- 'command-r-plus': 128000, 'command-r': 128000, 'command-a': 256000, # --- Perplexity --- 'sonar-pro': 200000, 'sonar': 128000, # --- MiniMax --- 'minimax': 1000000, # --- Moonshot / Kimi --- 'moonshot': 128000, 'kimi': 128000, # --- Microsoft --- 'phi-4': 16000, 'phi-3': 128000, # --- Nvidia --- 'nemotron': 131072, # --- Yi --- 'yi-large': 32768, 'yi-1.5': 16384, # --- 01.ai --- 'yi-lightning': 16384, # --- Nous --- 'hermes': 131072, 'nous-hermes': 131072, # --- Open community --- 'dolphin': 32768, 'mythomax': 4096, 'wizard': 32768, 'openchat': 8192, 'solar': 32768, } # --------------------------------------------------------------------------- # Cache # --------------------------------------------------------------------------- _context_cache: Dict[str, int] = {} def get_context_length(endpoint_url: str, model: str) -> int: """Get the context window size for a model. Queries /v1/models on the endpoint and looks for context_length or context_window fields. Caches result per model ID. Falls back to DEFAULT_CONTEXT if unavailable. """ if model in _context_cache: return _context_cache[model] ctx = _query_context_length(endpoint_url, model) # Only cache non-default values to allow retry on next request if ctx != DEFAULT_CONTEXT: _context_cache[model] = ctx logger.info(f"Context length for {model}: {ctx}") return ctx def _lookup_known(model: str) -> Optional[int]: """Check known context windows by substring match.""" name = model.lower() basename = name.split("/")[-1] if "/" in name else name basename = basename.split(":")[0] # strip :free, :extended etc. for key, ctx in KNOWN_CONTEXT_WINDOWS.items(): if key in basename or key in name: return ctx return None def _query_context_length(endpoint_url: str, model: str) -> int: """Query the model API for context length.""" known = _lookup_known(model) api_ctx = None # Try llama.cpp /slots endpoint first — reports actual serving context if _is_local_endpoint(endpoint_url): try: base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0] r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT) if r.is_success: slots = r.json() if isinstance(slots, list) and slots: n_ctx = slots[0].get("n_ctx") if n_ctx and isinstance(n_ctx, int) and n_ctx > 0: logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}") return n_ctx except Exception: pass models_url = endpoint_url.replace("/chat/completions", "/models") try: r = httpx.get(models_url, timeout=REQUEST_TIMEOUT) if r.is_success: data = r.json() models_list = data.get("data") or [] for m in models_list: mid = m.get("id", "") if mid == model or mid.split("/")[-1] == model.split("/")[-1]: for field in ( "context_length", "context_window", "max_model_len", "max_context_length", "max_seq_len", ): val = m.get(field) if val and isinstance(val, (int, float)) and val > 0: api_ctx = int(val) break if not api_ctx: meta = m.get("meta") or m.get("model_extra") or {} if isinstance(meta, dict): # n_ctx is the actual serving context (set via -c flag in llama.cpp) for field in ("n_ctx", "context_length", "context_window", "max_model_len"): val = meta.get(field) if val and isinstance(val, (int, float)) and val > 0: api_ctx = int(val) break break except Exception as e: logger.debug(f"Failed to query context length for {model}: {e}") # For local/self-hosted endpoints, trust the API value (user set --max-model-len) # For cloud APIs, use the larger value (API can report low defaults) if api_ctx and known: _is_local = _is_local_endpoint(endpoint_url) if _is_local and api_ctx < known: logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value") return api_ctx result = max(api_ctx, known) if api_ctx < known: logger.info(f"API reported {api_ctx} for {model}, using known {known} instead") return result if api_ctx: return api_ctx if known: logger.info(f"Using known context window for {model}: {known}") return known return DEFAULT_CONTEXT def estimate_tokens(messages: List[Dict]) -> int: """Rough token estimate for a list of messages. Uses chars * 0.3 which is closer to real BPE tokenizer output than the commonly-cited chars/4 (which underestimates by ~20-30%). Also adds ~4 tokens per message for role/formatting overhead. """ total = 0 for msg in messages: total += 4 # per-message overhead (role, separators) content = msg.get("content", "") if isinstance(content, str): total += int(len(content) * 0.3) elif isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": total += int(len(item.get("text", "")) * 0.3) return total