Make LLM host health maps thread-safe

The synchronous llm_call() runs in FastAPI's threadpool (sync route
handlers such as POST /sessions/auto-sort), while llm_call_async() runs
on the event loop. Both mutate the module-level _response_cache,
_host_fails and _dead_hosts dicts, so these are touched from multiple OS
threads concurrently. Two races result:

- _set_cached_response() snapshots 64 keys then deletes them with
  `del _response_cache[key]`; if another thread evicts the same key
  first, the del raises KeyError mid-eviction. Switched to
  pop(key, None).
- _mark_host_dead() does get()+1+set() on _host_fails with no lock, so
  concurrent connect failures lose increments and a genuinely dead host
  can stay under its cooldown threshold. Guarded the host-health maps
  with a threading.Lock (also applied to _is_host_dead / _clear_host_dead
  for consistent reads).

Adds tests/test_llm_core_concurrency.py with deterministic regression
tests (phantom snapshot key for the eviction race; a slow-read dict that
forces the lost-update window for the counter). Both fail on the
unpatched code and pass with the fix.
This commit is contained in:
SurprisedDuck
2026-06-01 22:54:23 +02:00
committed by GitHub
parent cd6041477c
commit 7268c49992
2 changed files with 108 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ import time
import json import json
import logging import logging
import hashlib import hashlib
import threading
from fastapi import HTTPException from fastapi import HTTPException
from typing import Optional, Dict, List from typing import Optional, Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -56,6 +57,12 @@ DEAD_HOST_COOLDOWN = 20.0
_HOST_FAIL_THRESHOLD = 2 _HOST_FAIL_THRESHOLD = 2
_dead_hosts: Dict[str, float] = {} _dead_hosts: Dict[str, float] = {}
_host_fails: Dict[str, int] = {} _host_fails: Dict[str, int] = {}
# Guards the two maps above. The synchronous llm_call() runs inside FastAPI's
# threadpool (sync routes such as /sessions/auto-sort) while llm_call_async()
# runs on the event loop, so these maps are mutated from multiple OS threads.
# Without the lock the get()+1+set on _host_fails is a read-modify-write that
# loses failure counts under concurrent connect errors (issue #659).
_host_health_lock = threading.Lock()
_model_activity: Dict[str, float] = {} _model_activity: Dict[str, float] = {}
def _model_activity_key(url: str, model: str) -> str: def _model_activity_key(url: str, model: str) -> str:
@@ -81,13 +88,14 @@ def _host_key(url: str) -> str:
def _is_host_dead(url: str) -> bool: def _is_host_dead(url: str) -> bool:
key = _host_key(url) key = _host_key(url)
exp = _dead_hosts.get(key) with _host_health_lock:
if exp is None: exp = _dead_hosts.get(key)
return False if exp is None:
if time.time() >= exp: return False
_dead_hosts.pop(key, None) if time.time() >= exp:
return False _dead_hosts.pop(key, None)
return True return False
return True
def _mark_host_dead(url: str) -> bool: def _mark_host_dead(url: str) -> bool:
"""Record a connect failure. Only actually cools the host after """Record a connect failure. Only actually cools the host after
@@ -95,17 +103,19 @@ def _mark_host_dead(url: str) -> bool:
is now cooled (so callers can log accurately), False if it's still is now cooled (so callers can log accurately), False if it's still
within its allowed-failure grace.""" within its allowed-failure grace."""
key = _host_key(url) key = _host_key(url)
n = _host_fails.get(key, 0) + 1 with _host_health_lock:
_host_fails[key] = n n = _host_fails.get(key, 0) + 1
if n >= _HOST_FAIL_THRESHOLD: _host_fails[key] = n
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN if n >= _HOST_FAIL_THRESHOLD:
return True _dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
return False return True
return False
def _clear_host_dead(url: str) -> None: def _clear_host_dead(url: str) -> None:
key = _host_key(url) key = _host_key(url)
_dead_hosts.pop(key, None) with _host_health_lock:
_host_fails.pop(key, None) _dead_hosts.pop(key, None)
_host_fails.pop(key, None)
# Shared async HTTP client. Reusing one client keeps connections warm: # Shared async HTTP client. Reusing one client keeps connections warm:
@@ -130,7 +140,10 @@ def _set_cached_response(cache_key: str, response: str) -> None:
if len(_response_cache) > 128: if len(_response_cache) > 128:
keys_to_remove = list(_response_cache.keys())[:64] keys_to_remove = list(_response_cache.keys())[:64]
for key in keys_to_remove: for key in keys_to_remove:
del _response_cache[key] # pop(), not del: another thread (sync llm_call runs in FastAPI's
# threadpool) may have already evicted the same snapshotted key,
# and del would raise KeyError mid-eviction (issue #659).
_response_cache.pop(key, None)
_response_cache[cache_key] = response _response_cache[cache_key] = response
# ── Anthropic native API adapter ── # ── Anthropic native API adapter ──

View File

@@ -0,0 +1,79 @@
"""Regression tests for thread-safe access to llm_core's shared maps (issue #659).
The synchronous llm_call() runs inside FastAPI's threadpool (sync route handlers
such as POST /sessions/auto-sort), while llm_call_async() runs on the event
loop. Both mutate the module-level _response_cache / _host_fails / _dead_hosts
dicts, so those mutations must tolerate concurrent access from multiple OS
threads.
Plain thread stress can't reliably reproduce these races (CPython's GIL rarely
preempts the short critical sections), so each test deterministically widens the
vulnerable window: one injects a phantom snapshot key, the other forces every
thread to read the counter before any writes it back.
"""
import threading
import time
from src import llm_core
def test_cache_eviction_tolerates_already_removed_key():
"""Eviction must not raise when a snapshotted key is gone by delete time.
Models a concurrent evictor removing the same key: the old `del` raised
KeyError mid-loop, `pop(key, None)` does not.
"""
class PhantomKeysCache(dict):
def keys(self):
# First key is absent from the dict — as if another thread evicted
# it between the snapshot and the delete.
return ["__phantom_removed__", *super().keys()]
original = llm_core._response_cache
cache = PhantomKeysCache()
for i in range(130): # exceed the 128 cap so the eviction branch runs
cache[f"k{i}"] = "x"
llm_core._response_cache = cache
try:
llm_core._set_cached_response("new-key", "y") # must not raise
assert dict.get(cache, "new-key") == "y"
finally:
llm_core._response_cache = original
def test_host_fail_counter_has_no_lost_updates():
"""Concurrent _mark_host_dead calls must each count exactly once.
A SlowGetDict widens the read-modify-write window so the unguarded
get()+1+set() loses every update but one; the lock serializes them.
"""
url = "http://race.example:1234/v1/chat/completions"
key = llm_core._host_key(url)
class SlowGetDict(dict):
def get(self, *args, **kwargs):
value = super().get(*args, **kwargs)
time.sleep(0.01) # widen the gap between the read and the caller's write
return value
n_threads = 8
barrier = threading.Barrier(n_threads)
original_fails = llm_core._host_fails
original_threshold = llm_core._HOST_FAIL_THRESHOLD
llm_core._host_fails = SlowGetDict()
llm_core._HOST_FAIL_THRESHOLD = 10 ** 9 # never cool: every call is a pure +1
try:
def worker():
barrier.wait() # all threads enter the read window together
llm_core._mark_host_dead(url)
threads = [threading.Thread(target=worker) for _ in range(n_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
assert dict.get(llm_core._host_fails, key) == n_threads
finally:
llm_core._host_fails = original_fails
llm_core._HOST_FAIL_THRESHOLD = original_threshold