Make LLM host health maps thread-safe
The synchronous llm_call() runs in FastAPI's threadpool (sync route handlers such as POST /sessions/auto-sort), while llm_call_async() runs on the event loop. Both mutate the module-level _response_cache, _host_fails and _dead_hosts dicts, so these are touched from multiple OS threads concurrently. Two races result: - _set_cached_response() snapshots 64 keys then deletes them with `del _response_cache[key]`; if another thread evicts the same key first, the del raises KeyError mid-eviction. Switched to pop(key, None). - _mark_host_dead() does get()+1+set() on _host_fails with no lock, so concurrent connect failures lose increments and a genuinely dead host can stay under its cooldown threshold. Guarded the host-health maps with a threading.Lock (also applied to _is_host_dead / _clear_host_dead for consistent reads). Adds tests/test_llm_core_concurrency.py with deterministic regression tests (phantom snapshot key for the eviction race; a slow-read dict that forces the lost-update window for the counter). Both fail on the unpatched code and pass with the fix.
This commit is contained in:
@@ -5,6 +5,7 @@ import time
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import threading
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -56,6 +57,12 @@ DEAD_HOST_COOLDOWN = 20.0
|
|||||||
_HOST_FAIL_THRESHOLD = 2
|
_HOST_FAIL_THRESHOLD = 2
|
||||||
_dead_hosts: Dict[str, float] = {}
|
_dead_hosts: Dict[str, float] = {}
|
||||||
_host_fails: Dict[str, int] = {}
|
_host_fails: Dict[str, int] = {}
|
||||||
|
# Guards the two maps above. The synchronous llm_call() runs inside FastAPI's
|
||||||
|
# threadpool (sync routes such as /sessions/auto-sort) while llm_call_async()
|
||||||
|
# runs on the event loop, so these maps are mutated from multiple OS threads.
|
||||||
|
# Without the lock the get()+1+set on _host_fails is a read-modify-write that
|
||||||
|
# loses failure counts under concurrent connect errors (issue #659).
|
||||||
|
_host_health_lock = threading.Lock()
|
||||||
_model_activity: Dict[str, float] = {}
|
_model_activity: Dict[str, float] = {}
|
||||||
|
|
||||||
def _model_activity_key(url: str, model: str) -> str:
|
def _model_activity_key(url: str, model: str) -> str:
|
||||||
@@ -81,13 +88,14 @@ def _host_key(url: str) -> str:
|
|||||||
|
|
||||||
def _is_host_dead(url: str) -> bool:
|
def _is_host_dead(url: str) -> bool:
|
||||||
key = _host_key(url)
|
key = _host_key(url)
|
||||||
exp = _dead_hosts.get(key)
|
with _host_health_lock:
|
||||||
if exp is None:
|
exp = _dead_hosts.get(key)
|
||||||
return False
|
if exp is None:
|
||||||
if time.time() >= exp:
|
return False
|
||||||
_dead_hosts.pop(key, None)
|
if time.time() >= exp:
|
||||||
return False
|
_dead_hosts.pop(key, None)
|
||||||
return True
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _mark_host_dead(url: str) -> bool:
|
def _mark_host_dead(url: str) -> bool:
|
||||||
"""Record a connect failure. Only actually cools the host after
|
"""Record a connect failure. Only actually cools the host after
|
||||||
@@ -95,17 +103,19 @@ def _mark_host_dead(url: str) -> bool:
|
|||||||
is now cooled (so callers can log accurately), False if it's still
|
is now cooled (so callers can log accurately), False if it's still
|
||||||
within its allowed-failure grace."""
|
within its allowed-failure grace."""
|
||||||
key = _host_key(url)
|
key = _host_key(url)
|
||||||
n = _host_fails.get(key, 0) + 1
|
with _host_health_lock:
|
||||||
_host_fails[key] = n
|
n = _host_fails.get(key, 0) + 1
|
||||||
if n >= _HOST_FAIL_THRESHOLD:
|
_host_fails[key] = n
|
||||||
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
if n >= _HOST_FAIL_THRESHOLD:
|
||||||
return True
|
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
||||||
return False
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _clear_host_dead(url: str) -> None:
|
def _clear_host_dead(url: str) -> None:
|
||||||
key = _host_key(url)
|
key = _host_key(url)
|
||||||
_dead_hosts.pop(key, None)
|
with _host_health_lock:
|
||||||
_host_fails.pop(key, None)
|
_dead_hosts.pop(key, None)
|
||||||
|
_host_fails.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
# Shared async HTTP client. Reusing one client keeps connections warm:
|
# Shared async HTTP client. Reusing one client keeps connections warm:
|
||||||
@@ -130,7 +140,10 @@ def _set_cached_response(cache_key: str, response: str) -> None:
|
|||||||
if len(_response_cache) > 128:
|
if len(_response_cache) > 128:
|
||||||
keys_to_remove = list(_response_cache.keys())[:64]
|
keys_to_remove = list(_response_cache.keys())[:64]
|
||||||
for key in keys_to_remove:
|
for key in keys_to_remove:
|
||||||
del _response_cache[key]
|
# pop(), not del: another thread (sync llm_call runs in FastAPI's
|
||||||
|
# threadpool) may have already evicted the same snapshotted key,
|
||||||
|
# and del would raise KeyError mid-eviction (issue #659).
|
||||||
|
_response_cache.pop(key, None)
|
||||||
_response_cache[cache_key] = response
|
_response_cache[cache_key] = response
|
||||||
|
|
||||||
# ── Anthropic native API adapter ──
|
# ── Anthropic native API adapter ──
|
||||||
|
|||||||
79
tests/test_llm_core_concurrency.py
Normal file
79
tests/test_llm_core_concurrency.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Regression tests for thread-safe access to llm_core's shared maps (issue #659).
|
||||||
|
|
||||||
|
The synchronous llm_call() runs inside FastAPI's threadpool (sync route handlers
|
||||||
|
such as POST /sessions/auto-sort), while llm_call_async() runs on the event
|
||||||
|
loop. Both mutate the module-level _response_cache / _host_fails / _dead_hosts
|
||||||
|
dicts, so those mutations must tolerate concurrent access from multiple OS
|
||||||
|
threads.
|
||||||
|
|
||||||
|
Plain thread stress can't reliably reproduce these races (CPython's GIL rarely
|
||||||
|
preempts the short critical sections), so each test deterministically widens the
|
||||||
|
vulnerable window: one injects a phantom snapshot key, the other forces every
|
||||||
|
thread to read the counter before any writes it back.
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src import llm_core
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_eviction_tolerates_already_removed_key():
|
||||||
|
"""Eviction must not raise when a snapshotted key is gone by delete time.
|
||||||
|
|
||||||
|
Models a concurrent evictor removing the same key: the old `del` raised
|
||||||
|
KeyError mid-loop, `pop(key, None)` does not.
|
||||||
|
"""
|
||||||
|
class PhantomKeysCache(dict):
|
||||||
|
def keys(self):
|
||||||
|
# First key is absent from the dict — as if another thread evicted
|
||||||
|
# it between the snapshot and the delete.
|
||||||
|
return ["__phantom_removed__", *super().keys()]
|
||||||
|
|
||||||
|
original = llm_core._response_cache
|
||||||
|
cache = PhantomKeysCache()
|
||||||
|
for i in range(130): # exceed the 128 cap so the eviction branch runs
|
||||||
|
cache[f"k{i}"] = "x"
|
||||||
|
llm_core._response_cache = cache
|
||||||
|
try:
|
||||||
|
llm_core._set_cached_response("new-key", "y") # must not raise
|
||||||
|
assert dict.get(cache, "new-key") == "y"
|
||||||
|
finally:
|
||||||
|
llm_core._response_cache = original
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_fail_counter_has_no_lost_updates():
|
||||||
|
"""Concurrent _mark_host_dead calls must each count exactly once.
|
||||||
|
|
||||||
|
A SlowGetDict widens the read-modify-write window so the unguarded
|
||||||
|
get()+1+set() loses every update but one; the lock serializes them.
|
||||||
|
"""
|
||||||
|
url = "http://race.example:1234/v1/chat/completions"
|
||||||
|
key = llm_core._host_key(url)
|
||||||
|
|
||||||
|
class SlowGetDict(dict):
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
value = super().get(*args, **kwargs)
|
||||||
|
time.sleep(0.01) # widen the gap between the read and the caller's write
|
||||||
|
return value
|
||||||
|
|
||||||
|
n_threads = 8
|
||||||
|
barrier = threading.Barrier(n_threads)
|
||||||
|
original_fails = llm_core._host_fails
|
||||||
|
original_threshold = llm_core._HOST_FAIL_THRESHOLD
|
||||||
|
llm_core._host_fails = SlowGetDict()
|
||||||
|
llm_core._HOST_FAIL_THRESHOLD = 10 ** 9 # never cool: every call is a pure +1
|
||||||
|
try:
|
||||||
|
def worker():
|
||||||
|
barrier.wait() # all threads enter the read window together
|
||||||
|
llm_core._mark_host_dead(url)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker) for _ in range(n_threads)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert dict.get(llm_core._host_fails, key) == n_threads
|
||||||
|
finally:
|
||||||
|
llm_core._host_fails = original_fails
|
||||||
|
llm_core._HOST_FAIL_THRESHOLD = original_threshold
|
||||||
Reference in New Issue
Block a user