TTSService._put_cache writes .mp3 for MP3 audio (ID3/MPEG-framed bytes) and
.wav otherwise, and the rest of the class treats both as cache entries
(_get_cache iterates (".mp3", ".wav"); eviction globs "*.*"). But
get_stats() enumerated the cache with `glob("*.wav")` only, so both
cache_entries and cache_size_mb undercounted — reporting 0 whenever the
cache held MP3 files, which is the common case for most TTS providers.
Glob both extensions so the reported stats match what's actually cached.
tests/test_tts_cache_stats.py writes an MP3-headed blob via _put_cache and
asserts get_stats() reports one entry with non-zero size. Fails before this
change.
284 lines
9.4 KiB
Python
284 lines
9.4 KiB
Python
# src/tts_service.py
|
|
"""Multi-provider TTS service — dispatches to local Kokoro, OpenAI-compatible API, or browser."""
|
|
|
|
import io
|
|
import wave
|
|
import logging
|
|
import hashlib
|
|
import httpx
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TTSService:
|
|
"""Multi-provider TTS service.
|
|
|
|
Reads provider config from data/settings.json on each call.
|
|
Providers:
|
|
"disabled" — no TTS
|
|
"browser" — client-side Web Speech API (no server synthesis)
|
|
"local" — Kokoro-82M on GPU
|
|
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
|
"""
|
|
|
|
def __init__(self, cache_dir: str = "data/tts_cache"):
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self._kokoro = None # lazy-init
|
|
|
|
# ── Settings ──
|
|
|
|
def _load_settings(self) -> dict:
|
|
from src.settings import load_settings
|
|
saved = load_settings()
|
|
return {
|
|
"tts_enabled": saved.get("tts_enabled", True),
|
|
"tts_provider": saved.get("tts_provider", "disabled"),
|
|
"tts_model": saved.get("tts_model", "tts-1"),
|
|
"tts_voice": saved.get("tts_voice", "alloy"),
|
|
"tts_speed": saved.get("tts_speed", "1"),
|
|
}
|
|
|
|
@property
|
|
def available(self) -> bool:
|
|
settings = self._load_settings()
|
|
if settings.get("tts_enabled") is False:
|
|
return False
|
|
provider = settings["tts_provider"]
|
|
if provider == "disabled":
|
|
return False
|
|
if provider == "browser":
|
|
return True # handled client-side
|
|
if provider == "local":
|
|
kokoro = self._get_kokoro()
|
|
return kokoro is not None and kokoro.available
|
|
if provider.startswith("endpoint:"):
|
|
return True # assume reachable; errors surface at synthesis time
|
|
return False
|
|
|
|
# ── Cache ──
|
|
|
|
def _cache_key(self, text: str, provider: str, model: str, voice: str, speed: float = 1.0) -> str:
|
|
raw = f"{provider}|{model}|{voice}|{speed}|{text}"
|
|
return hashlib.sha256(raw.encode()).hexdigest()
|
|
|
|
def _get_cached(self, key: str) -> Optional[bytes]:
|
|
for ext in (".mp3", ".wav"):
|
|
path = self.cache_dir / f"{key}{ext}"
|
|
if path.exists():
|
|
return path.read_bytes()
|
|
return None
|
|
|
|
def _put_cache(self, key: str, data: bytes):
|
|
ext = ".mp3" if (len(data) >= 3 and (data[:3] == b'ID3' or (data[0] == 0xff and (data[1] & 0xe0) == 0xe0))) else ".wav"
|
|
(self.cache_dir / f"{key}{ext}").write_bytes(data)
|
|
|
|
def clear_cache(self):
|
|
count = 0
|
|
for f in self.cache_dir.glob("*.*"):
|
|
f.unlink()
|
|
count += 1
|
|
logger.info(f"Cleared {count} cached TTS files")
|
|
|
|
# ── Kokoro (local) ──
|
|
|
|
def _get_kokoro(self):
|
|
if self._kokoro is None:
|
|
self._kokoro = _KokoroPipeline()
|
|
return self._kokoro
|
|
|
|
# ── API endpoint ──
|
|
|
|
def _synthesize_api(self, text: str, endpoint_id: str, model: str, voice: str, speed: float = 1.0) -> Optional[bytes]:
|
|
from src.database import SessionLocal, ModelEndpoint
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
|
if not ep:
|
|
logger.error(f"TTS endpoint {endpoint_id} not found")
|
|
return None
|
|
base_url = ep.base_url.rstrip("/")
|
|
api_key = ep.api_key
|
|
finally:
|
|
db.close()
|
|
|
|
url = base_url + "/audio/speech"
|
|
headers = {"Content-Type": "application/json"}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"input": text,
|
|
"voice": voice,
|
|
"response_format": "mp3",
|
|
"speed": speed,
|
|
}
|
|
|
|
try:
|
|
r = httpx.post(url, json=payload, headers=headers, timeout=60)
|
|
r.raise_for_status()
|
|
logger.info(f"API TTS: {len(r.content)} bytes from {base_url}")
|
|
return r.content
|
|
except Exception as e:
|
|
logger.error(f"API TTS synthesis failed: {e}")
|
|
return None
|
|
|
|
# ── Public interface ──
|
|
|
|
def synthesize(self, text: str, use_cache: bool = True) -> Optional[bytes]:
|
|
settings = self._load_settings()
|
|
if settings.get("tts_enabled") is False:
|
|
return None
|
|
provider = settings["tts_provider"]
|
|
model = settings["tts_model"]
|
|
voice = settings["tts_voice"]
|
|
speed = float(settings.get("tts_speed", "1"))
|
|
|
|
if provider in ("disabled", "browser"):
|
|
return None
|
|
|
|
if len(text) > 5000:
|
|
text = text[:5000]
|
|
|
|
if use_cache:
|
|
key = self._cache_key(text, provider, model, voice, speed)
|
|
cached = self._get_cached(key)
|
|
if cached:
|
|
logger.info(f"TTS cache hit ({len(text)} chars)")
|
|
return cached
|
|
|
|
audio_data = None
|
|
|
|
if provider == "local":
|
|
kokoro = self._get_kokoro()
|
|
if kokoro and kokoro.available:
|
|
audio_data = kokoro.synthesize_raw(text, voice)
|
|
else:
|
|
logger.warning("Kokoro TTS not available")
|
|
return None
|
|
elif provider.startswith("endpoint:"):
|
|
endpoint_id = provider.split(":", 1)[1]
|
|
audio_data = self._synthesize_api(text, endpoint_id, model, voice, speed)
|
|
else:
|
|
logger.error(f"Unknown TTS provider: {provider}")
|
|
return None
|
|
|
|
if audio_data and use_cache:
|
|
key = self._cache_key(text, provider, model, voice, speed)
|
|
self._put_cache(key, audio_data)
|
|
|
|
return audio_data
|
|
|
|
def synthesize_to_base64(self, text: str) -> Optional[str]:
|
|
import base64
|
|
audio = self.synthesize(text)
|
|
if audio:
|
|
return base64.b64encode(audio).decode("utf-8")
|
|
return None
|
|
|
|
def set_voice(self, voice: str):
|
|
"""Legacy no-op — voice is now managed via admin settings."""
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
settings = self._load_settings()
|
|
provider = settings["tts_provider"]
|
|
tts_enabled = settings.get("tts_enabled", True)
|
|
|
|
cache_files = list(self.cache_dir.glob("*.wav")) + list(self.cache_dir.glob("*.mp3"))
|
|
cache_size = sum(f.stat().st_size for f in cache_files)
|
|
|
|
is_available = self.available and tts_enabled
|
|
stats = {
|
|
"available": is_available,
|
|
"ready": is_available,
|
|
"provider": provider,
|
|
"model": settings["tts_model"],
|
|
"voice": settings["tts_voice"],
|
|
"speed": float(settings.get("tts_speed", "1")),
|
|
"cache_entries": len(cache_files),
|
|
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
|
}
|
|
|
|
if provider == "local":
|
|
kokoro = self._get_kokoro()
|
|
stats["model"] = "Kokoro-82M (GPU)" if (kokoro and kokoro.available) else "Kokoro (not loaded)"
|
|
elif provider == "browser":
|
|
stats["model"] = "Browser (Web Speech API)"
|
|
elif provider.startswith("endpoint:"):
|
|
stats["endpoint_id"] = provider.split(":", 1)[1]
|
|
|
|
return stats
|
|
|
|
|
|
class _KokoroPipeline:
|
|
"""Encapsulates the Kokoro-82M local GPU pipeline."""
|
|
|
|
def __init__(self):
|
|
self.pipeline = None
|
|
self.available = False
|
|
self.device = None
|
|
self._init()
|
|
|
|
def _init(self):
|
|
try:
|
|
import torch
|
|
from kokoro import KPipeline
|
|
|
|
if not torch.cuda.is_available():
|
|
logger.warning("CUDA not available for Kokoro TTS")
|
|
return
|
|
|
|
self.device = torch.device("cuda:0")
|
|
with torch.cuda.device(0):
|
|
self.pipeline = KPipeline(lang_code="a")
|
|
if hasattr(self.pipeline, "model"):
|
|
self.pipeline.model = self.pipeline.model.to(self.device)
|
|
self.available = True
|
|
logger.info("Kokoro-82M TTS pipeline loaded")
|
|
except ImportError as e:
|
|
logger.warning(f"Kokoro TTS not available: {e}")
|
|
logger.warning("Install with: pip install kokoro soundfile")
|
|
except Exception as e:
|
|
logger.error(f"Kokoro init failed: {e}", exc_info=True)
|
|
|
|
def synthesize_raw(self, text: str, voice: str = "af_heart") -> Optional[bytes]:
|
|
if not self.available:
|
|
return None
|
|
try:
|
|
import torch
|
|
import numpy as np
|
|
|
|
with torch.cuda.device(self.device):
|
|
chunks = []
|
|
for _, _, audio in self.pipeline(text, voice=voice):
|
|
chunks.append(audio)
|
|
|
|
if not chunks:
|
|
return None
|
|
|
|
full = np.concatenate(chunks)
|
|
buf = io.BytesIO()
|
|
with wave.open(buf, "wb") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(24000)
|
|
wf.writeframes((full * 32767).astype(np.int16).tobytes())
|
|
return buf.getvalue()
|
|
except Exception as e:
|
|
logger.error(f"Kokoro synthesis failed: {e}", exc_info=True)
|
|
return None
|
|
|
|
|
|
# Module-level singleton
|
|
_tts_service = None
|
|
|
|
def get_tts_service() -> TTSService:
|
|
global _tts_service
|
|
if _tts_service is None:
|
|
_tts_service = TTSService()
|
|
return _tts_service
|