diff --git a/src/model_context.py b/src/model_context.py index 2fd0b82..3a445fe 100644 --- a/src/model_context.py +++ b/src/model_context.py @@ -7,7 +7,7 @@ Provides token estimation for context usage tracking. import logging import sys -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse @@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = { # --------------------------------------------------------------------------- # Cache # --------------------------------------------------------------------------- -_context_cache: Dict[str, int] = {} +_context_cache: Dict[Tuple[str, str], int] = {} def get_context_length(endpoint_url: str, model: str) -> int: """Get the context window size for a model. Queries /v1/models on the endpoint and looks for context_length - or context_window fields. Caches result per model ID. + or context_window fields. Caches result per (endpoint, model). Falls back to DEFAULT_CONTEXT if unavailable. """ configured_kind = _configured_endpoint_kind(endpoint_url) is_local = _is_local_endpoint(endpoint_url) - if not is_local and model in _context_cache: - return _context_cache[model] + # Key on (endpoint_url, model): the same model id can be served by two + # different remote endpoints with different real context windows (e.g. a + # capped proxy vs. the full provider), so caching by model id alone would + # serve one endpoint's window for the other (issue #2603). + cache_key = (endpoint_url, model) + if not is_local and cache_key in _context_cache: + return _context_cache[cache_key] ctx = _query_context_length(endpoint_url, model) # Only cache non-default values to allow retry on next request. # Local endpoints can restart with a different --max-model-len while keeping # the same model id, so always re-query them instead of serving stale cache. if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")): - _context_cache[model] = ctx + _context_cache[cache_key] = ctx logger.info(f"Context length for {model}: {ctx}") return ctx diff --git a/tests/test_context_cache_per_endpoint.py b/tests/test_context_cache_per_endpoint.py new file mode 100644 index 0000000..3bffd7b --- /dev/null +++ b/tests/test_context_cache_per_endpoint.py @@ -0,0 +1,39 @@ +"""Regression for #2603 — model context-window cache must be keyed per endpoint. + +`get_context_length()` cached by model id alone, so two different remote endpoints +serving the same model id (e.g. a capped proxy at 8k vs. the full provider at 200k) +collided: whichever resolved first won process-wide and the other was served the +wrong window. The fix keys the cache on (endpoint_url, model). +""" + +import src.model_context as mc + + +def _setup(monkeypatch, windows): + """windows: {endpoint_url: context_length}. Force the remote path.""" + monkeypatch.setattr(mc, "_is_local_endpoint", lambda url: False) + monkeypatch.setattr(mc, "_configured_endpoint_kind", lambda url: "api") + monkeypatch.setattr(mc, "_query_context_length", lambda url, model: windows[url]) + mc._context_cache.clear() + + +def test_same_model_two_remote_endpoints_get_their_own_window(monkeypatch): + a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1" + _setup(monkeypatch, {a: 8000, b: 200000}) + + assert mc.get_context_length(a, "shared-model") == 8000 + # Same model id, different endpoint: must NOT return endpoint A's cached 8000. + assert mc.get_context_length(b, "shared-model") == 200000 + + +def test_cache_hit_still_works_per_endpoint(monkeypatch): + a, b = "https://proxy-a.example/v1", "https://provider-b.example/v1" + _setup(monkeypatch, {a: 8000, b: 200000}) + mc.get_context_length(a, "shared-model") + mc.get_context_length(b, "shared-model") + + # Both endpoints are now cached under their own key; flip the underlying + # query to prove subsequent reads come from the per-endpoint cache, not a re-query. + monkeypatch.setattr(mc, "_query_context_length", lambda url, model: 999) + assert mc.get_context_length(a, "shared-model") == 8000 + assert mc.get_context_length(b, "shared-model") == 200000