From cd6041477c611dc22ca60ed983ef440fe37cc170 Mon Sep 17 00:00:00 2001 From: ooovenenoso <120500656+ooovenenoso@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:54:06 -0400 Subject: [PATCH] Refresh local model context after restart Co-authored-by: Kevin <120500656+oooindefatigable@users.noreply.github.com> --- src/model_context.py | 9 +++++--- tests/test_model_context.py | 44 +++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/model_context.py b/src/model_context.py index 23cdb86..dd32a7b 100644 --- a/src/model_context.py +++ b/src/model_context.py @@ -169,12 +169,15 @@ def get_context_length(endpoint_url: str, model: str) -> int: or context_window fields. Caches result per model ID. Falls back to DEFAULT_CONTEXT if unavailable. """ - if model in _context_cache: + is_local = _is_local_endpoint(endpoint_url) + if not is_local and model in _context_cache: return _context_cache[model] ctx = _query_context_length(endpoint_url, model) - # Only cache non-default values to allow retry on next request - if ctx != DEFAULT_CONTEXT: + # Only cache non-default values to allow retry on next request. + # Local endpoints can restart with a different --max-model-len while keeping + # the same model id, so always re-query them instead of serving stale cache. + if not is_local and ctx != DEFAULT_CONTEXT: _context_cache[model] = ctx logger.info(f"Context length for {model}: {ctx}") return ctx diff --git a/tests/test_model_context.py b/tests/test_model_context.py index 619f0a8..9067b8c 100644 --- a/tests/test_model_context.py +++ b/tests/test_model_context.py @@ -2,6 +2,7 @@ import pytest +import src.model_context as model_context from src.model_context import _is_local_endpoint, estimate_tokens, _lookup_known @@ -107,3 +108,46 @@ class TestLookupKnown: """Models with :free or :extended suffixes should still match.""" result = _lookup_known("deepseek-r1:free") assert result == 64000 + + +class TestGetContextLength: + def setup_method(self): + model_context._context_cache.clear() + + def test_local_endpoint_requeries_same_model_after_restart(self, monkeypatch): + calls = [] + + def fake_query(endpoint_url, model): + calls.append((endpoint_url, model)) + return 8192 if len(calls) == 1 else 27000 + + monkeypatch.setattr(model_context, "_query_context_length", fake_query) + + endpoint = "http://127.0.0.1:8000/v1/chat/completions" + model = "Qwen/Qwen3-14B" + + first = model_context.get_context_length(endpoint, model) + second = model_context.get_context_length(endpoint, model) + + assert first == 8192 + assert second == 27000 + assert len(calls) == 2 + + def test_remote_endpoint_keeps_cached_context(self, monkeypatch): + calls = [] + + def fake_query(endpoint_url, model): + calls.append((endpoint_url, model)) + return 200000 if len(calls) == 1 else 12345 + + monkeypatch.setattr(model_context, "_query_context_length", fake_query) + + endpoint = "https://api.openai.com/v1/chat/completions" + model = "gpt-5" + + first = model_context.get_context_length(endpoint, model) + second = model_context.get_context_length(endpoint, model) + + assert first == 200000 + assert second == 200000 + assert len(calls) == 1