diff --git a/routes/model_routes.py b/routes/model_routes.py index 0fbf8e5..a7e364d 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -33,6 +33,19 @@ _SPEECH_ENDPOINT_SETTINGS = ( ("stt_provider", "stt_model", "base", "Speech to Text"), ) +_ENDPOINT_SETTING_FIELDS = { + "default_endpoint_id": ("default_model", "Default Model"), + "utility_endpoint_id": ("utility_model", "Utility Model"), + "research_endpoint_id": ("research_model", "Deep Research"), + "task_endpoint_id": ("task_model", "Background Tasks"), +} + +_ENDPOINT_FALLBACK_FIELDS = { + "default_model_fallbacks": "Default Model Fallbacks", + "utility_model_fallbacks": "Utility Model Fallbacks", + "vision_model_fallbacks": "Vision Model Fallbacks", +} + def _speech_settings_using_endpoint(settings: dict, ep_id: str) -> list: """Return speech settings that reference a model endpoint.""" @@ -56,6 +69,58 @@ def _clear_speech_settings_for_endpoint(settings: dict, ep_id: str) -> list: return cleared +def _endpoint_settings_using_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list: + """Return labels for settings and fallback chains that reference an endpoint.""" + affected = [] + for ep_key, (_, label) in _ENDPOINT_SETTING_FIELDS.items(): + if (settings.get(ep_key) or "") == ep_id: + affected.append(label) + for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items(): + chain = settings.get(fallback_key) or [] + if any(isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id for entry in chain): + affected.append(label) + if include_speech: + affected.extend(_speech_settings_using_endpoint(settings, ep_id)) + return affected + + +def _clear_endpoint_settings_for_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list: + """Remove an endpoint from direct settings and model fallback chains.""" + cleared = [] + for ep_key, (model_key, label) in _ENDPOINT_SETTING_FIELDS.items(): + if (settings.get(ep_key) or "") == ep_id: + settings[ep_key] = "" + settings[model_key] = "" + cleared.append(label) + for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items(): + chain = settings.get(fallback_key) + if not isinstance(chain, list): + continue + kept = [ + entry for entry in chain + if not (isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id) + ] + if len(kept) != len(chain): + settings[fallback_key] = kept + cleared.append(label) + if include_speech: + cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id)) + return cleared + + +def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int: + """Remove endpoint references from scoped or legacy-flat user preferences.""" + if not isinstance(all_prefs, dict): + return 0 + users = all_prefs.get("_users") + pref_sets = users.values() if isinstance(users, dict) else [all_prefs] + cleared_users = 0 + for prefs in pref_sets: + if isinstance(prefs, dict) and _clear_endpoint_settings_for_endpoint(prefs, ep_id): + cleared_users += 1 + return cleared_users + + # Loopback hosts a user might type for a local model server (LM Studio, # llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the # host the server actually runs on. @@ -1454,38 +1519,31 @@ def setup_model_routes(model_discovery): finally: db.close() - # ── Settings fields that store an endpoint ID ── - _EP_SETTING_FIELDS = { - "default_endpoint_id": ("default_model", "Default Model"), - "utility_endpoint_id": ("utility_model", "Utility Model"), - "research_endpoint_id": ("research_model", "Deep Research"), - "task_endpoint_id": ("task_model", "Background Tasks"), - } - def _settings_using_endpoint(ep_id: str) -> list: """Return human-readable labels for settings that reference this endpoint.""" - settings = _load_settings() - affected = [] - for ep_key, (_, label) in _EP_SETTING_FIELDS.items(): - if (settings.get(ep_key) or "") == ep_id: - affected.append(label) - affected.extend(_speech_settings_using_endpoint(settings, ep_id)) - return affected + return _endpoint_settings_using_endpoint(_load_settings(), ep_id, include_speech=True) def _clear_settings_for_endpoint(ep_id: str) -> list: """Clear all settings that reference this endpoint. Returns list of cleared labels.""" settings = _load_settings() - cleared = [] - for ep_key, (model_key, label) in _EP_SETTING_FIELDS.items(): - if (settings.get(ep_key) or "") == ep_id: - settings[ep_key] = "" - settings[model_key] = "" - cleared.append(label) - cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id)) + cleared = _clear_endpoint_settings_for_endpoint(settings, ep_id, include_speech=True) if cleared: _save_settings(settings) return cleared + def _clear_user_prefs_for_endpoint(ep_id: str) -> int: + """Clear per-user endpoint selections and fallback chains.""" + try: + from routes.prefs_routes import _load as _load_prefs, _save as _save_prefs + all_prefs = _load_prefs() + cleared_users = _clear_user_pref_endpoint_refs(all_prefs, ep_id) + if cleared_users: + _save_prefs(all_prefs) + return cleared_users + except Exception as e: + logger.warning("Failed to clear user prefs for endpoint %s: %s", ep_id, e) + return 0 + def _session_uses_endpoint_url(session_url: str, base_url: str) -> bool: if not session_url or not base_url: return False @@ -1550,6 +1608,7 @@ def setup_model_routes(model_discovery): raise HTTPException(404, "Endpoint not found") # Clean up any settings that reference this endpoint cleared = _clear_settings_for_endpoint(ep_id) + cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id) cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url) cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url) db.delete(ep) @@ -1559,6 +1618,7 @@ def setup_model_routes(model_discovery): return { "deleted": True, "cleared_settings": cleared, + "cleared_user_preferences": cleared_user_preferences, "cleared_sessions": cleared_sessions, "cleared_loaded_sessions": cleared_loaded_sessions, } diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index b251049..be767e4 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -35,6 +35,9 @@ from routes.model_routes import ( _truthy, _speech_settings_using_endpoint, _clear_speech_settings_for_endpoint, + _endpoint_settings_using_endpoint, + _clear_endpoint_settings_for_endpoint, + _clear_user_pref_endpoint_refs, _PROVIDER_CURATED, ) from src.llm_core import ANTHROPIC_MODELS @@ -67,6 +70,74 @@ def test_clear_speech_endpoint_settings_resets_tts_and_stt(): } +def test_endpoint_cleanup_removes_primary_and_fallback_references(): + settings = { + "default_endpoint_id": "dead", + "default_model": "primary", + "default_model_fallbacks": [ + {"endpoint_id": "dead", "model": "fallback-a"}, + {"endpoint_id": "keep", "model": "fallback-b"}, + ], + "utility_model_fallbacks": [{"endpoint_id": "dead", "model": "utility"}], + "vision_model_fallbacks": [{"endpoint_id": "dead", "model": "vision"}], + "stt_provider": "endpoint:dead", + "stt_model": "whisper", + } + + assert _endpoint_settings_using_endpoint(settings, "dead", include_speech=True) == [ + "Default Model", + "Default Model Fallbacks", + "Utility Model Fallbacks", + "Vision Model Fallbacks", + "Speech to Text", + ] + assert _clear_endpoint_settings_for_endpoint(settings, "dead", include_speech=True) == [ + "Default Model", + "Default Model Fallbacks", + "Utility Model Fallbacks", + "Vision Model Fallbacks", + "Speech to Text", + ] + assert settings["default_endpoint_id"] == "" + assert settings["default_model"] == "" + assert settings["default_model_fallbacks"] == [ + {"endpoint_id": "keep", "model": "fallback-b"}, + ] + assert settings["utility_model_fallbacks"] == [] + assert settings["vision_model_fallbacks"] == [] + assert settings["stt_provider"] == "disabled" + assert settings["stt_model"] == "base" + + +def test_endpoint_cleanup_updates_scoped_and_legacy_user_prefs(): + scoped = { + "_users": { + "alice": { + "utility_endpoint_id": "dead", + "utility_model": "utility", + "vision_model_fallbacks": [{"endpoint_id": "dead", "model": "vision"}], + }, + "bob": { + "default_endpoint_id": "keep", + "default_model": "chat", + }, + }, + } + assert _clear_user_pref_endpoint_refs(scoped, "dead") == 1 + assert scoped["_users"]["alice"] == { + "utility_endpoint_id": "", + "utility_model": "", + "vision_model_fallbacks": [], + } + assert scoped["_users"]["bob"]["default_endpoint_id"] == "keep" + + legacy = { + "default_model_fallbacks": [{"endpoint_id": "dead", "model": "chat"}], + } + assert _clear_user_pref_endpoint_refs(legacy, "dead") == 1 + assert legacy["default_model_fallbacks"] == [] + + # ── _match_provider_curated ── class TestMatchProviderCurated: