Use LM Studio-reported vision capability for image passthrough (#1130)

Read a model's capabilities.vision flag from LM Studio's native /api/v1/models so vision finetunes whose names lack a vision keyword still receive images, falling back to the name heuristic when the endpoint doesn't report it. The probe is short-TTL cached and restricted to local/LAN hosts, so remote/cloud endpoints are never contacted.
This commit is contained in:
RosenTomov
2026-06-02 17:01:04 +03:00
committed by GitHub
parent 18a445ba22
commit a493fb49b0
3 changed files with 203 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ from src.constants import (
UPLOAD_DIR,
)
from core.models import ChatMessage
from src.chat_helpers import extract_urls, is_vision_model
from src.chat_helpers import extract_urls, model_supports_vision
from src.document_processor import build_user_content, analyze_image_with_vl_result
from src.youtube_handler import (
is_youtube_url,
@@ -146,7 +146,9 @@ class ChatHandler:
# Analyze images — skip if vision disabled, or if main model is vision-capable
from src.settings import get_setting
vision_enabled = get_setting("vision_enabled", True)
main_is_vision = is_vision_model(sess.model or "")
main_is_vision = await asyncio.to_thread(
model_supports_vision, sess.model or "", getattr(sess, "endpoint_url", "") or ""
)
# Resolve uploads once with the session owner. Attachment IDs are
# bearer-like references; never trust them without an owner check.

View File

@@ -4,10 +4,14 @@
import re
import os
import json
import time
import ipaddress
import logging
import httpx
from urllib.parse import urlparse
from fastapi import HTTPException
from fastapi import UploadFile
from typing import List
from typing import List, Optional
logger = logging.getLogger(__name__)
@@ -55,6 +59,96 @@ def is_vision_model(model_name: str) -> bool:
return bool(_VISION_VL_RE.search(m))
_PROVIDER_FINGERPRINT_TTL = 60.0
# (host, port) -> (models_list | None, expiry); list = LM Studio, None = not LM Studio.
_lmstudio_models_cache: dict = {}
def _is_local_host(host: Optional[str]) -> bool:
"""True for loopback/LAN/Tailscale hosts (never public domains)."""
host = (host or "").lower()
if not host:
return False
if host in {"localhost", "host.docker.internal"} or host.endswith(".local"):
return True
try:
ip = ipaddress.ip_address(host)
except ValueError:
return "." not in host
if ip.is_loopback or ip.is_private or ip.is_link_local:
return True
return ip in ipaddress.ip_network("100.64.0.0/10")
def _probe_lmstudio_models(url: str) -> Optional[list]:
"""Return LM Studio's native /api/v1/models list, or None when the endpoint
isn't LM Studio or is unreachable (short-TTL cached; transient errors uncached)."""
parsed = urlparse(url)
host = parsed.hostname or ""
key = (host, parsed.port)
now = time.time()
cached = _lmstudio_models_cache.get(key)
if cached is not None and cached[1] > now:
return cached[0]
authority = host if parsed.port is None else f"{host}:{parsed.port}"
probe_url = f"{parsed.scheme or 'http'}://{authority}/api/v1/models"
try:
r = httpx.get(probe_url, timeout=1.0)
except Exception:
return None
try:
data = r.json() if r.is_success else {}
except Exception:
data = {}
models = data.get("models")
valid = (
isinstance(models, list) and bool(models)
and isinstance(models[0], dict)
and "key" in models[0] and "architecture" in models[0]
)
models = models if valid else None
_lmstudio_models_cache[key] = (models, now + _PROVIDER_FINGERPRINT_TTL)
return models
def lmstudio_supports_vision(url: str, model: str) -> Optional[bool]:
"""Read `model`'s capabilities.vision flag from LM Studio, or None when the
endpoint isn't LM Studio or doesn't report it (so callers fall back)."""
if not model:
return None
# Never probe a remote provider; LM Studio is always a local/LAN host.
if not _is_local_host(urlparse(url).hostname):
return None
models = _probe_lmstudio_models(url)
if not models:
return None
want = model.strip().lower()
for m in models:
if not isinstance(m, dict):
continue
names = {str(m.get("key", "")).lower(), str(m.get("display_name", "")).lower()}
if want in names:
caps = m.get("capabilities")
if isinstance(caps, dict) and "vision" in caps:
return bool(caps.get("vision"))
return None
return None
def model_supports_vision(model_name: str, endpoint_url: str = "") -> bool:
"""Whether a model accepts images, using the endpoint's reported
capability when available (LM Studio) and falling back to name-based
detection otherwise."""
if endpoint_url:
try:
advertised = lmstudio_supports_vision(endpoint_url, model_name or "")
except Exception:
advertised = None
if advertised is not None:
return advertised
return is_vision_model(model_name)
def validate_message(message: str) -> str:
"""Validate message input."""
if not message: