fix(models): clear stale speech endpoint settings (#1196)
This commit is contained in:
@@ -28,6 +28,33 @@ from src.auth_helpers import _auth_disabled, owner_filter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SPEECH_ENDPOINT_SETTINGS = (
|
||||
("tts_provider", "tts_model", "tts-1", "Text to Speech"),
|
||||
("stt_provider", "stt_model", "base", "Speech to Text"),
|
||||
)
|
||||
|
||||
|
||||
def _speech_settings_using_endpoint(settings: dict, ep_id: str) -> list:
|
||||
"""Return speech settings that reference a model endpoint."""
|
||||
endpoint_ref = f"endpoint:{ep_id}"
|
||||
return [
|
||||
label
|
||||
for provider_key, _, _, label in _SPEECH_ENDPOINT_SETTINGS
|
||||
if (settings.get(provider_key) or "") == endpoint_ref
|
||||
]
|
||||
|
||||
|
||||
def _clear_speech_settings_for_endpoint(settings: dict, ep_id: str) -> list:
|
||||
"""Reset speech settings that reference a model endpoint."""
|
||||
endpoint_ref = f"endpoint:{ep_id}"
|
||||
cleared = []
|
||||
for provider_key, model_key, default_model, label in _SPEECH_ENDPOINT_SETTINGS:
|
||||
if (settings.get(provider_key) or "") == endpoint_ref:
|
||||
settings[provider_key] = "disabled"
|
||||
settings[model_key] = default_model
|
||||
cleared.append(label)
|
||||
return cleared
|
||||
|
||||
|
||||
# Loopback hosts a user might type for a local model server (LM Studio,
|
||||
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
|
||||
@@ -1442,9 +1469,7 @@ def setup_model_routes(model_discovery):
|
||||
for ep_key, (_, label) in _EP_SETTING_FIELDS.items():
|
||||
if (settings.get(ep_key) or "") == ep_id:
|
||||
affected.append(label)
|
||||
tts_prov = settings.get("tts_provider") or ""
|
||||
if tts_prov == f"endpoint:{ep_id}":
|
||||
affected.append("Text to Speech")
|
||||
affected.extend(_speech_settings_using_endpoint(settings, ep_id))
|
||||
return affected
|
||||
|
||||
def _clear_settings_for_endpoint(ep_id: str) -> list:
|
||||
@@ -1456,11 +1481,7 @@ def setup_model_routes(model_discovery):
|
||||
settings[ep_key] = ""
|
||||
settings[model_key] = ""
|
||||
cleared.append(label)
|
||||
tts_prov = settings.get("tts_provider") or ""
|
||||
if tts_prov == f"endpoint:{ep_id}":
|
||||
settings["tts_provider"] = "disabled"
|
||||
settings["tts_model"] = "tts-1"
|
||||
cleared.append("Text to Speech")
|
||||
cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id))
|
||||
if cleared:
|
||||
_save_settings(settings)
|
||||
return cleared
|
||||
|
||||
@@ -33,11 +33,40 @@ from routes.model_routes import (
|
||||
_classify_endpoint,
|
||||
_probe_endpoint,
|
||||
_truthy,
|
||||
_speech_settings_using_endpoint,
|
||||
_clear_speech_settings_for_endpoint,
|
||||
_PROVIDER_CURATED,
|
||||
)
|
||||
from src.llm_core import ANTHROPIC_MODELS
|
||||
|
||||
|
||||
# ── speech endpoint settings ──
|
||||
|
||||
def test_speech_endpoint_dependents_include_stt():
|
||||
settings = {"stt_provider": "endpoint:voice"}
|
||||
assert _speech_settings_using_endpoint(settings, "voice") == ["Speech to Text"]
|
||||
|
||||
|
||||
def test_clear_speech_endpoint_settings_resets_tts_and_stt():
|
||||
settings = {
|
||||
"tts_provider": "endpoint:voice",
|
||||
"tts_model": "custom-tts",
|
||||
"stt_provider": "endpoint:voice",
|
||||
"stt_model": "custom-stt",
|
||||
}
|
||||
|
||||
assert _clear_speech_settings_for_endpoint(settings, "voice") == [
|
||||
"Text to Speech",
|
||||
"Speech to Text",
|
||||
]
|
||||
assert settings == {
|
||||
"tts_provider": "disabled",
|
||||
"tts_model": "tts-1",
|
||||
"stt_provider": "disabled",
|
||||
"stt_model": "base",
|
||||
}
|
||||
|
||||
|
||||
# ── _match_provider_curated ──
|
||||
|
||||
class TestMatchProviderCurated:
|
||||
|
||||
Reference in New Issue
Block a user