Make LLM host health maps thread-safe
The synchronous llm_call() runs in FastAPI's threadpool (sync route handlers such as POST /sessions/auto-sort), while llm_call_async() runs on the event loop. Both mutate the module-level _response_cache, _host_fails and _dead_hosts dicts, so these are touched from multiple OS threads concurrently. Two races result: - _set_cached_response() snapshots 64 keys then deletes them with `del _response_cache[key]`; if another thread evicts the same key first, the del raises KeyError mid-eviction. Switched to pop(key, None). - _mark_host_dead() does get()+1+set() on _host_fails with no lock, so concurrent connect failures lose increments and a genuinely dead host can stay under its cooldown threshold. Guarded the host-health maps with a threading.Lock (also applied to _is_host_dead / _clear_host_dead for consistent reads). Adds tests/test_llm_core_concurrency.py with deterministic regression tests (phantom snapshot key for the eviction race; a slow-read dict that forces the lost-update window for the counter). Both fail on the unpatched code and pass with the fix.
This commit is contained in:
@@ -5,6 +5,7 @@ import time
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import threading
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, List
|
||||
from urllib.parse import urlparse
|
||||
@@ -56,6 +57,12 @@ DEAD_HOST_COOLDOWN = 20.0
|
||||
_HOST_FAIL_THRESHOLD = 2
|
||||
_dead_hosts: Dict[str, float] = {}
|
||||
_host_fails: Dict[str, int] = {}
|
||||
# Guards the two maps above. The synchronous llm_call() runs inside FastAPI's
|
||||
# threadpool (sync routes such as /sessions/auto-sort) while llm_call_async()
|
||||
# runs on the event loop, so these maps are mutated from multiple OS threads.
|
||||
# Without the lock the get()+1+set on _host_fails is a read-modify-write that
|
||||
# loses failure counts under concurrent connect errors (issue #659).
|
||||
_host_health_lock = threading.Lock()
|
||||
_model_activity: Dict[str, float] = {}
|
||||
|
||||
def _model_activity_key(url: str, model: str) -> str:
|
||||
@@ -81,13 +88,14 @@ def _host_key(url: str) -> str:
|
||||
|
||||
def _is_host_dead(url: str) -> bool:
|
||||
key = _host_key(url)
|
||||
exp = _dead_hosts.get(key)
|
||||
if exp is None:
|
||||
return False
|
||||
if time.time() >= exp:
|
||||
_dead_hosts.pop(key, None)
|
||||
return False
|
||||
return True
|
||||
with _host_health_lock:
|
||||
exp = _dead_hosts.get(key)
|
||||
if exp is None:
|
||||
return False
|
||||
if time.time() >= exp:
|
||||
_dead_hosts.pop(key, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _mark_host_dead(url: str) -> bool:
|
||||
"""Record a connect failure. Only actually cools the host after
|
||||
@@ -95,17 +103,19 @@ def _mark_host_dead(url: str) -> bool:
|
||||
is now cooled (so callers can log accurately), False if it's still
|
||||
within its allowed-failure grace."""
|
||||
key = _host_key(url)
|
||||
n = _host_fails.get(key, 0) + 1
|
||||
_host_fails[key] = n
|
||||
if n >= _HOST_FAIL_THRESHOLD:
|
||||
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
||||
return True
|
||||
return False
|
||||
with _host_health_lock:
|
||||
n = _host_fails.get(key, 0) + 1
|
||||
_host_fails[key] = n
|
||||
if n >= _HOST_FAIL_THRESHOLD:
|
||||
_dead_hosts[key] = time.time() + DEAD_HOST_COOLDOWN
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clear_host_dead(url: str) -> None:
|
||||
key = _host_key(url)
|
||||
_dead_hosts.pop(key, None)
|
||||
_host_fails.pop(key, None)
|
||||
with _host_health_lock:
|
||||
_dead_hosts.pop(key, None)
|
||||
_host_fails.pop(key, None)
|
||||
|
||||
|
||||
# Shared async HTTP client. Reusing one client keeps connections warm:
|
||||
@@ -130,7 +140,10 @@ def _set_cached_response(cache_key: str, response: str) -> None:
|
||||
if len(_response_cache) > 128:
|
||||
keys_to_remove = list(_response_cache.keys())[:64]
|
||||
for key in keys_to_remove:
|
||||
del _response_cache[key]
|
||||
# pop(), not del: another thread (sync llm_call runs in FastAPI's
|
||||
# threadpool) may have already evicted the same snapshotted key,
|
||||
# and del would raise KeyError mid-eviction (issue #659).
|
||||
_response_cache.pop(key, None)
|
||||
_response_cache[cache_key] = response
|
||||
|
||||
# ── Anthropic native API adapter ──
|
||||
|
||||
79
tests/test_llm_core_concurrency.py
Normal file
79
tests/test_llm_core_concurrency.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Regression tests for thread-safe access to llm_core's shared maps (issue #659).
|
||||
|
||||
The synchronous llm_call() runs inside FastAPI's threadpool (sync route handlers
|
||||
such as POST /sessions/auto-sort), while llm_call_async() runs on the event
|
||||
loop. Both mutate the module-level _response_cache / _host_fails / _dead_hosts
|
||||
dicts, so those mutations must tolerate concurrent access from multiple OS
|
||||
threads.
|
||||
|
||||
Plain thread stress can't reliably reproduce these races (CPython's GIL rarely
|
||||
preempts the short critical sections), so each test deterministically widens the
|
||||
vulnerable window: one injects a phantom snapshot key, the other forces every
|
||||
thread to read the counter before any writes it back.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
from src import llm_core
|
||||
|
||||
|
||||
def test_cache_eviction_tolerates_already_removed_key():
|
||||
"""Eviction must not raise when a snapshotted key is gone by delete time.
|
||||
|
||||
Models a concurrent evictor removing the same key: the old `del` raised
|
||||
KeyError mid-loop, `pop(key, None)` does not.
|
||||
"""
|
||||
class PhantomKeysCache(dict):
|
||||
def keys(self):
|
||||
# First key is absent from the dict — as if another thread evicted
|
||||
# it between the snapshot and the delete.
|
||||
return ["__phantom_removed__", *super().keys()]
|
||||
|
||||
original = llm_core._response_cache
|
||||
cache = PhantomKeysCache()
|
||||
for i in range(130): # exceed the 128 cap so the eviction branch runs
|
||||
cache[f"k{i}"] = "x"
|
||||
llm_core._response_cache = cache
|
||||
try:
|
||||
llm_core._set_cached_response("new-key", "y") # must not raise
|
||||
assert dict.get(cache, "new-key") == "y"
|
||||
finally:
|
||||
llm_core._response_cache = original
|
||||
|
||||
|
||||
def test_host_fail_counter_has_no_lost_updates():
|
||||
"""Concurrent _mark_host_dead calls must each count exactly once.
|
||||
|
||||
A SlowGetDict widens the read-modify-write window so the unguarded
|
||||
get()+1+set() loses every update but one; the lock serializes them.
|
||||
"""
|
||||
url = "http://race.example:1234/v1/chat/completions"
|
||||
key = llm_core._host_key(url)
|
||||
|
||||
class SlowGetDict(dict):
|
||||
def get(self, *args, **kwargs):
|
||||
value = super().get(*args, **kwargs)
|
||||
time.sleep(0.01) # widen the gap between the read and the caller's write
|
||||
return value
|
||||
|
||||
n_threads = 8
|
||||
barrier = threading.Barrier(n_threads)
|
||||
original_fails = llm_core._host_fails
|
||||
original_threshold = llm_core._HOST_FAIL_THRESHOLD
|
||||
llm_core._host_fails = SlowGetDict()
|
||||
llm_core._HOST_FAIL_THRESHOLD = 10 ** 9 # never cool: every call is a pure +1
|
||||
try:
|
||||
def worker():
|
||||
barrier.wait() # all threads enter the read window together
|
||||
llm_core._mark_host_dead(url)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(n_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert dict.get(llm_core._host_fails, key) == n_threads
|
||||
finally:
|
||||
llm_core._host_fails = original_fails
|
||||
llm_core._HOST_FAIL_THRESHOLD = original_threshold
|
||||
Reference in New Issue
Block a user