diff --git a/routes/model_routes.py b/routes/model_routes.py index 28ec1d0..0fbf8e5 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -28,6 +28,33 @@ from src.auth_helpers import _auth_disabled, owner_filter logger = logging.getLogger(__name__) +_SPEECH_ENDPOINT_SETTINGS = ( + ("tts_provider", "tts_model", "tts-1", "Text to Speech"), + ("stt_provider", "stt_model", "base", "Speech to Text"), +) + + +def _speech_settings_using_endpoint(settings: dict, ep_id: str) -> list: + """Return speech settings that reference a model endpoint.""" + endpoint_ref = f"endpoint:{ep_id}" + return [ + label + for provider_key, _, _, label in _SPEECH_ENDPOINT_SETTINGS + if (settings.get(provider_key) or "") == endpoint_ref + ] + + +def _clear_speech_settings_for_endpoint(settings: dict, ep_id: str) -> list: + """Reset speech settings that reference a model endpoint.""" + endpoint_ref = f"endpoint:{ep_id}" + cleared = [] + for provider_key, model_key, default_model, label in _SPEECH_ENDPOINT_SETTINGS: + if (settings.get(provider_key) or "") == endpoint_ref: + settings[provider_key] = "disabled" + settings[model_key] = default_model + cleared.append(label) + return cleared + # Loopback hosts a user might type for a local model server (LM Studio, # llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the @@ -1442,9 +1469,7 @@ def setup_model_routes(model_discovery): for ep_key, (_, label) in _EP_SETTING_FIELDS.items(): if (settings.get(ep_key) or "") == ep_id: affected.append(label) - tts_prov = settings.get("tts_provider") or "" - if tts_prov == f"endpoint:{ep_id}": - affected.append("Text to Speech") + affected.extend(_speech_settings_using_endpoint(settings, ep_id)) return affected def _clear_settings_for_endpoint(ep_id: str) -> list: @@ -1456,11 +1481,7 @@ def setup_model_routes(model_discovery): settings[ep_key] = "" settings[model_key] = "" cleared.append(label) - tts_prov = settings.get("tts_provider") or "" - if tts_prov == f"endpoint:{ep_id}": - settings["tts_provider"] = "disabled" - settings["tts_model"] = "tts-1" - cleared.append("Text to Speech") + cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id)) if cleared: _save_settings(settings) return cleared diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index d9ca5e3..b251049 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -33,11 +33,40 @@ from routes.model_routes import ( _classify_endpoint, _probe_endpoint, _truthy, + _speech_settings_using_endpoint, + _clear_speech_settings_for_endpoint, _PROVIDER_CURATED, ) from src.llm_core import ANTHROPIC_MODELS +# ── speech endpoint settings ── + +def test_speech_endpoint_dependents_include_stt(): + settings = {"stt_provider": "endpoint:voice"} + assert _speech_settings_using_endpoint(settings, "voice") == ["Speech to Text"] + + +def test_clear_speech_endpoint_settings_resets_tts_and_stt(): + settings = { + "tts_provider": "endpoint:voice", + "tts_model": "custom-tts", + "stt_provider": "endpoint:voice", + "stt_model": "custom-stt", + } + + assert _clear_speech_settings_for_endpoint(settings, "voice") == [ + "Text to Speech", + "Speech to Text", + ] + assert settings == { + "tts_provider": "disabled", + "tts_model": "tts-1", + "stt_provider": "disabled", + "stt_model": "base", + } + + # ── _match_provider_curated ── class TestMatchProviderCurated: