Chat: use cached endpoint model ids before probing

This commit is contained in:
red person
2026-06-02 15:00:58 +03:00
committed by GitHub
parent 5029c8570e
commit fd89d098a1
2 changed files with 82 additions and 2 deletions

View File

@@ -3,6 +3,7 @@
import asyncio
import json
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Any, Optional
@@ -11,6 +12,7 @@ from core.models import ChatMessage
from core.database import SessionLocal
from core.database import Session as DBSession, ModelEndpoint
from src.llm_core import normalize_model_id
from src.endpoint_resolver import normalize_base
from src.context_compactor import maybe_compact, trim_for_context
from src.auth_helpers import get_current_user
from src.prompt_security import untrusted_context_message
@@ -337,6 +339,63 @@ def resolve_session_auth(sess, session_id: str):
logger.warning(f"Failed to resolve session headers: {e}")
def _match_cached_model_id(requested: str, models) -> Optional[str]:
if not requested or not models:
return None
model_ids = [str(m) for m in models if m]
if requested in model_ids:
return requested
req_base = os.path.basename(requested.rstrip("/"))
for model_id in model_ids:
if os.path.basename(model_id.rstrip("/")) == req_base:
return model_id
return None
def _normalize_model_id_from_cache(sess) -> Optional[str]:
"""Use stored endpoint model IDs before falling back to a live /models probe."""
endpoint_url = getattr(sess, "endpoint_url", "") or ""
requested = getattr(sess, "model", "") or ""
if not endpoint_url or not requested:
return None
try:
session_base = normalize_base(endpoint_url)
except Exception:
session_base = endpoint_url.rstrip("/")
if not session_base:
return None
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in endpoints:
try:
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
continue
except Exception:
continue
raw_models = getattr(ep, "cached_models", None)
if not raw_models:
continue
try:
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
except Exception:
continue
matched = _match_cached_model_id(requested, models)
if matched:
return matched
except Exception as e:
logger.debug("Cached model normalization skipped: %s", e)
finally:
db.close()
return None
async def build_chat_context(
sess,
request,
@@ -437,8 +496,9 @@ async def build_chat_context(
for transcript in preprocessed.youtube_transcripts:
preface.append(untrusted_context_message("youtube transcript", transcript))
# Normalize model ID
norm = normalize_model_id(sess.endpoint_url, sess.model)
# Normalize model ID. Prefer cached endpoint models so group chat does not
# re-hit slow local /models endpoints on every participant turn.
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
if norm:
sess.model = norm

View File

@@ -0,0 +1,20 @@
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
def test_chat_context_uses_cached_models_before_live_model_probe():
source = (ROOT / "routes" / "chat_helpers.py").read_text()
assert "def _normalize_model_id_from_cache" in source
assert "cached_models" in source
assert "norm = _normalize_model_id_from_cache(sess) or normalize_model_id" in source
def test_cached_model_match_keeps_basename_normalization():
source = (ROOT / "routes" / "chat_helpers.py").read_text()
assert "def _match_cached_model_id" in source
assert "os.path.basename(requested.rstrip(\"/\"))" in source
assert "os.path.basename(model_id.rstrip(\"/\")) == req_base" in source