fix: add threading lock to AuthManager config mutations (#1226)

This commit is contained in:
Isak
2026-06-05 10:04:37 +02:00
committed by GitHub
parent 04df7255fb
commit ec7691956b
2 changed files with 271 additions and 59 deletions

View File

@@ -76,6 +76,10 @@ class AuthManager:
# Guards mutations of self._sessions and the on-disk sessions.json. # Guards mutations of self._sessions and the on-disk sessions.json.
# Validate/create/revoke run concurrently from the FastAPI threadpool. # Validate/create/revoke run concurrently from the FastAPI threadpool.
self._sessions_lock = threading.RLock() self._sessions_lock = threading.RLock()
# Guards all mutations of self._config and the on-disk auth.json so
# concurrent create/delete/rename/privilege operations don't interleave
# and corrupt the user database.
self._config_lock = threading.Lock()
# Guards the first-run setup check-and-write so concurrent requests # Guards the first-run setup check-and-write so concurrent requests
# cannot both observe is_configured==False and both create admin accounts. # cannot both observe is_configured==False and both create admin accounts.
self._setup_lock = threading.Lock() self._setup_lock = threading.Lock()
@@ -172,8 +176,9 @@ class AuthManager:
@signup_enabled.setter @signup_enabled.setter
def signup_enabled(self, value: bool): def signup_enabled(self, value: bool):
self._config["signup_enabled"] = value with self._config_lock:
self._save() self._config["signup_enabled"] = value
self._save()
@property @property
def is_configured(self) -> bool: def is_configured(self) -> bool:
@@ -198,17 +203,18 @@ class AuthManager:
if username in RESERVED_USERNAMES: if username in RESERVED_USERNAMES:
logger.warning("Refused to create reserved username '%s'", username) logger.warning("Refused to create reserved username '%s'", username)
return False return False
if username in self.users: with self._config_lock:
return False if username in self.users:
if "users" not in self._config: return False
self._config["users"] = {} if "users" not in self._config:
self._config["users"][username] = { self._config["users"] = {}
"password_hash": _hash_password(password), self._config["users"][username] = {
"created": time.time(), "password_hash": _hash_password(password),
"is_admin": is_admin, "created": time.time(),
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES), "is_admin": is_admin,
} "privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
self._save() }
self._save()
logger.info(f"Created user '{username}' (admin={is_admin})") logger.info(f"Created user '{username}' (admin={is_admin})")
return True return True
@@ -221,14 +227,15 @@ class AuthManager:
their cookie expired naturally (default ~30 days). their cookie expired naturally (default ~30 days).
""" """
username = username.strip().lower() username = username.strip().lower()
if username not in self.users: with self._config_lock:
return False if username not in self.users:
if username == requesting_user: return False
return False if username == requesting_user:
if not self.users.get(requesting_user, {}).get("is_admin"): return False
return False if not self.users.get(requesting_user, {}).get("is_admin"):
del self._config["users"][username] return False
self._save() del self._config["users"][username]
self._save()
# Purge all sessions belonging to this user. validate_token doesn't # Purge all sessions belonging to this user. validate_token doesn't
# cross-check `self.users`, so without this step a deleted user's # cross-check `self.users`, so without this step a deleted user's
# cookie keeps authenticating. # cookie keeps authenticating.
@@ -266,14 +273,15 @@ class AuthManager:
if new_username in RESERVED_USERNAMES: if new_username in RESERVED_USERNAMES:
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username) logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
return False return False
if old_username not in self.users: with self._config_lock:
return False if old_username not in self.users:
if new_username in self.users: return False
return False if new_username in self.users:
if not self.users.get(requesting_user, {}).get("is_admin"): return False
return False if not self.users.get(requesting_user, {}).get("is_admin"):
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username) return False
self._save() self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
self._save()
renamed_sessions = 0 renamed_sessions = 0
with self._sessions_lock: with self._sessions_lock:
@@ -311,17 +319,18 @@ class AuthManager:
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool: def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
"""Update privileges for a user. Can't modify admin privileges.""" """Update privileges for a user. Can't modify admin privileges."""
username = username.strip().lower() username = username.strip().lower()
if username not in self.users: with self._config_lock:
return False if username not in self.users:
if self.users[username].get("is_admin"): return False
return False # admins always have full access if self.users[username].get("is_admin"):
# Only allow known privilege keys return False # admins always have full access
current = self.get_privileges(username) # Only allow known privilege keys
for k, v in privileges.items(): current = self.get_privileges(username)
if k in DEFAULT_PRIVILEGES: for k, v in privileges.items():
current[k] = v if k in DEFAULT_PRIVILEGES:
self._config["users"][username]["privileges"] = current current[k] = v
self._save() self._config["users"][username]["privileges"] = current
self._save()
logger.info(f"Updated privileges for '{username}': {current}") logger.info(f"Updated privileges for '{username}': {current}")
return True return True
@@ -331,8 +340,9 @@ class AuthManager:
return False return False
if not _verify_password(current_password, self.users[username]["password_hash"]): if not _verify_password(current_password, self.users[username]["password_hash"]):
return False return False
self._config["users"][username]["password_hash"] = _hash_password(new_password) with self._config_lock:
self._save() self._config["users"][username]["password_hash"] = _hash_password(new_password)
self._save()
return True return True
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -350,8 +360,9 @@ class AuthManager:
if username not in self.users: if username not in self.users:
return None return None
secret = pyotp.random_base32() secret = pyotp.random_base32()
self._config["users"][username]["totp_secret_pending"] = secret with self._config_lock:
self._save() self._config["users"][username]["totp_secret_pending"] = secret
self._save()
return secret return secret
def totp_get_provisioning_uri(self, username: str, secret: str) -> str: def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
@@ -370,13 +381,14 @@ class AuthManager:
if not totp.verify(code, valid_window=1): if not totp.verify(code, valid_window=1):
return False return False
# Enable 2FA # Enable 2FA
self._config["users"][username]["totp_secret"] = secret with self._config_lock:
self._config["users"][username]["totp_enabled"] = True self._config["users"][username]["totp_secret"] = secret
self._config["users"][username].pop("totp_secret_pending", None) self._config["users"][username]["totp_enabled"] = True
# Generate backup codes self._config["users"][username].pop("totp_secret_pending", None)
backup = [secrets.token_hex(4) for _ in range(8)] # Generate backup codes
self._config["users"][username]["totp_backup_codes"] = backup backup = [secrets.token_hex(4) for _ in range(8)]
self._save() self._config["users"][username]["totp_backup_codes"] = backup
self._save()
logger.info(f"2FA enabled for '{username}'") logger.info(f"2FA enabled for '{username}'")
return True return True
@@ -395,9 +407,10 @@ class AuthManager:
# Check backup codes first # Check backup codes first
backup = user.get("totp_backup_codes", []) backup = user.get("totp_backup_codes", [])
if code in backup: if code in backup:
backup.remove(code) with self._config_lock:
self._config["users"][username]["totp_backup_codes"] = backup backup.remove(code)
self._save() self._config["users"][username]["totp_backup_codes"] = backup
self._save()
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)") logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
return True return True
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
@@ -408,11 +421,12 @@ class AuthManager:
username = username.strip().lower() username = username.strip().lower()
if not self.verify_password(username, password): if not self.verify_password(username, password):
return False return False
self._config["users"][username].pop("totp_secret", None) with self._config_lock:
self._config["users"][username].pop("totp_secret_pending", None) self._config["users"][username].pop("totp_secret", None)
self._config["users"][username].pop("totp_backup_codes", None) self._config["users"][username].pop("totp_secret_pending", None)
self._config["users"][username]["totp_enabled"] = False self._config["users"][username].pop("totp_backup_codes", None)
self._save() self._config["users"][username]["totp_enabled"] = False
self._save()
logger.info(f"2FA disabled for '{username}'") logger.info(f"2FA disabled for '{username}'")
return True return True

View File

@@ -0,0 +1,198 @@
"""Concurrency stress tests for AuthManager._config_lock.
Verifies that concurrent create/delete/rename operations don't lose data
or corrupt auth.json. If someone removes the lock, these tests should fail
with missing users or assertion errors.
"""
import json
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
def _fresh_auth_manager(tmp_path):
sys.modules.pop("core.auth", None)
if "core" in sys.modules and hasattr(sys.modules["core"], "auth"):
delattr(sys.modules["core"], "auth")
from core.auth import AuthManager
return AuthManager(str(tmp_path / "auth.json"))
class TestConcurrentCreateUser:
"""Concurrent create_user calls must not lose accounts."""
def test_parallel_creates_no_lost_users(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
num_users = 50
def create(i):
return mgr.create_user(f"user{i}", f"password{i}")
with ThreadPoolExecutor(max_workers=10) as pool:
futures = [pool.submit(create, i) for i in range(num_users)]
results = [f.result() for f in as_completed(futures)]
assert all(results), "Some create_user calls returned False unexpectedly"
assert len(mgr.users) == num_users
mgr2 = _fresh_auth_manager(tmp_path)
mgr2.auth_path = mgr.auth_path
mgr2._load()
assert len(mgr2.users) == num_users
def test_parallel_creates_same_username_only_one_wins(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
num_attempts = 20
def create(_):
return mgr.create_user("contested", "password123")
with ThreadPoolExecutor(max_workers=10) as pool:
futures = [pool.submit(create, i) for i in range(num_attempts)]
results = [f.result() for f in as_completed(futures)]
assert results.count(True) == 1
assert results.count(False) == num_attempts - 1
assert len(mgr.users) == 1
class TestConcurrentDeleteUser:
"""Concurrent deletes must not corrupt state."""
def test_parallel_deletes_no_corruption(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
mgr.create_user("admin", "adminpw", is_admin=True)
num_users = 30
for i in range(num_users):
mgr.create_user(f"target{i}", f"pw{i}")
assert len(mgr.users) == num_users + 1
def delete(i):
return mgr.delete_user(f"target{i}", "admin")
with ThreadPoolExecutor(max_workers=10) as pool:
futures = [pool.submit(delete, i) for i in range(num_users)]
results = [f.result() for f in as_completed(futures)]
assert all(results)
assert len(mgr.users) == 1
with open(mgr.auth_path, "r") as f:
data = json.load(f)
assert len(data["users"]) == 1
assert "admin" in data["users"]
class TestConcurrentRenameUser:
"""Concurrent renames must not lose or duplicate users."""
def test_parallel_renames_no_lost_users(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
mgr.create_user("admin", "adminpw", is_admin=True)
num_users = 20
for i in range(num_users):
mgr.create_user(f"old{i}", f"pw{i}")
def rename(i):
return mgr.rename_user(f"old{i}", f"new{i}", "admin")
with ThreadPoolExecutor(max_workers=10) as pool:
futures = [pool.submit(rename, i) for i in range(num_users)]
results = [f.result() for f in as_completed(futures)]
assert all(results)
for i in range(num_users):
assert f"new{i}" in mgr.users
assert f"old{i}" not in mgr.users
assert len(mgr.users) == num_users + 1
class TestConcurrentMixedOperations:
"""Mixed create/delete/rename at the same time."""
def test_mixed_operations_no_corruption(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
mgr.create_user("admin", "adminpw", is_admin=True)
for i in range(20):
mgr.create_user(f"existing{i}", f"pw{i}")
def create_batch():
for i in range(20):
mgr.create_user(f"newuser{i}", f"pw{i}")
def delete_batch():
for i in range(10):
mgr.delete_user(f"existing{i}", "admin")
def rename_batch():
for i in range(10, 20):
mgr.rename_user(f"existing{i}", f"renamed{i}", "admin")
threads = [
threading.Thread(target=create_batch),
threading.Thread(target=delete_batch),
threading.Thread(target=rename_batch),
]
for t in threads:
t.start()
for t in threads:
t.join()
assert "admin" in mgr.users
for i in range(10):
assert f"existing{i}" not in mgr.users
for i in range(10, 20):
assert f"renamed{i}" in mgr.users
assert f"existing{i}" not in mgr.users
for i in range(20):
assert f"newuser{i}" in mgr.users
with open(mgr.auth_path, "r") as f:
data = json.load(f)
assert set(data["users"].keys()) == set(mgr.users.keys())
class TestDiskConsistency:
"""Verify auth.json is never in a corrupt state during concurrent writes."""
def test_file_always_valid_json_during_concurrent_ops(self, tmp_path):
mgr = _fresh_auth_manager(tmp_path)
mgr.create_user("admin", "adminpw", is_admin=True)
stop_event = threading.Event()
corruption_found = []
def reader():
while not stop_event.is_set():
try:
with open(mgr.auth_path, "r") as f:
content = f.read()
json.loads(content)
except json.JSONDecodeError as e:
corruption_found.append(str(e))
break
except FileNotFoundError:
pass
time.sleep(0.001)
def writer():
for i in range(50):
mgr.create_user(f"stress{i}", f"pw{i}")
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
writer_thread.join()
stop_event.set()
reader_thread.join()
assert not corruption_found, f"Corrupt JSON detected: {corruption_found[0]}"