fix: add threading lock to AuthManager config mutations (#1226)
This commit is contained in:
132
core/auth.py
132
core/auth.py
@@ -76,6 +76,10 @@ class AuthManager:
|
||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||
self._sessions_lock = threading.RLock()
|
||||
# Guards all mutations of self._config and the on-disk auth.json so
|
||||
# concurrent create/delete/rename/privilege operations don't interleave
|
||||
# and corrupt the user database.
|
||||
self._config_lock = threading.Lock()
|
||||
# Guards the first-run setup check-and-write so concurrent requests
|
||||
# cannot both observe is_configured==False and both create admin accounts.
|
||||
self._setup_lock = threading.Lock()
|
||||
@@ -172,8 +176,9 @@ class AuthManager:
|
||||
|
||||
@signup_enabled.setter
|
||||
def signup_enabled(self, value: bool):
|
||||
self._config["signup_enabled"] = value
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["signup_enabled"] = value
|
||||
self._save()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
@@ -198,17 +203,18 @@ class AuthManager:
|
||||
if username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to create reserved username '%s'", username)
|
||||
return False
|
||||
if username in self.users:
|
||||
return False
|
||||
if "users" not in self._config:
|
||||
self._config["users"] = {}
|
||||
self._config["users"][username] = {
|
||||
"password_hash": _hash_password(password),
|
||||
"created": time.time(),
|
||||
"is_admin": is_admin,
|
||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||
}
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username in self.users:
|
||||
return False
|
||||
if "users" not in self._config:
|
||||
self._config["users"] = {}
|
||||
self._config["users"][username] = {
|
||||
"password_hash": _hash_password(password),
|
||||
"created": time.time(),
|
||||
"is_admin": is_admin,
|
||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||
}
|
||||
self._save()
|
||||
logger.info(f"Created user '{username}' (admin={is_admin})")
|
||||
return True
|
||||
|
||||
@@ -221,14 +227,15 @@ class AuthManager:
|
||||
their cookie expired naturally (default ~30 days).
|
||||
"""
|
||||
username = username.strip().lower()
|
||||
if username not in self.users:
|
||||
return False
|
||||
if username == requesting_user:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
del self._config["users"][username]
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if username == requesting_user:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
del self._config["users"][username]
|
||||
self._save()
|
||||
# Purge all sessions belonging to this user. validate_token doesn't
|
||||
# cross-check `self.users`, so without this step a deleted user's
|
||||
# cookie keeps authenticating.
|
||||
@@ -266,14 +273,15 @@ class AuthManager:
|
||||
if new_username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||
return False
|
||||
if old_username not in self.users:
|
||||
return False
|
||||
if new_username in self.users:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if old_username not in self.users:
|
||||
return False
|
||||
if new_username in self.users:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||
self._save()
|
||||
|
||||
renamed_sessions = 0
|
||||
with self._sessions_lock:
|
||||
@@ -311,17 +319,18 @@ class AuthManager:
|
||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||
"""Update privileges for a user. Can't modify admin privileges."""
|
||||
username = username.strip().lower()
|
||||
if username not in self.users:
|
||||
return False
|
||||
if self.users[username].get("is_admin"):
|
||||
return False # admins always have full access
|
||||
# Only allow known privilege keys
|
||||
current = self.get_privileges(username)
|
||||
for k, v in privileges.items():
|
||||
if k in DEFAULT_PRIVILEGES:
|
||||
current[k] = v
|
||||
self._config["users"][username]["privileges"] = current
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if self.users[username].get("is_admin"):
|
||||
return False # admins always have full access
|
||||
# Only allow known privilege keys
|
||||
current = self.get_privileges(username)
|
||||
for k, v in privileges.items():
|
||||
if k in DEFAULT_PRIVILEGES:
|
||||
current[k] = v
|
||||
self._config["users"][username]["privileges"] = current
|
||||
self._save()
|
||||
logger.info(f"Updated privileges for '{username}': {current}")
|
||||
return True
|
||||
|
||||
@@ -331,8 +340,9 @@ class AuthManager:
|
||||
return False
|
||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||
return False
|
||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||
self._save()
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -350,8 +360,9 @@ class AuthManager:
|
||||
if username not in self.users:
|
||||
return None
|
||||
secret = pyotp.random_base32()
|
||||
self._config["users"][username]["totp_secret_pending"] = secret
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret_pending"] = secret
|
||||
self._save()
|
||||
return secret
|
||||
|
||||
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
||||
@@ -370,13 +381,14 @@ class AuthManager:
|
||||
if not totp.verify(code, valid_window=1):
|
||||
return False
|
||||
# Enable 2FA
|
||||
self._config["users"][username]["totp_secret"] = secret
|
||||
self._config["users"][username]["totp_enabled"] = True
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
# Generate backup codes
|
||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret"] = secret
|
||||
self._config["users"][username]["totp_enabled"] = True
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
# Generate backup codes
|
||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
logger.info(f"2FA enabled for '{username}'")
|
||||
return True
|
||||
|
||||
@@ -395,9 +407,10 @@ class AuthManager:
|
||||
# Check backup codes first
|
||||
backup = user.get("totp_backup_codes", [])
|
||||
if code in backup:
|
||||
backup.remove(code)
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
backup.remove(code)
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
||||
return True
|
||||
totp = pyotp.TOTP(secret)
|
||||
@@ -408,11 +421,12 @@ class AuthManager:
|
||||
username = username.strip().lower()
|
||||
if not self.verify_password(username, password):
|
||||
return False
|
||||
self._config["users"][username].pop("totp_secret", None)
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
self._config["users"][username].pop("totp_backup_codes", None)
|
||||
self._config["users"][username]["totp_enabled"] = False
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username].pop("totp_secret", None)
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
self._config["users"][username].pop("totp_backup_codes", None)
|
||||
self._config["users"][username]["totp_enabled"] = False
|
||||
self._save()
|
||||
logger.info(f"2FA disabled for '{username}'")
|
||||
return True
|
||||
|
||||
|
||||
198
tests/test_auth_config_lock_concurrency.py
Normal file
198
tests/test_auth_config_lock_concurrency.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Concurrency stress tests for AuthManager._config_lock.
|
||||
|
||||
Verifies that concurrent create/delete/rename operations don't lose data
|
||||
or corrupt auth.json. If someone removes the lock, these tests should fail
|
||||
with missing users or assertion errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _fresh_auth_manager(tmp_path):
|
||||
sys.modules.pop("core.auth", None)
|
||||
if "core" in sys.modules and hasattr(sys.modules["core"], "auth"):
|
||||
delattr(sys.modules["core"], "auth")
|
||||
from core.auth import AuthManager
|
||||
|
||||
return AuthManager(str(tmp_path / "auth.json"))
|
||||
|
||||
|
||||
class TestConcurrentCreateUser:
|
||||
"""Concurrent create_user calls must not lose accounts."""
|
||||
|
||||
def test_parallel_creates_no_lost_users(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
num_users = 50
|
||||
|
||||
def create(i):
|
||||
return mgr.create_user(f"user{i}", f"password{i}")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(create, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results), "Some create_user calls returned False unexpectedly"
|
||||
assert len(mgr.users) == num_users
|
||||
|
||||
mgr2 = _fresh_auth_manager(tmp_path)
|
||||
mgr2.auth_path = mgr.auth_path
|
||||
mgr2._load()
|
||||
assert len(mgr2.users) == num_users
|
||||
|
||||
def test_parallel_creates_same_username_only_one_wins(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
num_attempts = 20
|
||||
|
||||
def create(_):
|
||||
return mgr.create_user("contested", "password123")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(create, i) for i in range(num_attempts)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert results.count(True) == 1
|
||||
assert results.count(False) == num_attempts - 1
|
||||
assert len(mgr.users) == 1
|
||||
|
||||
|
||||
class TestConcurrentDeleteUser:
|
||||
"""Concurrent deletes must not corrupt state."""
|
||||
|
||||
def test_parallel_deletes_no_corruption(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
num_users = 30
|
||||
for i in range(num_users):
|
||||
mgr.create_user(f"target{i}", f"pw{i}")
|
||||
|
||||
assert len(mgr.users) == num_users + 1
|
||||
|
||||
def delete(i):
|
||||
return mgr.delete_user(f"target{i}", "admin")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(delete, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results)
|
||||
assert len(mgr.users) == 1
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
data = json.load(f)
|
||||
assert len(data["users"]) == 1
|
||||
assert "admin" in data["users"]
|
||||
|
||||
|
||||
class TestConcurrentRenameUser:
|
||||
"""Concurrent renames must not lose or duplicate users."""
|
||||
|
||||
def test_parallel_renames_no_lost_users(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
num_users = 20
|
||||
for i in range(num_users):
|
||||
mgr.create_user(f"old{i}", f"pw{i}")
|
||||
|
||||
def rename(i):
|
||||
return mgr.rename_user(f"old{i}", f"new{i}", "admin")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(rename, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results)
|
||||
for i in range(num_users):
|
||||
assert f"new{i}" in mgr.users
|
||||
assert f"old{i}" not in mgr.users
|
||||
|
||||
assert len(mgr.users) == num_users + 1
|
||||
|
||||
|
||||
class TestConcurrentMixedOperations:
|
||||
"""Mixed create/delete/rename at the same time."""
|
||||
|
||||
def test_mixed_operations_no_corruption(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
|
||||
for i in range(20):
|
||||
mgr.create_user(f"existing{i}", f"pw{i}")
|
||||
|
||||
def create_batch():
|
||||
for i in range(20):
|
||||
mgr.create_user(f"newuser{i}", f"pw{i}")
|
||||
|
||||
def delete_batch():
|
||||
for i in range(10):
|
||||
mgr.delete_user(f"existing{i}", "admin")
|
||||
|
||||
def rename_batch():
|
||||
for i in range(10, 20):
|
||||
mgr.rename_user(f"existing{i}", f"renamed{i}", "admin")
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=create_batch),
|
||||
threading.Thread(target=delete_batch),
|
||||
threading.Thread(target=rename_batch),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert "admin" in mgr.users
|
||||
for i in range(10):
|
||||
assert f"existing{i}" not in mgr.users
|
||||
for i in range(10, 20):
|
||||
assert f"renamed{i}" in mgr.users
|
||||
assert f"existing{i}" not in mgr.users
|
||||
for i in range(20):
|
||||
assert f"newuser{i}" in mgr.users
|
||||
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
data = json.load(f)
|
||||
assert set(data["users"].keys()) == set(mgr.users.keys())
|
||||
|
||||
|
||||
class TestDiskConsistency:
|
||||
"""Verify auth.json is never in a corrupt state during concurrent writes."""
|
||||
|
||||
def test_file_always_valid_json_during_concurrent_ops(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
|
||||
stop_event = threading.Event()
|
||||
corruption_found = []
|
||||
|
||||
def reader():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
content = f.read()
|
||||
json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
corruption_found.append(str(e))
|
||||
break
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
time.sleep(0.001)
|
||||
|
||||
def writer():
|
||||
for i in range(50):
|
||||
mgr.create_user(f"stress{i}", f"pw{i}")
|
||||
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
|
||||
reader_thread.start()
|
||||
writer_thread.start()
|
||||
writer_thread.join()
|
||||
stop_event.set()
|
||||
reader_thread.join()
|
||||
|
||||
assert not corruption_found, f"Corrupt JSON detected: {corruption_found[0]}"
|
||||
Reference in New Issue
Block a user