From 7268c49992118ce9988ac665a8c3ede5484ad39f Mon Sep 17 00:00:00 2001 From: SurprisedDuck Date: Mon, 1 Jun 2026 22:54:23 +0200 Subject: [PATCH] Make LLM host health maps thread-safe The synchronous llm_call() runs in FastAPI's threadpool (sync route handlers such as POST /sessions/auto-sort), while llm_call_async() runs on the event loop. Both mutate the module-level _response_cache, _host_fails and _dead_hosts dicts, so these are touched from multiple OS threads concurrently. Two races result: - _set_cached_response() snapshots 64 keys then deletes them with `del _response_cache[key]`; if another thread evicts the same key first, the del raises KeyError mid-eviction. Switched to pop(key, None). - _mark_host_dead() does get()+1+set() on _host_fails with no lock, so concurrent connect failures lose increments and a genuinely dead host can stay under its cooldown threshold. Guarded the host-health maps with a threading.Lock (also applied to _is_host_dead / _clear_host_dead for consistent reads). Adds tests/test_llm_core_concurrency.py with deterministic regression tests (phantom snapshot key for the eviction race; a slow-read dict that forces the lost-update window for the counter). Both fail on the unpatched code and pass with the fix. --- src/llm_core.py | 45 +++++++++++------ tests/test_llm_core_concurrency.py | 79 ++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 16 deletions(-) create mode 100644 tests/test_llm_core_concurrency.py diff --git a/src/llm_core.py b/src/llm_core.py index 210ed49..0d4ddc5 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -5,6 +5,7 @@ import time import json import logging import hashlib +import threading from fastapi import HTTPException from typing import Optional, Dict, List from urllib.parse import urlparse @@ -56,6 +57,12 @@ DEAD_HOST_COOLDOWN = 20.0 _HOST_FAIL_THRESHOLD = 2 _dead_hosts: Dict[str, float] = {} _host_fails: Dict[str, int] = {} +# Guards the two maps above. The synchronous llm_call() runs inside FastAPI's +# threadpool (sync routes such as /sessions/auto-sort) while llm_call_async() +# runs on the event loop, so these maps are mutated from multiple OS threads. +# Without the lock the get()+1+set on _host_fails is a read-modify-write that +# loses failure counts under concurrent connect errors (issue #659). +_host_health_lock = threading.Lock() _model_activity: Dict[str, float] = {} def _model_activity_key(url: str, model: str) -> str: @@ -81,13 +88,14 @@ def _host_key(url: str) -> str: def _is_host_dead(url: str) -> bool: key = _host_key(url) - exp = _dead_hosts.get(key) - if exp is None: - return False - if time.time() >= exp: - _dead_hosts.pop(key, None) - return False - return True + with _host_health_lock: + exp = _dead_hosts.get(key) + if exp is None: + return False + if time.time() >= exp: + _dead_hosts.pop(key, None) + return False + return True def _mark_host_dead(url: str) -> bool: """Record a connect failure. Only actually cools the host after @@ -95,17 +103,19 @@ def _mark_host_dead(url: str) -> bool: is now cooled (so callers can log accurately), False if it's still within its allowed-failure grace.""" key = _host_key(url) - n = _host_fails.get(key, 0) + 1 - _host_fails[key] = n - if n >= _HOST_FAIL_THRESHOLD: - _dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN - return True - return False + with _host_health_lock: + n = _host_fails.get(key, 0) + 1 + _host_fails[key] = n + if n >= _HOST_FAIL_THRESHOLD: + _dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN + return True + return False def _clear_host_dead(url: str) -> None: key = _host_key(url) - _dead_hosts.pop(key, None) - _host_fails.pop(key, None) + with _host_health_lock: + _dead_hosts.pop(key, None) + _host_fails.pop(key, None) # Shared async HTTP client. Reusing one client keeps connections warm: @@ -130,7 +140,10 @@ def _set_cached_response(cache_key: str, response: str) -> None: if len(_response_cache) > 128: keys_to_remove = list(_response_cache.keys())[:64] for key in keys_to_remove: - del _response_cache[key] + # pop(), not del: another thread (sync llm_call runs in FastAPI's + # threadpool) may have already evicted the same snapshotted key, + # and del would raise KeyError mid-eviction (issue #659). + _response_cache.pop(key, None) _response_cache[cache_key] = response # ── Anthropic native API adapter ── diff --git a/tests/test_llm_core_concurrency.py b/tests/test_llm_core_concurrency.py new file mode 100644 index 0000000..22a85a6 --- /dev/null +++ b/tests/test_llm_core_concurrency.py @@ -0,0 +1,79 @@ +"""Regression tests for thread-safe access to llm_core's shared maps (issue #659). + +The synchronous llm_call() runs inside FastAPI's threadpool (sync route handlers +such as POST /sessions/auto-sort), while llm_call_async() runs on the event +loop. Both mutate the module-level _response_cache / _host_fails / _dead_hosts +dicts, so those mutations must tolerate concurrent access from multiple OS +threads. + +Plain thread stress can't reliably reproduce these races (CPython's GIL rarely +preempts the short critical sections), so each test deterministically widens the +vulnerable window: one injects a phantom snapshot key, the other forces every +thread to read the counter before any writes it back. +""" +import threading +import time + +from src import llm_core + + +def test_cache_eviction_tolerates_already_removed_key(): + """Eviction must not raise when a snapshotted key is gone by delete time. + + Models a concurrent evictor removing the same key: the old `del` raised + KeyError mid-loop, `pop(key, None)` does not. + """ + class PhantomKeysCache(dict): + def keys(self): + # First key is absent from the dict — as if another thread evicted + # it between the snapshot and the delete. + return ["__phantom_removed__", *super().keys()] + + original = llm_core._response_cache + cache = PhantomKeysCache() + for i in range(130): # exceed the 128 cap so the eviction branch runs + cache[f"k{i}"] = "x" + llm_core._response_cache = cache + try: + llm_core._set_cached_response("new-key", "y") # must not raise + assert dict.get(cache, "new-key") == "y" + finally: + llm_core._response_cache = original + + +def test_host_fail_counter_has_no_lost_updates(): + """Concurrent _mark_host_dead calls must each count exactly once. + + A SlowGetDict widens the read-modify-write window so the unguarded + get()+1+set() loses every update but one; the lock serializes them. + """ + url = "http://race.example:1234/v1/chat/completions" + key = llm_core._host_key(url) + + class SlowGetDict(dict): + def get(self, *args, **kwargs): + value = super().get(*args, **kwargs) + time.sleep(0.01) # widen the gap between the read and the caller's write + return value + + n_threads = 8 + barrier = threading.Barrier(n_threads) + original_fails = llm_core._host_fails + original_threshold = llm_core._HOST_FAIL_THRESHOLD + llm_core._host_fails = SlowGetDict() + llm_core._HOST_FAIL_THRESHOLD = 10 ** 9 # never cool: every call is a pure +1 + try: + def worker(): + barrier.wait() # all threads enter the read window together + llm_core._mark_host_dead(url) + + threads = [threading.Thread(target=worker) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert dict.get(llm_core._host_fails, key) == n_threads + finally: + llm_core._host_fails = original_fails + llm_core._HOST_FAIL_THRESHOLD = original_threshold