From fd89d098a11b01def8ea3d4129310209fcf06e81 Mon Sep 17 00:00:00 2001 From: red person Date: Tue, 2 Jun 2026 15:00:58 +0300 Subject: [PATCH] Chat: use cached endpoint model ids before probing --- routes/chat_helpers.py | 64 ++++++++++++++++++- tests/test_chat_cached_model_normalization.py | 20 ++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 tests/test_chat_cached_model_normalization.py diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py index 06d886d..9b79ee0 100644 --- a/routes/chat_helpers.py +++ b/routes/chat_helpers.py @@ -3,6 +3,7 @@ import asyncio import json import logging +import os import re from dataclasses import dataclass, field from typing import Any, Optional @@ -11,6 +12,7 @@ from core.models import ChatMessage from core.database import SessionLocal from core.database import Session as DBSession, ModelEndpoint from src.llm_core import normalize_model_id +from src.endpoint_resolver import normalize_base from src.context_compactor import maybe_compact, trim_for_context from src.auth_helpers import get_current_user from src.prompt_security import untrusted_context_message @@ -337,6 +339,63 @@ def resolve_session_auth(sess, session_id: str): logger.warning(f"Failed to resolve session headers: {e}") +def _match_cached_model_id(requested: str, models) -> Optional[str]: + if not requested or not models: + return None + model_ids = [str(m) for m in models if m] + if requested in model_ids: + return requested + + req_base = os.path.basename(requested.rstrip("/")) + for model_id in model_ids: + if os.path.basename(model_id.rstrip("/")) == req_base: + return model_id + return None + + +def _normalize_model_id_from_cache(sess) -> Optional[str]: + """Use stored endpoint model IDs before falling back to a live /models probe.""" + endpoint_url = getattr(sess, "endpoint_url", "") or "" + requested = getattr(sess, "model", "") or "" + if not endpoint_url or not requested: + return None + + try: + session_base = normalize_base(endpoint_url) + except Exception: + session_base = endpoint_url.rstrip("/") + if not session_base: + return None + + db = SessionLocal() + try: + endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + for ep in endpoints: + try: + if normalize_base(getattr(ep, "base_url", "") or "") != session_base: + continue + except Exception: + continue + + raw_models = getattr(ep, "cached_models", None) + if not raw_models: + continue + try: + models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models + except Exception: + continue + + matched = _match_cached_model_id(requested, models) + if matched: + return matched + except Exception as e: + logger.debug("Cached model normalization skipped: %s", e) + finally: + db.close() + + return None + + async def build_chat_context( sess, request, @@ -437,8 +496,9 @@ async def build_chat_context( for transcript in preprocessed.youtube_transcripts: preface.append(untrusted_context_message("youtube transcript", transcript)) - # Normalize model ID - norm = normalize_model_id(sess.endpoint_url, sess.model) + # Normalize model ID. Prefer cached endpoint models so group chat does not + # re-hit slow local /models endpoints on every participant turn. + norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model) if norm: sess.model = norm diff --git a/tests/test_chat_cached_model_normalization.py b/tests/test_chat_cached_model_normalization.py new file mode 100644 index 0000000..b601f87 --- /dev/null +++ b/tests/test_chat_cached_model_normalization.py @@ -0,0 +1,20 @@ +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] + + +def test_chat_context_uses_cached_models_before_live_model_probe(): + source = (ROOT / "routes" / "chat_helpers.py").read_text() + + assert "def _normalize_model_id_from_cache" in source + assert "cached_models" in source + assert "norm = _normalize_model_id_from_cache(sess) or normalize_model_id" in source + + +def test_cached_model_match_keeps_basename_normalization(): + source = (ROOT / "routes" / "chat_helpers.py").read_text() + + assert "def _match_cached_model_id" in source + assert "os.path.basename(requested.rstrip(\"/\"))" in source + assert "os.path.basename(model_id.rstrip(\"/\")) == req_base" in source