Chat: use cached endpoint model ids before probing
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -11,6 +12,7 @@ from core.models import ChatMessage
|
|||||||
from core.database import SessionLocal
|
from core.database import SessionLocal
|
||||||
from core.database import Session as DBSession, ModelEndpoint
|
from core.database import Session as DBSession, ModelEndpoint
|
||||||
from src.llm_core import normalize_model_id
|
from src.llm_core import normalize_model_id
|
||||||
|
from src.endpoint_resolver import normalize_base
|
||||||
from src.context_compactor import maybe_compact, trim_for_context
|
from src.context_compactor import maybe_compact, trim_for_context
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
from src.prompt_security import untrusted_context_message
|
from src.prompt_security import untrusted_context_message
|
||||||
@@ -337,6 +339,63 @@ def resolve_session_auth(sess, session_id: str):
|
|||||||
logger.warning(f"Failed to resolve session headers: {e}")
|
logger.warning(f"Failed to resolve session headers: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _match_cached_model_id(requested: str, models) -> Optional[str]:
|
||||||
|
if not requested or not models:
|
||||||
|
return None
|
||||||
|
model_ids = [str(m) for m in models if m]
|
||||||
|
if requested in model_ids:
|
||||||
|
return requested
|
||||||
|
|
||||||
|
req_base = os.path.basename(requested.rstrip("/"))
|
||||||
|
for model_id in model_ids:
|
||||||
|
if os.path.basename(model_id.rstrip("/")) == req_base:
|
||||||
|
return model_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||||
|
"""Use stored endpoint model IDs before falling back to a live /models probe."""
|
||||||
|
endpoint_url = getattr(sess, "endpoint_url", "") or ""
|
||||||
|
requested = getattr(sess, "model", "") or ""
|
||||||
|
if not endpoint_url or not requested:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_base = normalize_base(endpoint_url)
|
||||||
|
except Exception:
|
||||||
|
session_base = endpoint_url.rstrip("/")
|
||||||
|
if not session_base:
|
||||||
|
return None
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||||
|
for ep in endpoints:
|
||||||
|
try:
|
||||||
|
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_models = getattr(ep, "cached_models", None)
|
||||||
|
if not raw_models:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matched = _match_cached_model_id(requested, models)
|
||||||
|
if matched:
|
||||||
|
return matched
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Cached model normalization skipped: %s", e)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def build_chat_context(
|
async def build_chat_context(
|
||||||
sess,
|
sess,
|
||||||
request,
|
request,
|
||||||
@@ -437,8 +496,9 @@ async def build_chat_context(
|
|||||||
for transcript in preprocessed.youtube_transcripts:
|
for transcript in preprocessed.youtube_transcripts:
|
||||||
preface.append(untrusted_context_message("youtube transcript", transcript))
|
preface.append(untrusted_context_message("youtube transcript", transcript))
|
||||||
|
|
||||||
# Normalize model ID
|
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||||
norm = normalize_model_id(sess.endpoint_url, sess.model)
|
# re-hit slow local /models endpoints on every participant turn.
|
||||||
|
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||||
if norm:
|
if norm:
|
||||||
sess.model = norm
|
sess.model = norm
|
||||||
|
|
||||||
|
|||||||
20
tests/test_chat_cached_model_normalization.py
Normal file
20
tests/test_chat_cached_model_normalization.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_context_uses_cached_models_before_live_model_probe():
|
||||||
|
source = (ROOT / "routes" / "chat_helpers.py").read_text()
|
||||||
|
|
||||||
|
assert "def _normalize_model_id_from_cache" in source
|
||||||
|
assert "cached_models" in source
|
||||||
|
assert "norm = _normalize_model_id_from_cache(sess) or normalize_model_id" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_cached_model_match_keeps_basename_normalization():
|
||||||
|
source = (ROOT / "routes" / "chat_helpers.py").read_text()
|
||||||
|
|
||||||
|
assert "def _match_cached_model_id" in source
|
||||||
|
assert "os.path.basename(requested.rstrip(\"/\"))" in source
|
||||||
|
assert "os.path.basename(model_id.rstrip(\"/\")) == req_base" in source
|
||||||
Reference in New Issue
Block a user