319 lines
11 KiB
Python
319 lines
11 KiB
Python
# routes/embedding_routes.py
|
|
"""Routes for managing local fastembed embedding models and custom endpoints."""
|
|
import os
|
|
import json
|
|
import shutil
|
|
import logging
|
|
import asyncio
|
|
from pathlib import Path
|
|
from fastapi import APIRouter, HTTPException, Form
|
|
from core.constants import BASE_DIR
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
|
|
|
# Track in-progress downloads
|
|
_downloading: dict = {}
|
|
|
|
# Curated recommendations — good coverage of size/quality tiers
|
|
RECOMMENDED_MODELS = {
|
|
"sentence-transformers/all-MiniLM-L6-v2", # 384d, 90MB — fast & tiny, good default
|
|
"BAAI/bge-small-en-v1.5", # 384d, 67MB — smallest, solid quality
|
|
"nomic-ai/nomic-embed-text-v1.5-Q", # 768d, 130MB — quantized, great bang/buck
|
|
"BAAI/bge-base-en-v1.5", # 768d, 210MB — balanced mid-range
|
|
"snowflake/snowflake-arctic-embed-m", # 768d, 430MB — strong performer
|
|
"BAAI/bge-large-en-v1.5", # 1024d, 1.2GB — highest quality
|
|
}
|
|
|
|
|
|
def _cache_dir() -> str:
|
|
"""Get the fastembed cache directory.
|
|
|
|
Defaults to a persistent path under the repo's data/ dir. The old
|
|
default lived in /tmp, which many systems wipe on reboot — forcing a
|
|
full re-download of the embedding model after every restart.
|
|
"""
|
|
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
|
if env:
|
|
return env
|
|
return os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
"data", "fastembed_cache",
|
|
)
|
|
|
|
|
|
def _model_cache_name(hf_source: str) -> str:
|
|
"""Convert HF source like 'qdrant/all-MiniLM-L6-v2-onnx' to cache dir name."""
|
|
return "models--" + hf_source.replace("/", "--")
|
|
|
|
|
|
def _is_downloaded(hf_source: str) -> bool:
|
|
"""Check if a model is already cached."""
|
|
cache = _cache_dir()
|
|
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
|
if not os.path.isdir(model_dir):
|
|
return False
|
|
# Check for actual model files (not just empty dir)
|
|
snapshots = os.path.join(model_dir, "snapshots")
|
|
if os.path.isdir(snapshots):
|
|
return any(os.listdir(snapshots))
|
|
# Also check for blobs (older cache format)
|
|
blobs = os.path.join(model_dir, "blobs")
|
|
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
|
|
|
|
|
def _active_model() -> str:
|
|
"""Get the currently configured fastembed model name."""
|
|
return os.environ.get("FASTEMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
|
|
def _dir_size_mb(path: str) -> float:
|
|
"""Get directory size in MB."""
|
|
total = 0
|
|
for dirpath, _, filenames in os.walk(path):
|
|
for f in filenames:
|
|
fp = os.path.join(dirpath, f)
|
|
try:
|
|
total += os.path.getsize(fp)
|
|
except OSError:
|
|
pass
|
|
return round(total / (1024 * 1024), 1)
|
|
|
|
|
|
def _load_custom_endpoint() -> dict:
|
|
"""Load the saved custom embedding endpoint, if any."""
|
|
try:
|
|
if os.path.exists(_ENDPOINT_FILE):
|
|
return json.loads(Path(_ENDPOINT_FILE).read_text())
|
|
except Exception:
|
|
pass
|
|
return {}
|
|
|
|
|
|
def _save_custom_endpoint(data: dict):
|
|
Path(_ENDPOINT_FILE).parent.mkdir(parents=True, exist_ok=True)
|
|
Path(_ENDPOINT_FILE).write_text(json.dumps(data, indent=2))
|
|
|
|
|
|
def setup_embedding_routes():
|
|
router = APIRouter(prefix="/api/embeddings")
|
|
|
|
@router.get("/models")
|
|
def list_models():
|
|
"""List all available fastembed models with download status."""
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
except ImportError:
|
|
raise HTTPException(503, "fastembed is not installed")
|
|
|
|
active = _active_model()
|
|
catalog = TextEmbedding.list_supported_models()
|
|
result = []
|
|
|
|
for m in catalog:
|
|
hf_src = m.get("sources", {}).get("hf", "")
|
|
downloaded = _is_downloaded(hf_src) if hf_src else False
|
|
|
|
cached_size = None
|
|
if downloaded and hf_src:
|
|
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
|
cached_size = _dir_size_mb(model_path)
|
|
|
|
result.append({
|
|
"model": m["model"],
|
|
"dim": m.get("dim"),
|
|
"size_gb": m.get("size_in_GB", 0),
|
|
"description": m.get("description", ""),
|
|
"downloaded": downloaded,
|
|
"downloading": m["model"] in _downloading,
|
|
"active": m["model"] == active,
|
|
"recommended": m["model"] in RECOMMENDED_MODELS,
|
|
"cached_size_mb": cached_size,
|
|
})
|
|
|
|
# Sort: active first, then downloaded, then by size
|
|
result.sort(key=lambda x: (not x["active"], not x["downloaded"], x["size_gb"]))
|
|
return result
|
|
|
|
@router.post("/models/{model_name:path}/download")
|
|
async def download_model(model_name: str):
|
|
"""Download a fastembed model. Returns when complete."""
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
except ImportError:
|
|
raise HTTPException(503, "fastembed is not installed")
|
|
|
|
# Validate model exists
|
|
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
|
if model_name not in catalog:
|
|
raise HTTPException(404, f"Unknown model: {model_name}")
|
|
|
|
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
|
if hf_src and _is_downloaded(hf_src):
|
|
return {"status": "already_downloaded", "model": model_name}
|
|
|
|
if model_name in _downloading:
|
|
return {"status": "already_downloading", "model": model_name}
|
|
|
|
_downloading[model_name] = True
|
|
try:
|
|
# Run in thread to not block the event loop
|
|
loop = asyncio.get_event_loop()
|
|
cache = _cache_dir()
|
|
await loop.run_in_executor(
|
|
None,
|
|
lambda: TextEmbedding(model_name=model_name, cache_dir=cache),
|
|
)
|
|
return {"status": "downloaded", "model": model_name}
|
|
except Exception as e:
|
|
logger.error(f"Failed to download {model_name}: {e}")
|
|
raise HTTPException(500, f"Download failed: {str(e)}")
|
|
finally:
|
|
_downloading.pop(model_name, None)
|
|
|
|
@router.get("/models/{model_name:path}/status")
|
|
def download_status(model_name: str):
|
|
"""Check download status of a model."""
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
except ImportError:
|
|
raise HTTPException(503, "fastembed is not installed")
|
|
|
|
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
|
if model_name not in catalog:
|
|
raise HTTPException(404, f"Unknown model: {model_name}")
|
|
|
|
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
|
downloaded = _is_downloaded(hf_src) if hf_src else False
|
|
|
|
return {
|
|
"model": model_name,
|
|
"downloaded": downloaded,
|
|
"downloading": model_name in _downloading,
|
|
}
|
|
|
|
@router.delete("/models/{model_name:path}")
|
|
def delete_model(model_name: str):
|
|
"""Delete a cached model."""
|
|
if model_name == _active_model():
|
|
raise HTTPException(400, "Cannot delete the active embedding model")
|
|
|
|
if model_name in _downloading:
|
|
raise HTTPException(400, "Model is currently downloading")
|
|
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
except ImportError:
|
|
raise HTTPException(503, "fastembed is not installed")
|
|
|
|
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
|
if model_name not in catalog:
|
|
raise HTTPException(404, f"Unknown model: {model_name}")
|
|
|
|
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
|
if not hf_src:
|
|
raise HTTPException(400, "No cache source for this model")
|
|
|
|
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
|
if not os.path.isdir(model_path):
|
|
return {"deleted": False, "message": "Model not cached"}
|
|
|
|
shutil.rmtree(model_path)
|
|
logger.info(f"Deleted cached model: {model_name} ({model_path})")
|
|
return {"deleted": True, "model": model_name}
|
|
|
|
@router.get("/endpoint")
|
|
def get_endpoint():
|
|
"""Get the current custom embedding endpoint config."""
|
|
saved = _load_custom_endpoint()
|
|
current_url = os.environ.get("EMBEDDING_URL", "")
|
|
return {
|
|
"url": saved.get("url", current_url),
|
|
"model": saved.get("model", os.environ.get("EMBEDDING_MODEL", "")),
|
|
"active": bool(saved.get("url") or current_url),
|
|
}
|
|
|
|
@router.post("/endpoint")
|
|
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
|
"""Save a custom embedding endpoint URL."""
|
|
url = url.strip()
|
|
if not url:
|
|
raise HTTPException(400, "URL is required")
|
|
|
|
# Quick health check
|
|
try:
|
|
import httpx
|
|
resp = httpx.post(
|
|
url,
|
|
json={"input": ["test"], "model": model or "test"},
|
|
timeout=10,
|
|
)
|
|
resp.raise_for_status()
|
|
except Exception as e:
|
|
raise HTTPException(400, f"Endpoint unreachable: {e}")
|
|
|
|
# Persist and set in environment for immediate use
|
|
data = {"url": url}
|
|
if model:
|
|
data["model"] = model
|
|
_save_custom_endpoint(data)
|
|
os.environ["EMBEDDING_URL"] = url
|
|
if model:
|
|
os.environ["EMBEDDING_MODEL"] = model
|
|
|
|
# Reset the RAG singleton so it picks up the new endpoint
|
|
import src.rag_singleton as _rs
|
|
_rs.rag_instance = None
|
|
_rs._last_attempt = 0
|
|
|
|
# Clear the HTTP-embedding "down" latch so the new endpoint is re-probed
|
|
# instead of staying on the FastEmbed fallback for the process lifetime.
|
|
try:
|
|
from src.embeddings import reset_http_embed_state
|
|
reset_http_embed_state()
|
|
except Exception:
|
|
pass
|
|
|
|
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
|
try:
|
|
from src.chroma_client import reset_client
|
|
reset_client()
|
|
except Exception:
|
|
pass
|
|
|
|
logger.info(f"Custom embedding endpoint set: {url}")
|
|
return {"success": True, "url": url, "model": model}
|
|
|
|
@router.delete("/endpoint")
|
|
def clear_endpoint():
|
|
"""Clear the custom endpoint and revert to local fastembed."""
|
|
if os.path.exists(_ENDPOINT_FILE):
|
|
os.remove(_ENDPOINT_FILE)
|
|
|
|
# Remove from environment
|
|
os.environ.pop("EMBEDDING_URL", None)
|
|
os.environ.pop("EMBEDDING_MODEL", None)
|
|
|
|
# Reset the RAG singleton so it falls back to fastembed
|
|
import src.rag_singleton as _rs
|
|
_rs.rag_instance = None
|
|
_rs._last_attempt = 0
|
|
try:
|
|
from src.embeddings import reset_http_embed_state
|
|
reset_http_embed_state()
|
|
except Exception:
|
|
pass
|
|
|
|
# Reset ChromaDB client
|
|
try:
|
|
from src.chroma_client import reset_client
|
|
reset_client()
|
|
except Exception:
|
|
pass
|
|
|
|
logger.info("Custom embedding endpoint cleared, reverting to local fastembed")
|
|
return {"success": True}
|
|
|
|
return router
|