Chat: use cached endpoint model ids before probing

This commit is contained in:
red person
2026-06-02 15:00:58 +03:00
committed by GitHub
parent 5029c8570e
commit fd89d098a1
2 changed files with 82 additions and 2 deletions

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional from typing import Any, Optional
@@ -11,6 +12,7 @@ from core.models import ChatMessage
from core.database import SessionLocal from core.database import SessionLocal
from core.database import Session as DBSession, ModelEndpoint from core.database import Session as DBSession, ModelEndpoint
from src.llm_core import normalize_model_id from src.llm_core import normalize_model_id
from src.endpoint_resolver import normalize_base
from src.context_compactor import maybe_compact, trim_for_context from src.context_compactor import maybe_compact, trim_for_context
from src.auth_helpers import get_current_user from src.auth_helpers import get_current_user
from src.prompt_security import untrusted_context_message from src.prompt_security import untrusted_context_message
@@ -337,6 +339,63 @@ def resolve_session_auth(sess, session_id: str):
logger.warning(f"Failed to resolve session headers: {e}") logger.warning(f"Failed to resolve session headers: {e}")
def _match_cached_model_id(requested: str, models) -> Optional[str]:
if not requested or not models:
return None
model_ids = [str(m) for m in models if m]
if requested in model_ids:
return requested
req_base = os.path.basename(requested.rstrip("/"))
for model_id in model_ids:
if os.path.basename(model_id.rstrip("/")) == req_base:
return model_id
return None
def _normalize_model_id_from_cache(sess) -> Optional[str]:
"""Use stored endpoint model IDs before falling back to a live /models probe."""
endpoint_url = getattr(sess, "endpoint_url", "") or ""
requested = getattr(sess, "model", "") or ""
if not endpoint_url or not requested:
return None
try:
session_base = normalize_base(endpoint_url)
except Exception:
session_base = endpoint_url.rstrip("/")
if not session_base:
return None
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in endpoints:
try:
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
continue
except Exception:
continue
raw_models = getattr(ep, "cached_models", None)
if not raw_models:
continue
try:
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
except Exception:
continue
matched = _match_cached_model_id(requested, models)
if matched:
return matched
except Exception as e:
logger.debug("Cached model normalization skipped: %s", e)
finally:
db.close()
return None
async def build_chat_context( async def build_chat_context(
sess, sess,
request, request,
@@ -437,8 +496,9 @@ async def build_chat_context(
for transcript in preprocessed.youtube_transcripts: for transcript in preprocessed.youtube_transcripts:
preface.append(untrusted_context_message("youtube transcript", transcript)) preface.append(untrusted_context_message("youtube transcript", transcript))
# Normalize model ID # Normalize model ID. Prefer cached endpoint models so group chat does not
norm = normalize_model_id(sess.endpoint_url, sess.model) # re-hit slow local /models endpoints on every participant turn.
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
if norm: if norm:
sess.model = norm sess.model = norm

View File

@@ -0,0 +1,20 @@
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
def test_chat_context_uses_cached_models_before_live_model_probe():
source = (ROOT / "routes" / "chat_helpers.py").read_text()
assert "def _normalize_model_id_from_cache" in source
assert "cached_models" in source
assert "norm = _normalize_model_id_from_cache(sess) or normalize_model_id" in source
def test_cached_model_match_keeps_basename_normalization():
source = (ROOT / "routes" / "chat_helpers.py").read_text()
assert "def _match_cached_model_id" in source
assert "os.path.basename(requested.rstrip(\"/\"))" in source
assert "os.path.basename(model_id.rstrip(\"/\")) == req_base" in source