Constrain embedding model cache paths (#2849)
This commit is contained in:
@@ -49,19 +49,35 @@ def _model_cache_name(hf_source: str) -> str:
|
|||||||
return "models--" + hf_source.replace("/", "--")
|
return "models--" + hf_source.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
def _model_cache_path(hf_source: str) -> Path:
|
||||||
|
"""Return a confined cache path for a fastembed HF source."""
|
||||||
|
root = Path(_cache_dir()).expanduser().resolve()
|
||||||
|
raw_path = root / _model_cache_name(hf_source)
|
||||||
|
if raw_path.is_symlink():
|
||||||
|
raise ValueError("Model cache path must not be a symlink")
|
||||||
|
path = raw_path.resolve(strict=False)
|
||||||
|
try:
|
||||||
|
path.relative_to(root)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Model cache path escapes cache root")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
def _is_downloaded(hf_source: str) -> bool:
|
def _is_downloaded(hf_source: str) -> bool:
|
||||||
"""Check if a model is already cached."""
|
"""Check if a model is already cached."""
|
||||||
cache = _cache_dir()
|
try:
|
||||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
model_dir = _model_cache_path(hf_source)
|
||||||
if not os.path.isdir(model_dir):
|
except ValueError:
|
||||||
|
return False
|
||||||
|
if not model_dir.is_dir():
|
||||||
return False
|
return False
|
||||||
# Check for actual model files (not just empty dir)
|
# Check for actual model files (not just empty dir)
|
||||||
snapshots = os.path.join(model_dir, "snapshots")
|
snapshots = model_dir / "snapshots"
|
||||||
if os.path.isdir(snapshots):
|
if snapshots.is_dir():
|
||||||
return any(os.listdir(snapshots))
|
return any(snapshots.iterdir())
|
||||||
# Also check for blobs (older cache format)
|
# Also check for blobs (older cache format)
|
||||||
blobs = os.path.join(model_dir, "blobs")
|
blobs = model_dir / "blobs"
|
||||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
return blobs.is_dir() and any(blobs.iterdir())
|
||||||
|
|
||||||
|
|
||||||
def _active_model() -> str:
|
def _active_model() -> str:
|
||||||
@@ -119,8 +135,10 @@ def setup_embedding_routes():
|
|||||||
|
|
||||||
cached_size = None
|
cached_size = None
|
||||||
if downloaded and hf_src:
|
if downloaded and hf_src:
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
cached_size = _dir_size_mb(model_path)
|
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
|
||||||
|
except ValueError:
|
||||||
|
cached_size = None
|
||||||
|
|
||||||
result.append({
|
result.append({
|
||||||
"model": m["model"],
|
"model": m["model"],
|
||||||
@@ -217,8 +235,11 @@ def setup_embedding_routes():
|
|||||||
if not hf_src:
|
if not hf_src:
|
||||||
raise HTTPException(400, "No cache source for this model")
|
raise HTTPException(400, "No cache source for this model")
|
||||||
|
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
if not os.path.isdir(model_path):
|
model_path = _model_cache_path(hf_src)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if not model_path.is_dir():
|
||||||
return {"deleted": False, "message": "Model not cached"}
|
return {"deleted": False, "message": "Model not cached"}
|
||||||
|
|
||||||
shutil.rmtree(model_path)
|
shutil.rmtree(model_path)
|
||||||
|
|||||||
75
tests/test_embedding_cache_confinement.py
Normal file
75
tests/test_embedding_cache_confinement.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import routes.embedding_routes as embedding_routes
|
||||||
|
|
||||||
|
|
||||||
|
def _install_fastembed_stub(monkeypatch):
|
||||||
|
fastembed = types.ModuleType("fastembed")
|
||||||
|
|
||||||
|
class TextEmbedding:
|
||||||
|
@staticmethod
|
||||||
|
def list_supported_models():
|
||||||
|
return [{"model": "test-model", "sources": {"hf": "org/test-model"}}]
|
||||||
|
|
||||||
|
fastembed.TextEmbedding = TextEmbedding
|
||||||
|
monkeypatch.setitem(sys.modules, "fastembed", fastembed)
|
||||||
|
|
||||||
|
|
||||||
|
def _route_endpoint(path: str, method: str):
|
||||||
|
router = embedding_routes.setup_embedding_routes()
|
||||||
|
for route in router.routes:
|
||||||
|
if route.path == path and method in route.methods:
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError(f"route not found: {method} {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_cache_path_resolves_under_cache_root(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(tmp_path / "cache"))
|
||||||
|
|
||||||
|
path = embedding_routes._model_cache_path("org/test-model")
|
||||||
|
|
||||||
|
assert path == (tmp_path / "cache" / "models--org--test-model").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_cache_path_rejects_top_level_symlink_escape(tmp_path, monkeypatch):
|
||||||
|
cache = tmp_path / "cache"
|
||||||
|
outside = tmp_path / "outside"
|
||||||
|
cache.mkdir()
|
||||||
|
outside.mkdir()
|
||||||
|
monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(cache))
|
||||||
|
link = cache / "models--org--test-model"
|
||||||
|
try:
|
||||||
|
link.symlink_to(outside, target_is_directory=True)
|
||||||
|
except (AttributeError, NotImplementedError, OSError) as exc:
|
||||||
|
pytest.skip(f"symlinks unavailable: {exc}")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding_routes._model_cache_path("org/test-model")
|
||||||
|
assert embedding_routes._is_downloaded("org/test-model") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_model_rejects_symlink_cache_dir(tmp_path, monkeypatch):
|
||||||
|
cache = tmp_path / "cache"
|
||||||
|
outside = tmp_path / "outside"
|
||||||
|
cache.mkdir()
|
||||||
|
outside.mkdir()
|
||||||
|
(outside / "keep.txt").write_text("outside", encoding="utf-8")
|
||||||
|
monkeypatch.setattr(embedding_routes, "_cache_dir", lambda: str(cache))
|
||||||
|
monkeypatch.setattr(embedding_routes, "_active_model", lambda: "other-model")
|
||||||
|
_install_fastembed_stub(monkeypatch)
|
||||||
|
link = cache / "models--org--test-model"
|
||||||
|
try:
|
||||||
|
link.symlink_to(outside, target_is_directory=True)
|
||||||
|
except (AttributeError, NotImplementedError, OSError) as exc:
|
||||||
|
pytest.skip(f"symlinks unavailable: {exc}")
|
||||||
|
delete_model = _route_endpoint("/api/embeddings/models/{model_name:path}", "DELETE")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
delete_model("test-model")
|
||||||
|
|
||||||
|
assert exc.value.status_code == 400
|
||||||
|
assert (outside / "keep.txt").exists()
|
||||||
Reference in New Issue
Block a user