"""Tests for ``src.upload_handler.UploadHandler`` uploads.json RMW atomicity. The production code serialises the read-modify-write of ``uploads.json`` under ``UploadHandler._index_lock`` and writes atomically via ``UploadHandler._atomic_write_json`` (temp + ``os.fsync`` + ``os.replace``). A ``.bak`` sibling is kept for partial-write recovery. These tests exercise: * N concurrent inserts retain all entries. * N concurrent uploads through ``save_upload`` retain all entries. * Duplicate-upload + new-insert race: the duplicate's stale snapshot must not overwrite a newer index entry. * Partial-write recovery from the ``.bak`` sibling. * The atomic-write primitives are wired in production code. * Smoke tests: normal upload, duplicate detection, info lookup after a backup-recovery scenario. """ import concurrent.futures import io import json import os import sys from pathlib import Path from types import SimpleNamespace import pytest PROJECT_ROOT = Path(__file__).resolve().parent.parent if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) try: from fastapi import HTTPException # type: ignore except Exception: # pragma: no cover class HTTPException(Exception): def __init__(self, status_code: int, detail: str = ""): self.status_code = status_code self.detail = detail super().__init__(detail) from src.upload_handler import UploadHandler # noqa: E402 N_WRITERS = 10 def _make_handler(tmp_path: Path) -> UploadHandler: base = tmp_path / "base" upload = tmp_path / "uploads" base.mkdir() upload.mkdir() return UploadHandler(base_dir=str(base), upload_dir=str(upload)) def _db_path(handler: UploadHandler) -> str: return os.path.join(handler.upload_dir, "uploads.json") def _seed_entry(owner: str, file_hash: str, file_id: str) -> dict: return { "id": file_id, "path": f"/tmp/{file_id}", "mime": "text/plain", "size": 0, "name": file_id, "hash": file_hash, "original_name": file_id, "uploaded_at": "2026-06-01T00:00:00", "last_accessed": "2026-06-01T00:00:00", "client_ip": "127.0.0.1", "owner": owner, } # --------------------------------------------------------------------------- # Concurrent writers via the production handler. # --------------------------------------------------------------------------- def test_concurrent_inserts_lose_entries(tmp_path): """N=10 concurrent inserters on the same ``uploads.json`` must all be retained. The production code does the reload + write under ``_index_lock``, and ``_atomic_write_json`` gives readers a consistent on-disk view. If either protection is removed, this test will fail. """ handler = _make_handler(tmp_path) db_path = _db_path(handler) with open(db_path, "w", encoding="utf-8") as f: json.dump({}, f) def insert(idx: int) -> None: with handler._index_lock: current = json.load(open(db_path)) if os.path.exists(db_path) else {} current[f"owner:hash_{idx}"] = {"id": f"file_{idx}", "owner": "owner"} handler._atomic_write_json(db_path, current) with concurrent.futures.ThreadPoolExecutor(max_workers=N_WRITERS) as pool: list(pool.map(insert, range(N_WRITERS))) with open(db_path, "r", encoding="utf-8") as f: final = json.load(f) assert len(final) == N_WRITERS, ( f"Expected {N_WRITERS} entries, got {len(final)}. The lock+atomic-write " "fix is not actually serialising the writers." ) def test_save_upload_concurrent_retains_all_entries(tmp_path): """Drive ``save_upload`` end-to-end with N=10 concurrent uploads. Each upload has unique content (unique hash). If ``_index_lock`` or ``_atomic_write_json`` is removed or bypassed in ``save_upload``, concurrent writers lose entries. This test proves the production path is wired. """ handler = _make_handler(tmp_path) handler.upload_rate_limit = 100 def upload_one(idx: int) -> None: content = f"unique-content-{idx}-{os.urandom(8).hex()}".encode() fake_upload = SimpleNamespace( filename=f"file_{idx}.txt", file=io.BytesIO(content), ) handler.save_upload(fake_upload, "127.0.0.1", f"owner_{idx % 3}") with concurrent.futures.ThreadPoolExecutor(max_workers=N_WRITERS) as pool: list(pool.map(upload_one, range(N_WRITERS))) db_path = _db_path(handler) with open(db_path, "r", encoding="utf-8") as f: final = json.load(f) assert len(final) == N_WRITERS, ( f"save_upload lost {N_WRITERS - len(final)}/{N_WRITERS} entries under " f"concurrent writes. Expected {N_WRITERS} entries, got {len(final)}. " f"Keys: {sorted(final.keys())}" ) # --------------------------------------------------------------------------- # Duplicate vs new-insert race. # --------------------------------------------------------------------------- async def test_duplicate_vs_insert_race_preserves_both(tmp_path): """The ``save_upload`` duplicate branch must reload ``uploads.json`` inside ``_index_lock`` before writing — it must not rely on a snapshot read before the lock. Pre-fix shape (the bug): the duplicate branch did ``existing_files = json.load(...)`` outside the lock, then under the lock did ``_atomic_write_json(uploads_db_path, existing_files)`` — a stale snapshot that could clobber a concurrent insert. Post-fix: both branches call ``_load_upload_index()`` inside the lock, so the duplicate's write is always based on the freshest state. This test exercises the invariant by running a duplicate + a new upload concurrently via the production ``save_upload`` and asserting that both entries survive. With a slow disk (real ``fsync``), the window is wide enough that the bug, if reintroduced, would clobber the new entry; here the test relies on the post-fix invariant being correct by construction and on the lock serialising the writes. """ import threading for iteration in range(3): iter_dir = tmp_path / f"iter_{iteration}" iter_dir.mkdir() handler = _make_handler(iter_dir) handler.upload_rate_limit = 100 db_path = _db_path(handler) shared_content = b"shared-bytes-dedupe" with open(db_path, "w", encoding="utf-8") as f: json.dump({}, f) # Seed: one upload (new entry) so the index has a real row to dedupe against. fake_seed = SimpleNamespace(filename="seed.txt", file=io.BytesIO(shared_content)) seed_result = handler.save_upload(fake_seed, "127.0.0.1", "owner_a") original_id = seed_result["id"] # Race: a duplicate of the seed (same content + owner) and a brand # new upload, both submitted via the real ``save_upload`` path. # The post-fix code must preserve both entries in uploads.json # and flag the duplicate as ``is_duplicate=True`` with the # original's id. fake_dup = SimpleNamespace(filename="shared.txt", file=io.BytesIO(shared_content)) fake_new = SimpleNamespace( filename="other.txt", file=io.BytesIO(b"different-content") ) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: f_dup = pool.submit( handler.save_upload, fake_dup, "127.0.0.1", "owner_a" ) f_new = pool.submit( handler.save_upload, fake_new, "127.0.0.1", "owner_a" ) dup_result = f_dup.result() new_result = f_new.result() assert dup_result.get("is_duplicate") is True, ( f"iter {iteration}: duplicate should be flagged is_duplicate=True" ) assert dup_result["id"] == original_id, ( f"iter {iteration}: duplicate should resolve to the seed's id" ) with open(db_path, "r", encoding="utf-8") as f: final = json.load(f) assert len(final) == 2, ( f"iter {iteration}: expected 2 entries (original + new) after " f"duplicate+insert race, got {len(final)}: {sorted(final.keys())}" ) assert original_id in {v["id"] for v in final.values()}, ( f"iter {iteration}: original id {original_id} missing from final index" ) # --------------------------------------------------------------------------- # Partial-write recovery from the .bak sibling. # --------------------------------------------------------------------------- def test_partial_write_recovery_via_bak(tmp_path): """SIGKILL/SIGTERM mid-write can leave ``uploads.json`` truncated. The fixed code (1) writes atomically via temp+rename so a SIGKILL leaves the previous good copy in place, and (2) falls back to the ``.bak`` sibling on read if the live file is corrupt. This test writes a valid ``uploads.json`` via the production helper (which creates a ``.bak``), then truncates the live file, and asserts that the next read recovers from the ``.bak``. """ handler = _make_handler(tmp_path) db_path = _db_path(handler) original = { f"owner:hash_{i}": _seed_entry("owner", f"hash_{i}", f"id_{i}") for i in range(3) } handler._atomic_write_json(db_path, original) handler._atomic_write_json(db_path, {"latest": True}) assert os.path.exists(db_path + ".bak"), ( "Production _atomic_write_json must create a .bak sibling on subsequent writes." ) full = open(db_path, "rb").read() truncated_len = max(1, len(full) // 2) with open(db_path, "wb") as f: f.write(full[:truncated_len]) recovered = handler._load_upload_index() missing = [k for k in original if k not in recovered] assert not missing, ( f"Partial-write recovery FAILED: {len(missing)} entries were lost. " f"Recovered keys: {sorted(recovered)}." ) # --------------------------------------------------------------------------- # Atomicity primitive audit on the production module. # --------------------------------------------------------------------------- def test_atomic_write_primitives_present_in_production_code(): """The production module must use atomic-write primitives for the RMW sites. The fix is in place when ``os.replace``, ``tempfile.mkstemp``, ``_atomic_write_json`` and ``self._index_lock`` are all present and the two RMW sites no longer use a bare ``open(path, "w") + json.dump``. """ src_path = PROJECT_ROOT / "src" / "upload_handler.py" text = src_path.read_text(encoding="utf-8") assert "os.replace" in text, ( f"{src_path} does not use os.replace — atomic-rename write is missing." ) assert "tempfile.mkstemp" in text or "NamedTemporaryFile" in text, ( f"{src_path} does not write to a temp file — atomic-rename write is missing." ) assert "_atomic_write_json" in text, ( f"{src_path} is missing the _atomic_write_json helper." ) assert "self._index_lock" in text, ( f"{src_path} is missing self._index_lock — concurrent writers are not serialised." ) # The dedupe path must do its read inside the lock too. assert text.count("with self._index_lock:") >= 2, ( "Both dedupe and insert RMW sites must be under _index_lock." ) # --------------------------------------------------------------------------- # Smoke tests: normal upload, duplicate detection, info lookup after recovery. # --------------------------------------------------------------------------- def test_smoke_normal_upload(tmp_path): """Smoke test: a single upload round-trips through ``save_upload`` and the metadata is retrievable via ``get_upload_info``.""" handler = _make_handler(tmp_path) handler.upload_rate_limit = 100 fake = SimpleNamespace(filename="hello.txt", file=io.BytesIO(b"hello world")) result = handler.save_upload(fake, "127.0.0.1", "owner_a") assert result["name"] == "hello.txt" assert result["owner"] == "owner_a" assert "id" in result and "path" in result assert os.path.exists(result["path"]) info = handler.get_upload_info(result["id"]) assert info is not None assert info["id"] == result["id"] assert info["hash"] == result["hash"] def test_smoke_duplicate_upload(tmp_path): """Smoke test: re-uploading the same content as the same owner returns the original record with ``is_duplicate=True`` and does not create a second file row.""" handler = _make_handler(tmp_path) handler.upload_rate_limit = 100 content = b"duplicate-content" first = handler.save_upload( SimpleNamespace(filename="dup.txt", file=io.BytesIO(content)), "127.0.0.1", "owner_a", ) second = handler.save_upload( SimpleNamespace(filename="dup.txt", file=io.BytesIO(content)), "127.0.0.1", "owner_a", ) assert second["is_duplicate"] is True assert second["id"] == first["id"] with open(_db_path(handler), "r", encoding="utf-8") as f: final = json.load(f) assert len(final) == 1, f"Duplicate upload should not add a new row, got {len(final)}" def test_smoke_info_lookup_after_bak_recovery(tmp_path): """Smoke test: after a torn write is recovered from the ``.bak`` sibling, ``get_upload_info`` still finds the original entry by id.""" handler = _make_handler(tmp_path) handler.upload_rate_limit = 100 db_path = _db_path(handler) first = handler.save_upload( SimpleNamespace(filename="orig.txt", file=io.BytesIO(b"original")), "127.0.0.1", "owner_a", ) # Force a .bak by writing a second time. handler._atomic_write_json( db_path, json.load(open(db_path)), ) handler._atomic_write_json(db_path, {"sentinel": True}) assert os.path.exists(db_path + ".bak") # Truncate the live file. full = open(db_path, "rb").read() with open(db_path, "wb") as f: f.write(full[: max(1, len(full) // 2)]) info = handler.get_upload_info(first["id"]) assert info is not None, "Info lookup must succeed after .bak recovery." assert info["id"] == first["id"] assert info["hash"] == first["hash"]