Chat: use cached endpoint model ids before probing
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
@@ -11,6 +12,7 @@ from core.models import ChatMessage
|
||||
from core.database import SessionLocal
|
||||
from core.database import Session as DBSession, ModelEndpoint
|
||||
from src.llm_core import normalize_model_id
|
||||
from src.endpoint_resolver import normalize_base
|
||||
from src.context_compactor import maybe_compact, trim_for_context
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.prompt_security import untrusted_context_message
|
||||
@@ -337,6 +339,63 @@ def resolve_session_auth(sess, session_id: str):
|
||||
logger.warning(f"Failed to resolve session headers: {e}")
|
||||
|
||||
|
||||
def _match_cached_model_id(requested: str, models) -> Optional[str]:
|
||||
if not requested or not models:
|
||||
return None
|
||||
model_ids = [str(m) for m in models if m]
|
||||
if requested in model_ids:
|
||||
return requested
|
||||
|
||||
req_base = os.path.basename(requested.rstrip("/"))
|
||||
for model_id in model_ids:
|
||||
if os.path.basename(model_id.rstrip("/")) == req_base:
|
||||
return model_id
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||
"""Use stored endpoint model IDs before falling back to a live /models probe."""
|
||||
endpoint_url = getattr(sess, "endpoint_url", "") or ""
|
||||
requested = getattr(sess, "model", "") or ""
|
||||
if not endpoint_url or not requested:
|
||||
return None
|
||||
|
||||
try:
|
||||
session_base = normalize_base(endpoint_url)
|
||||
except Exception:
|
||||
session_base = endpoint_url.rstrip("/")
|
||||
if not session_base:
|
||||
return None
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in endpoints:
|
||||
try:
|
||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raw_models = getattr(ep, "cached_models", None)
|
||||
if not raw_models:
|
||||
continue
|
||||
try:
|
||||
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
matched = _match_cached_model_id(requested, models)
|
||||
if matched:
|
||||
return matched
|
||||
except Exception as e:
|
||||
logger.debug("Cached model normalization skipped: %s", e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def build_chat_context(
|
||||
sess,
|
||||
request,
|
||||
@@ -437,8 +496,9 @@ async def build_chat_context(
|
||||
for transcript in preprocessed.youtube_transcripts:
|
||||
preface.append(untrusted_context_message("youtube transcript", transcript))
|
||||
|
||||
# Normalize model ID
|
||||
norm = normalize_model_id(sess.endpoint_url, sess.model)
|
||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||
# re-hit slow local /models endpoints on every participant turn.
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
|
||||
20
tests/test_chat_cached_model_normalization.py
Normal file
20
tests/test_chat_cached_model_normalization.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def test_chat_context_uses_cached_models_before_live_model_probe():
|
||||
source = (ROOT / "routes" / "chat_helpers.py").read_text()
|
||||
|
||||
assert "def _normalize_model_id_from_cache" in source
|
||||
assert "cached_models" in source
|
||||
assert "norm = _normalize_model_id_from_cache(sess) or normalize_model_id" in source
|
||||
|
||||
|
||||
def test_cached_model_match_keeps_basename_normalization():
|
||||
source = (ROOT / "routes" / "chat_helpers.py").read_text()
|
||||
|
||||
assert "def _match_cached_model_id" in source
|
||||
assert "os.path.basename(requested.rstrip(\"/\"))" in source
|
||||
assert "os.path.basename(model_id.rstrip(\"/\")) == req_base" in source
|
||||
Reference in New Issue
Block a user