diff --git a/core/auth.py b/core/auth.py index d4f5d36..3c7669d 100644 --- a/core/auth.py +++ b/core/auth.py @@ -76,6 +76,10 @@ class AuthManager: # Guards mutations of self._sessions and the on-disk sessions.json. # Validate/create/revoke run concurrently from the FastAPI threadpool. self._sessions_lock = threading.RLock() + # Guards all mutations of self._config and the on-disk auth.json so + # concurrent create/delete/rename/privilege operations don't interleave + # and corrupt the user database. + self._config_lock = threading.Lock() # Guards the first-run setup check-and-write so concurrent requests # cannot both observe is_configured==False and both create admin accounts. self._setup_lock = threading.Lock() @@ -172,8 +176,9 @@ class AuthManager: @signup_enabled.setter def signup_enabled(self, value: bool): - self._config["signup_enabled"] = value - self._save() + with self._config_lock: + self._config["signup_enabled"] = value + self._save() @property def is_configured(self) -> bool: @@ -198,17 +203,18 @@ class AuthManager: if username in RESERVED_USERNAMES: logger.warning("Refused to create reserved username '%s'", username) return False - if username in self.users: - return False - if "users" not in self._config: - self._config["users"] = {} - self._config["users"][username] = { - "password_hash": _hash_password(password), - "created": time.time(), - "is_admin": is_admin, - "privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES), - } - self._save() + with self._config_lock: + if username in self.users: + return False + if "users" not in self._config: + self._config["users"] = {} + self._config["users"][username] = { + "password_hash": _hash_password(password), + "created": time.time(), + "is_admin": is_admin, + "privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES), + } + self._save() logger.info(f"Created user '{username}' (admin={is_admin})") return True @@ -221,14 +227,15 @@ class AuthManager: their cookie expired naturally (default ~30 days). """ username = username.strip().lower() - if username not in self.users: - return False - if username == requesting_user: - return False - if not self.users.get(requesting_user, {}).get("is_admin"): - return False - del self._config["users"][username] - self._save() + with self._config_lock: + if username not in self.users: + return False + if username == requesting_user: + return False + if not self.users.get(requesting_user, {}).get("is_admin"): + return False + del self._config["users"][username] + self._save() # Purge all sessions belonging to this user. validate_token doesn't # cross-check `self.users`, so without this step a deleted user's # cookie keeps authenticating. @@ -266,14 +273,15 @@ class AuthManager: if new_username in RESERVED_USERNAMES: logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username) return False - if old_username not in self.users: - return False - if new_username in self.users: - return False - if not self.users.get(requesting_user, {}).get("is_admin"): - return False - self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username) - self._save() + with self._config_lock: + if old_username not in self.users: + return False + if new_username in self.users: + return False + if not self.users.get(requesting_user, {}).get("is_admin"): + return False + self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username) + self._save() renamed_sessions = 0 with self._sessions_lock: @@ -311,17 +319,18 @@ class AuthManager: def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool: """Update privileges for a user. Can't modify admin privileges.""" username = username.strip().lower() - if username not in self.users: - return False - if self.users[username].get("is_admin"): - return False # admins always have full access - # Only allow known privilege keys - current = self.get_privileges(username) - for k, v in privileges.items(): - if k in DEFAULT_PRIVILEGES: - current[k] = v - self._config["users"][username]["privileges"] = current - self._save() + with self._config_lock: + if username not in self.users: + return False + if self.users[username].get("is_admin"): + return False # admins always have full access + # Only allow known privilege keys + current = self.get_privileges(username) + for k, v in privileges.items(): + if k in DEFAULT_PRIVILEGES: + current[k] = v + self._config["users"][username]["privileges"] = current + self._save() logger.info(f"Updated privileges for '{username}': {current}") return True @@ -331,8 +340,9 @@ class AuthManager: return False if not _verify_password(current_password, self.users[username]["password_hash"]): return False - self._config["users"][username]["password_hash"] = _hash_password(new_password) - self._save() + with self._config_lock: + self._config["users"][username]["password_hash"] = _hash_password(new_password) + self._save() return True # ------------------------------------------------------------------ @@ -350,8 +360,9 @@ class AuthManager: if username not in self.users: return None secret = pyotp.random_base32() - self._config["users"][username]["totp_secret_pending"] = secret - self._save() + with self._config_lock: + self._config["users"][username]["totp_secret_pending"] = secret + self._save() return secret def totp_get_provisioning_uri(self, username: str, secret: str) -> str: @@ -370,13 +381,14 @@ class AuthManager: if not totp.verify(code, valid_window=1): return False # Enable 2FA - self._config["users"][username]["totp_secret"] = secret - self._config["users"][username]["totp_enabled"] = True - self._config["users"][username].pop("totp_secret_pending", None) - # Generate backup codes - backup = [secrets.token_hex(4) for _ in range(8)] - self._config["users"][username]["totp_backup_codes"] = backup - self._save() + with self._config_lock: + self._config["users"][username]["totp_secret"] = secret + self._config["users"][username]["totp_enabled"] = True + self._config["users"][username].pop("totp_secret_pending", None) + # Generate backup codes + backup = [secrets.token_hex(4) for _ in range(8)] + self._config["users"][username]["totp_backup_codes"] = backup + self._save() logger.info(f"2FA enabled for '{username}'") return True @@ -395,9 +407,10 @@ class AuthManager: # Check backup codes first backup = user.get("totp_backup_codes", []) if code in backup: - backup.remove(code) - self._config["users"][username]["totp_backup_codes"] = backup - self._save() + with self._config_lock: + backup.remove(code) + self._config["users"][username]["totp_backup_codes"] = backup + self._save() logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)") return True totp = pyotp.TOTP(secret) @@ -408,11 +421,12 @@ class AuthManager: username = username.strip().lower() if not self.verify_password(username, password): return False - self._config["users"][username].pop("totp_secret", None) - self._config["users"][username].pop("totp_secret_pending", None) - self._config["users"][username].pop("totp_backup_codes", None) - self._config["users"][username]["totp_enabled"] = False - self._save() + with self._config_lock: + self._config["users"][username].pop("totp_secret", None) + self._config["users"][username].pop("totp_secret_pending", None) + self._config["users"][username].pop("totp_backup_codes", None) + self._config["users"][username]["totp_enabled"] = False + self._save() logger.info(f"2FA disabled for '{username}'") return True diff --git a/tests/test_auth_config_lock_concurrency.py b/tests/test_auth_config_lock_concurrency.py new file mode 100644 index 0000000..39e196a --- /dev/null +++ b/tests/test_auth_config_lock_concurrency.py @@ -0,0 +1,198 @@ +"""Concurrency stress tests for AuthManager._config_lock. + +Verifies that concurrent create/delete/rename operations don't lose data +or corrupt auth.json. If someone removes the lock, these tests should fail +with missing users or assertion errors. +""" + +import json +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + + +def _fresh_auth_manager(tmp_path): + sys.modules.pop("core.auth", None) + if "core" in sys.modules and hasattr(sys.modules["core"], "auth"): + delattr(sys.modules["core"], "auth") + from core.auth import AuthManager + + return AuthManager(str(tmp_path / "auth.json")) + + +class TestConcurrentCreateUser: + """Concurrent create_user calls must not lose accounts.""" + + def test_parallel_creates_no_lost_users(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + num_users = 50 + + def create(i): + return mgr.create_user(f"user{i}", f"password{i}") + + with ThreadPoolExecutor(max_workers=10) as pool: + futures = [pool.submit(create, i) for i in range(num_users)] + results = [f.result() for f in as_completed(futures)] + + assert all(results), "Some create_user calls returned False unexpectedly" + assert len(mgr.users) == num_users + + mgr2 = _fresh_auth_manager(tmp_path) + mgr2.auth_path = mgr.auth_path + mgr2._load() + assert len(mgr2.users) == num_users + + def test_parallel_creates_same_username_only_one_wins(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + num_attempts = 20 + + def create(_): + return mgr.create_user("contested", "password123") + + with ThreadPoolExecutor(max_workers=10) as pool: + futures = [pool.submit(create, i) for i in range(num_attempts)] + results = [f.result() for f in as_completed(futures)] + + assert results.count(True) == 1 + assert results.count(False) == num_attempts - 1 + assert len(mgr.users) == 1 + + +class TestConcurrentDeleteUser: + """Concurrent deletes must not corrupt state.""" + + def test_parallel_deletes_no_corruption(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + mgr.create_user("admin", "adminpw", is_admin=True) + num_users = 30 + for i in range(num_users): + mgr.create_user(f"target{i}", f"pw{i}") + + assert len(mgr.users) == num_users + 1 + + def delete(i): + return mgr.delete_user(f"target{i}", "admin") + + with ThreadPoolExecutor(max_workers=10) as pool: + futures = [pool.submit(delete, i) for i in range(num_users)] + results = [f.result() for f in as_completed(futures)] + + assert all(results) + assert len(mgr.users) == 1 + with open(mgr.auth_path, "r") as f: + data = json.load(f) + assert len(data["users"]) == 1 + assert "admin" in data["users"] + + +class TestConcurrentRenameUser: + """Concurrent renames must not lose or duplicate users.""" + + def test_parallel_renames_no_lost_users(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + mgr.create_user("admin", "adminpw", is_admin=True) + num_users = 20 + for i in range(num_users): + mgr.create_user(f"old{i}", f"pw{i}") + + def rename(i): + return mgr.rename_user(f"old{i}", f"new{i}", "admin") + + with ThreadPoolExecutor(max_workers=10) as pool: + futures = [pool.submit(rename, i) for i in range(num_users)] + results = [f.result() for f in as_completed(futures)] + + assert all(results) + for i in range(num_users): + assert f"new{i}" in mgr.users + assert f"old{i}" not in mgr.users + + assert len(mgr.users) == num_users + 1 + + +class TestConcurrentMixedOperations: + """Mixed create/delete/rename at the same time.""" + + def test_mixed_operations_no_corruption(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + mgr.create_user("admin", "adminpw", is_admin=True) + + for i in range(20): + mgr.create_user(f"existing{i}", f"pw{i}") + + def create_batch(): + for i in range(20): + mgr.create_user(f"newuser{i}", f"pw{i}") + + def delete_batch(): + for i in range(10): + mgr.delete_user(f"existing{i}", "admin") + + def rename_batch(): + for i in range(10, 20): + mgr.rename_user(f"existing{i}", f"renamed{i}", "admin") + + threads = [ + threading.Thread(target=create_batch), + threading.Thread(target=delete_batch), + threading.Thread(target=rename_batch), + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert "admin" in mgr.users + for i in range(10): + assert f"existing{i}" not in mgr.users + for i in range(10, 20): + assert f"renamed{i}" in mgr.users + assert f"existing{i}" not in mgr.users + for i in range(20): + assert f"newuser{i}" in mgr.users + + with open(mgr.auth_path, "r") as f: + data = json.load(f) + assert set(data["users"].keys()) == set(mgr.users.keys()) + + +class TestDiskConsistency: + """Verify auth.json is never in a corrupt state during concurrent writes.""" + + def test_file_always_valid_json_during_concurrent_ops(self, tmp_path): + mgr = _fresh_auth_manager(tmp_path) + mgr.create_user("admin", "adminpw", is_admin=True) + + stop_event = threading.Event() + corruption_found = [] + + def reader(): + while not stop_event.is_set(): + try: + with open(mgr.auth_path, "r") as f: + content = f.read() + json.loads(content) + except json.JSONDecodeError as e: + corruption_found.append(str(e)) + break + except FileNotFoundError: + pass + time.sleep(0.001) + + def writer(): + for i in range(50): + mgr.create_user(f"stress{i}", f"pw{i}") + + reader_thread = threading.Thread(target=reader) + writer_thread = threading.Thread(target=writer) + + reader_thread.start() + writer_thread.start() + writer_thread.join() + stop_event.set() + reader_thread.join() + + assert not corruption_found, f"Corrupt JSON detected: {corruption_found[0]}"