fix(model-context): key context-window cache by (endpoint, model) (#2614)

get_context_length() cached the resolved context window by model id alone,
so two different remote endpoints serving the same model id (e.g. a capped
proxy at 8k vs. the full provider at 200k) collided: the first to resolve
won process-wide and the other endpoint was served the wrong window. That
silently over-trims conversations on the larger-window endpoint (it feeds
context_compactor) or overflows the smaller one (provider 400s).

Key the cache on (endpoint_url, model). Local endpoints already always
re-query, so they are unaffected.

Fixes #2603
This commit is contained in:
nubs
2026-06-05 00:50:56 +00:00
committed by GitHub
parent f8cf791491
commit 19a3fc59c9
2 changed files with 50 additions and 6 deletions

View File

@@ -7,7 +7,7 @@ Provides token estimation for context usage tracking.
import logging import logging
import sys import sys
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = {
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Cache # Cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_context_cache: Dict[str, int] = {} _context_cache: Dict[Tuple[str, str], int] = {}
def get_context_length(endpoint_url: str, model: str) -> int: def get_context_length(endpoint_url: str, model: str) -> int:
"""Get the context window size for a model. """Get the context window size for a model.
Queries /v1/models on the endpoint and looks for context_length Queries /v1/models on the endpoint and looks for context_length
or context_window fields. Caches result per model ID. or context_window fields. Caches result per (endpoint, model).
Falls back to DEFAULT_CONTEXT if unavailable. Falls back to DEFAULT_CONTEXT if unavailable.
""" """
configured_kind = _configured_endpoint_kind(endpoint_url) configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url) is_local = _is_local_endpoint(endpoint_url)
if not is_local and model in _context_cache: # Key on (endpoint_url, model): the same model id can be served by two
return _context_cache[model] # different remote endpoints with different real context windows (e.g. a
# capped proxy vs. the full provider), so caching by model id alone would
# serve one endpoint's window for the other (issue #2603).
cache_key = (endpoint_url, model)
if not is_local and cache_key in _context_cache:
return _context_cache[cache_key]
ctx = _query_context_length(endpoint_url, model) ctx = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request. # Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping # Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache. # the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")): if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[model] = ctx _context_cache[cache_key] = ctx
logger.info(f"Context length for {model}: {ctx}") logger.info(f"Context length for {model}: {ctx}")
return ctx return ctx

View File

@@ -0,0 +1,39 @@
"""Regression for #2603 — model context-window cache must be keyed per endpoint.
`get_context_length()` cached by model id alone, so two different remote endpoints
serving the same model id (e.g. a capped proxy at 8k vs. the full provider at 200k)
collided: whichever resolved first won process-wide and the other was served the
wrong window. The fix keys the cache on (endpoint_url, model).
"""
import src.model_context as mc
def _setup(monkeypatch, windows):
"""windows: {endpoint_url: context_length}. Force the remote path."""
monkeypatch.setattr(mc, "_is_local_endpoint", lambda url: False)
monkeypatch.setattr(mc, "_configured_endpoint_kind", lambda url: "api")
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: windows[url])
mc._context_cache.clear()
def test_same_model_two_remote_endpoints_get_their_own_window(monkeypatch):
a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1"
_setup(monkeypatch, {a: 8000, b: 200000})
assert mc.get_context_length(a, "shared-model") == 8000
# Same model id, different endpoint: must NOT return endpoint A's cached 8000.
assert mc.get_context_length(b, "shared-model") == 200000
def test_cache_hit_still_works_per_endpoint(monkeypatch):
a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1"
_setup(monkeypatch, {a: 8000, b: 200000})
mc.get_context_length(a, "shared-model")
mc.get_context_length(b, "shared-model")
# Both endpoints are now cached under their own key; flip the underlying
# query to prove subsequent reads come from the per-endpoint cache, not a re-query.
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: 999)
assert mc.get_context_length(a, "shared-model") == 8000
assert mc.get_context_length(b, "shared-model") == 200000