fix(model-context): key context-window cache by (endpoint, model) (#2614)
get_context_length() cached the resolved context window by model id alone, so two different remote endpoints serving the same model id (e.g. a capped proxy at 8k vs. the full provider at 200k) collided: the first to resolve won process-wide and the other endpoint was served the wrong window. That silently over-trims conversations on the larger-window endpoint (it feeds context_compactor) or overflows the smaller one (provider 400s). Key the cache on (endpoint_url, model). Local endpoints already always re-query, so they are unaffected. Fixes #2603
This commit is contained in:
@@ -7,7 +7,7 @@ Provides token estimation for context usage tracking.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = {
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Cache
|
# Cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
_context_cache: Dict[str, int] = {}
|
_context_cache: Dict[Tuple[str, str], int] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_context_length(endpoint_url: str, model: str) -> int:
|
def get_context_length(endpoint_url: str, model: str) -> int:
|
||||||
"""Get the context window size for a model.
|
"""Get the context window size for a model.
|
||||||
|
|
||||||
Queries /v1/models on the endpoint and looks for context_length
|
Queries /v1/models on the endpoint and looks for context_length
|
||||||
or context_window fields. Caches result per model ID.
|
or context_window fields. Caches result per (endpoint, model).
|
||||||
Falls back to DEFAULT_CONTEXT if unavailable.
|
Falls back to DEFAULT_CONTEXT if unavailable.
|
||||||
"""
|
"""
|
||||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||||
is_local = _is_local_endpoint(endpoint_url)
|
is_local = _is_local_endpoint(endpoint_url)
|
||||||
if not is_local and model in _context_cache:
|
# Key on (endpoint_url, model): the same model id can be served by two
|
||||||
return _context_cache[model]
|
# different remote endpoints with different real context windows (e.g. a
|
||||||
|
# capped proxy vs. the full provider), so caching by model id alone would
|
||||||
|
# serve one endpoint's window for the other (issue #2603).
|
||||||
|
cache_key = (endpoint_url, model)
|
||||||
|
if not is_local and cache_key in _context_cache:
|
||||||
|
return _context_cache[cache_key]
|
||||||
|
|
||||||
ctx = _query_context_length(endpoint_url, model)
|
ctx = _query_context_length(endpoint_url, model)
|
||||||
# Only cache non-default values to allow retry on next request.
|
# Only cache non-default values to allow retry on next request.
|
||||||
# Local endpoints can restart with a different --max-model-len while keeping
|
# Local endpoints can restart with a different --max-model-len while keeping
|
||||||
# the same model id, so always re-query them instead of serving stale cache.
|
# the same model id, so always re-query them instead of serving stale cache.
|
||||||
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
||||||
_context_cache[model] = ctx
|
_context_cache[cache_key] = ctx
|
||||||
logger.info(f"Context length for {model}: {ctx}")
|
logger.info(f"Context length for {model}: {ctx}")
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
|||||||
39
tests/test_context_cache_per_endpoint.py
Normal file
39
tests/test_context_cache_per_endpoint.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Regression for #2603 — model context-window cache must be keyed per endpoint.
|
||||||
|
|
||||||
|
`get_context_length()` cached by model id alone, so two different remote endpoints
|
||||||
|
serving the same model id (e.g. a capped proxy at 8k vs. the full provider at 200k)
|
||||||
|
collided: whichever resolved first won process-wide and the other was served the
|
||||||
|
wrong window. The fix keys the cache on (endpoint_url, model).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import src.model_context as mc
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(monkeypatch, windows):
|
||||||
|
"""windows: {endpoint_url: context_length}. Force the remote path."""
|
||||||
|
monkeypatch.setattr(mc, "_is_local_endpoint", lambda url: False)
|
||||||
|
monkeypatch.setattr(mc, "_configured_endpoint_kind", lambda url: "api")
|
||||||
|
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: windows[url])
|
||||||
|
mc._context_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_model_two_remote_endpoints_get_their_own_window(monkeypatch):
|
||||||
|
a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1"
|
||||||
|
_setup(monkeypatch, {a: 8000, b: 200000})
|
||||||
|
|
||||||
|
assert mc.get_context_length(a, "shared-model") == 8000
|
||||||
|
# Same model id, different endpoint: must NOT return endpoint A's cached 8000.
|
||||||
|
assert mc.get_context_length(b, "shared-model") == 200000
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_hit_still_works_per_endpoint(monkeypatch):
|
||||||
|
a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1"
|
||||||
|
_setup(monkeypatch, {a: 8000, b: 200000})
|
||||||
|
mc.get_context_length(a, "shared-model")
|
||||||
|
mc.get_context_length(b, "shared-model")
|
||||||
|
|
||||||
|
# Both endpoints are now cached under their own key; flip the underlying
|
||||||
|
# query to prove subsequent reads come from the per-endpoint cache, not a re-query.
|
||||||
|
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: 999)
|
||||||
|
assert mc.get_context_length(a, "shared-model") == 8000
|
||||||
|
assert mc.get_context_length(b, "shared-model") == 200000
|
||||||
Reference in New Issue
Block a user