diff --git a/routes/embedding_routes.py b/routes/embedding_routes.py index dbe075a..a5ef4c0 100644 --- a/routes/embedding_routes.py +++ b/routes/embedding_routes.py @@ -49,19 +49,35 @@ def _model_cache_name(hf_source: str) -> str: return "models--" + hf_source.replace("/", "--") +def _model_cache_path(hf_source: str) -> Path: + """Return a confined cache path for a fastembed HF source.""" + root = Path(_cache_dir()).expanduser().resolve() + raw_path = root / _model_cache_name(hf_source) + if raw_path.is_symlink(): + raise ValueError("Model cache path must not be a symlink") + path = raw_path.resolve(strict=False) + try: + path.relative_to(root) + except ValueError: + raise ValueError("Model cache path escapes cache root") + return path + + def _is_downloaded(hf_source: str) -> bool: """Check if a model is already cached.""" - cache = _cache_dir() - model_dir = os.path.join(cache, _model_cache_name(hf_source)) - if not os.path.isdir(model_dir): + try: + model_dir = _model_cache_path(hf_source) + except ValueError: + return False + if not model_dir.is_dir(): return False # Check for actual model files (not just empty dir) - snapshots = os.path.join(model_dir, "snapshots") - if os.path.isdir(snapshots): - return any(os.listdir(snapshots)) + snapshots = model_dir / "snapshots" + if snapshots.is_dir(): + return any(snapshots.iterdir()) # Also check for blobs (older cache format) - blobs = os.path.join(model_dir, "blobs") - return os.path.isdir(blobs) and any(os.listdir(blobs)) + blobs = model_dir / "blobs" + return blobs.is_dir() and any(blobs.iterdir()) def _active_model() -> str: @@ -119,8 +135,10 @@ def setup_embedding_routes(): cached_size = None if downloaded and hf_src: - model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src)) - cached_size = _dir_size_mb(model_path) + try: + cached_size = _dir_size_mb(str(_model_cache_path(hf_src))) + except ValueError: + cached_size = None result.append({ "model": m["model"], @@ -217,8 +235,11 @@ def setup_embedding_routes(): if not hf_src: raise HTTPException(400, "No cache source for this model") - model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src)) - if not os.path.isdir(model_path): + try: + model_path = _model_cache_path(hf_src) + except ValueError as e: + raise HTTPException(400, str(e)) + if not model_path.is_dir(): return {"deleted": False, "message": "Model not cached"} shutil.rmtree(model_path) diff --git a/tests/test_embedding_cache_confinement.py b/tests/test_embedding_cache_confinement.py new file mode 100644 index 0000000..0cf93d4 --- /dev/null +++ b/tests/test_embedding_cache_confinement.py @@ -0,0 +1,75 @@ +import sys +import types + +import pytest +from fastapi import HTTPException + +import routes.embedding_routes as embedding_routes + + +def _install_fastembed_stub(monkeypatch): + fastembed = types.ModuleType("fastembed") + + class TextEmbedding: + @staticmethod + def list_supported_models(): + return [{"model": "test-model", "sources": {"hf": "org/test-model"}}] + + fastembed.TextEmbedding = TextEmbedding + monkeypatch.setitem(sys.modules, "fastembed", fastembed) + + +def _route_endpoint(path: str, method: str): + router = embedding_routes.setup_embedding_routes() + for route in router.routes: + if route.path == path and method in route.methods: + return route.endpoint + raise AssertionError(f"route not found: {method} {path}") + + +def test_model_cache_path_resolves_under_cache_root(tmp_path, monkeypatch): + monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(tmp_path / "cache")) + + path = embedding_routes._model_cache_path("org/test-model") + + assert path == (tmp_path / "cache" / "models--org--test-model").resolve() + + +def test_model_cache_path_rejects_top_level_symlink_escape(tmp_path, monkeypatch): + cache = tmp_path / "cache" + outside = tmp_path / "outside" + cache.mkdir() + outside.mkdir() + monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(cache)) + link = cache / "models--org--test-model" + try: + link.symlink_to(outside, target_is_directory=True) + except (AttributeError, NotImplementedError, OSError) as exc: + pytest.skip(f"symlinks unavailable: {exc}") + + with pytest.raises(ValueError): + embedding_routes._model_cache_path("org/test-model") + assert embedding_routes._is_downloaded("org/test-model") is False + + +def test_delete_model_rejects_symlink_cache_dir(tmp_path, monkeypatch): + cache = tmp_path / "cache" + outside = tmp_path / "outside" + cache.mkdir() + outside.mkdir() + (outside / "keep.txt").write_text("outside", encoding="utf-8") + monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(cache)) + monkeypatch.setattr(embedding_routes, "_active_model", lambda: "other-model") + _install_fastembed_stub(monkeypatch) + link = cache / "models--org--test-model" + try: + link.symlink_to(outside, target_is_directory=True) + except (AttributeError, NotImplementedError, OSError) as exc: + pytest.skip(f"symlinks unavailable: {exc}") + delete_model = _route_endpoint("/api/embeddings/models/{model_name:path}", "DELETE") + + with pytest.raises(HTTPException) as exc: + delete_model("test-model") + + assert exc.value.status_code == 400 + assert (outside / "keep.txt").exists()