STTService._transcribe_local writes the audio to a NamedTemporaryFile (delete=False) and only unlinks it on the success path, before the except. If model.transcribe() raises (corrupt audio, model/runtime error, etc.) the function logs, returns None, and leaves the .webm temp file behind — so every failed local transcription leaks a file in the system temp dir. Initialize tmp_path = None up front and move the unlink into a finally block so the temp file is cleaned up whether transcription succeeds or raises. tests/test_stt_leak.py stubs the whisper model to raise during transcribe, runs _transcribe_local, and asserts it returns None and leaves no new .webm file in the temp dir. Fails before this change.
209 lines
7.6 KiB
Python
209 lines
7.6 KiB
Python
# services/stt/stt_service.py
|
|
"""Multi-provider Speech-to-Text service — dispatches to local Whisper, OpenAI-compatible API, or browser."""
|
|
|
|
import io
|
|
import logging
|
|
import httpx
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class STTService:
|
|
"""Multi-provider STT service.
|
|
|
|
Reads provider config from data/settings.json on each call.
|
|
Providers:
|
|
"disabled" — no STT
|
|
"browser" — client-side Web Speech API (no server transcription)
|
|
"local" — faster-whisper on CPU/GPU
|
|
"endpoint:<id>" — OpenAI-compatible /audio/transcriptions via ModelEndpoint
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._whisper_model = None # lazy-init
|
|
|
|
# ── Settings ──
|
|
|
|
def _load_settings(self) -> dict:
|
|
from src.settings import load_settings
|
|
saved = load_settings()
|
|
return {
|
|
"stt_enabled": saved.get("stt_enabled", False),
|
|
"stt_provider": saved.get("stt_provider", "disabled"),
|
|
"stt_model": saved.get("stt_model", "base"),
|
|
"stt_language": saved.get("stt_language", ""),
|
|
}
|
|
|
|
@property
|
|
def available(self) -> bool:
|
|
settings = self._load_settings()
|
|
if settings.get("stt_enabled") is False:
|
|
return False
|
|
provider = settings["stt_provider"]
|
|
if provider == "disabled":
|
|
return False
|
|
if provider == "browser":
|
|
return True # handled client-side
|
|
if provider == "local":
|
|
return self._get_whisper() is not None
|
|
if provider.startswith("endpoint:"):
|
|
return True # assume reachable
|
|
return False
|
|
|
|
# ── Local Whisper ──
|
|
|
|
def _get_whisper(self):
|
|
if self._whisper_model is None:
|
|
try:
|
|
from faster_whisper import WhisperModel
|
|
except ImportError:
|
|
logger.warning("faster-whisper not installed. Install with: pip install faster-whisper")
|
|
return None
|
|
try:
|
|
settings = self._load_settings()
|
|
model_size = settings.get("stt_model", "base")
|
|
# faster-whisper runs on CTranslate2, not torch. torch is only
|
|
# used (optionally) to detect a CUDA device for acceleration —
|
|
# if it's missing or unusable we just run on CPU. Keeping this
|
|
# probe separate (and tolerant of any failure, e.g. a broken
|
|
# CUDA/torch install that raises OSError on import) means a
|
|
# torch-less or torch-broken machine still does CPU
|
|
# transcription instead of failing with a misleading
|
|
# "faster-whisper not installed" error.
|
|
try:
|
|
import torch
|
|
use_cuda = torch.cuda.is_available()
|
|
except Exception:
|
|
use_cuda = False
|
|
device = "cuda" if use_cuda else "cpu"
|
|
compute_type = "float16" if device == "cuda" else "int8"
|
|
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
|
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load whisper model: {e}")
|
|
return None
|
|
return self._whisper_model
|
|
|
|
def _transcribe_local(self, audio_bytes: bytes, language: str = "") -> Optional[str]:
|
|
model = self._get_whisper()
|
|
if not model:
|
|
return None
|
|
tmp_path = None
|
|
try:
|
|
# Write to temp file (faster-whisper needs a file path or file-like)
|
|
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
|
|
tmp.write(audio_bytes)
|
|
tmp_path = tmp.name
|
|
|
|
kwargs = {}
|
|
if language:
|
|
kwargs["language"] = language
|
|
|
|
segments, info = model.transcribe(tmp_path, **kwargs)
|
|
text = " ".join(seg.text.strip() for seg in segments)
|
|
|
|
logger.info(f"Local STT: {len(text)} chars, lang={info.language}, prob={info.language_probability:.2f}")
|
|
return text
|
|
except Exception as e:
|
|
logger.error(f"Local STT transcription failed: {e}", exc_info=True)
|
|
return None
|
|
finally:
|
|
if tmp_path:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
# ── API endpoint ──
|
|
|
|
def _transcribe_api(self, audio_bytes: bytes, endpoint_id: str, model: str, language: str = "") -> Optional[str]:
|
|
from src.database import SessionLocal, ModelEndpoint
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
|
if not ep:
|
|
logger.error(f"STT endpoint {endpoint_id} not found")
|
|
return None
|
|
base_url = ep.base_url.rstrip("/")
|
|
api_key = ep.api_key
|
|
finally:
|
|
db.close()
|
|
|
|
url = base_url + "/audio/transcriptions"
|
|
headers = {}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
files = {"file": ("audio.webm", io.BytesIO(audio_bytes), "audio/webm")}
|
|
data = {"model": model or "whisper-1"}
|
|
if language:
|
|
data["language"] = language
|
|
|
|
try:
|
|
r = httpx.post(url, headers=headers, files=files, data=data, timeout=60)
|
|
r.raise_for_status()
|
|
result = r.json()
|
|
text = result.get("text", "")
|
|
logger.info(f"API STT: {len(text)} chars from {base_url}")
|
|
return text
|
|
except Exception as e:
|
|
logger.error(f"API STT transcription failed: {e}")
|
|
return None
|
|
|
|
# ── Public interface ──
|
|
|
|
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
|
settings = self._load_settings()
|
|
if settings.get("stt_enabled") is False:
|
|
return None
|
|
provider = settings["stt_provider"]
|
|
model = settings["stt_model"]
|
|
language = settings.get("stt_language", "")
|
|
|
|
if provider in ("disabled", "browser"):
|
|
return None
|
|
|
|
if provider == "local":
|
|
return self._transcribe_local(audio_bytes, language)
|
|
elif provider.startswith("endpoint:"):
|
|
endpoint_id = provider.split(":", 1)[1]
|
|
return self._transcribe_api(audio_bytes, endpoint_id, model, language)
|
|
else:
|
|
logger.error(f"Unknown STT provider: {provider}")
|
|
return None
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
settings = self._load_settings()
|
|
provider = settings["stt_provider"]
|
|
stt_enabled = settings.get("stt_enabled", False)
|
|
# If toggle is off, report as disabled
|
|
effective_provider = provider if stt_enabled else "disabled"
|
|
|
|
stats = {
|
|
"available": self.available and stt_enabled,
|
|
"provider": effective_provider,
|
|
"model": settings["stt_model"],
|
|
"language": settings.get("stt_language", ""),
|
|
}
|
|
|
|
if provider == "local":
|
|
whisper = self._get_whisper()
|
|
stats["model_loaded"] = whisper is not None
|
|
elif provider == "browser":
|
|
stats["model"] = "Browser (Web Speech API)"
|
|
elif provider.startswith("endpoint:"):
|
|
stats["endpoint_id"] = provider.split(":", 1)[1]
|
|
|
|
return stats
|
|
|
|
|
|
# Module-level singleton
|
|
_stt_service = None
|
|
|
|
def get_stt_service() -> STTService:
|
|
global _stt_service
|
|
if _stt_service is None:
|
|
_stt_service = STTService()
|
|
return _stt_service
|