Use LM Studio-reported vision capability for image passthrough (#1130)
Read a model's capabilities.vision flag from LM Studio's native /api/v1/models so vision finetunes whose names lack a vision keyword still receive images, falling back to the name heuristic when the endpoint doesn't report it. The probe is short-TTL cached and restricted to local/LAN hosts, so remote/cloud endpoints are never contacted.
This commit is contained in:
@@ -14,7 +14,7 @@ from src.constants import (
|
|||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
)
|
)
|
||||||
from core.models import ChatMessage
|
from core.models import ChatMessage
|
||||||
from src.chat_helpers import extract_urls, is_vision_model
|
from src.chat_helpers import extract_urls, model_supports_vision
|
||||||
from src.document_processor import build_user_content, analyze_image_with_vl_result
|
from src.document_processor import build_user_content, analyze_image_with_vl_result
|
||||||
from src.youtube_handler import (
|
from src.youtube_handler import (
|
||||||
is_youtube_url,
|
is_youtube_url,
|
||||||
@@ -146,7 +146,9 @@ class ChatHandler:
|
|||||||
# Analyze images — skip if vision disabled, or if main model is vision-capable
|
# Analyze images — skip if vision disabled, or if main model is vision-capable
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
vision_enabled = get_setting("vision_enabled", True)
|
vision_enabled = get_setting("vision_enabled", True)
|
||||||
main_is_vision = is_vision_model(sess.model or "")
|
main_is_vision = await asyncio.to_thread(
|
||||||
|
model_supports_vision, sess.model or "", getattr(sess, "endpoint_url", "") or ""
|
||||||
|
)
|
||||||
|
|
||||||
# Resolve uploads once with the session owner. Attachment IDs are
|
# Resolve uploads once with the session owner. Attachment IDs are
|
||||||
# bearer-like references; never trust them without an owner check.
|
# bearer-like references; never trust them without an owner check.
|
||||||
|
|||||||
@@ -4,10 +4,14 @@
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
|
import httpx
|
||||||
|
from urllib.parse import urlparse
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,6 +59,96 @@ def is_vision_model(model_name: str) -> bool:
|
|||||||
return bool(_VISION_VL_RE.search(m))
|
return bool(_VISION_VL_RE.search(m))
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDER_FINGERPRINT_TTL = 60.0
|
||||||
|
# (host, port) -> (models_list | None, expiry); list = LM Studio, None = not LM Studio.
|
||||||
|
_lmstudio_models_cache: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local_host(host: Optional[str]) -> bool:
|
||||||
|
"""True for loopback/LAN/Tailscale hosts (never public domains)."""
|
||||||
|
host = (host or "").lower()
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
if host in {"localhost", "host.docker.internal"} or host.endswith(".local"):
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
return "." not in host
|
||||||
|
if ip.is_loopback or ip.is_private or ip.is_link_local:
|
||||||
|
return True
|
||||||
|
return ip in ipaddress.ip_network("100.64.0.0/10")
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_lmstudio_models(url: str) -> Optional[list]:
|
||||||
|
"""Return LM Studio's native /api/v1/models list, or None when the endpoint
|
||||||
|
isn't LM Studio or is unreachable (short-TTL cached; transient errors uncached)."""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
key = (host, parsed.port)
|
||||||
|
now = time.time()
|
||||||
|
cached = _lmstudio_models_cache.get(key)
|
||||||
|
if cached is not None and cached[1] > now:
|
||||||
|
return cached[0]
|
||||||
|
authority = host if parsed.port is None else f"{host}:{parsed.port}"
|
||||||
|
probe_url = f"{parsed.scheme or 'http'}://{authority}/api/v1/models"
|
||||||
|
try:
|
||||||
|
r = httpx.get(probe_url, timeout=1.0)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
data = r.json() if r.is_success else {}
|
||||||
|
except Exception:
|
||||||
|
data = {}
|
||||||
|
models = data.get("models")
|
||||||
|
valid = (
|
||||||
|
isinstance(models, list) and bool(models)
|
||||||
|
and isinstance(models[0], dict)
|
||||||
|
and "key" in models[0] and "architecture" in models[0]
|
||||||
|
)
|
||||||
|
models = models if valid else None
|
||||||
|
_lmstudio_models_cache[key] = (models, now + _PROVIDER_FINGERPRINT_TTL)
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def lmstudio_supports_vision(url: str, model: str) -> Optional[bool]:
|
||||||
|
"""Read `model`'s capabilities.vision flag from LM Studio, or None when the
|
||||||
|
endpoint isn't LM Studio or doesn't report it (so callers fall back)."""
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
# Never probe a remote provider; LM Studio is always a local/LAN host.
|
||||||
|
if not _is_local_host(urlparse(url).hostname):
|
||||||
|
return None
|
||||||
|
models = _probe_lmstudio_models(url)
|
||||||
|
if not models:
|
||||||
|
return None
|
||||||
|
want = model.strip().lower()
|
||||||
|
for m in models:
|
||||||
|
if not isinstance(m, dict):
|
||||||
|
continue
|
||||||
|
names = {str(m.get("key", "")).lower(), str(m.get("display_name", "")).lower()}
|
||||||
|
if want in names:
|
||||||
|
caps = m.get("capabilities")
|
||||||
|
if isinstance(caps, dict) and "vision" in caps:
|
||||||
|
return bool(caps.get("vision"))
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def model_supports_vision(model_name: str, endpoint_url: str = "") -> bool:
|
||||||
|
"""Whether a model accepts images, using the endpoint's reported
|
||||||
|
capability when available (LM Studio) and falling back to name-based
|
||||||
|
detection otherwise."""
|
||||||
|
if endpoint_url:
|
||||||
|
try:
|
||||||
|
advertised = lmstudio_supports_vision(endpoint_url, model_name or "")
|
||||||
|
except Exception:
|
||||||
|
advertised = None
|
||||||
|
if advertised is not None:
|
||||||
|
return advertised
|
||||||
|
return is_vision_model(model_name)
|
||||||
|
|
||||||
|
|
||||||
def validate_message(message: str) -> str:
|
def validate_message(message: str) -> str:
|
||||||
"""Validate message input."""
|
"""Validate message input."""
|
||||||
if not message:
|
if not message:
|
||||||
|
|||||||
104
tests/test_lmstudio_vision.py
Normal file
104
tests/test_lmstudio_vision.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Tests for LM Studio vision-capability passthrough: reading capabilities.vision
|
||||||
|
from the native /api/v1/models endpoint, with no probing of cloud providers."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src import chat_helpers
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, payload, ok=True):
|
||||||
|
self._payload = payload
|
||||||
|
self.is_success = ok
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# lmstudio_supports_vision — reads capabilities.vision
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestLmStudioSupportsVision:
|
||||||
|
# A vision finetune whose NAME has no vision keyword — the case the
|
||||||
|
# name-based heuristic gets wrong (the issue this fixes).
|
||||||
|
PAYLOAD = {"models": [
|
||||||
|
{"key": "qwen3.6-27b-custom-finetune", "architecture": "qwen35",
|
||||||
|
"capabilities": {"vision": True, "trained_for_tool_use": True}},
|
||||||
|
{"key": "text-only-llm", "architecture": "qwen35",
|
||||||
|
"capabilities": {"vision": False}},
|
||||||
|
{"key": "no-caps-model", "architecture": "qwen35"},
|
||||||
|
]}
|
||||||
|
URL = "http://localhost:1234/v1/chat/completions"
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_cache(self):
|
||||||
|
chat_helpers._lmstudio_models_cache.clear()
|
||||||
|
yield
|
||||||
|
chat_helpers._lmstudio_models_cache.clear()
|
||||||
|
|
||||||
|
def _serve(self, monkeypatch, payload):
|
||||||
|
monkeypatch.setattr(chat_helpers.httpx, "get",
|
||||||
|
lambda url, timeout=None: _FakeResponse(payload))
|
||||||
|
|
||||||
|
def test_vision_true_from_capabilities(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, self.PAYLOAD)
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "qwen3.6-27b-custom-finetune") is True
|
||||||
|
|
||||||
|
def test_vision_false_from_capabilities(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, self.PAYLOAD)
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "text-only-llm") is False
|
||||||
|
|
||||||
|
def test_model_without_capabilities_returns_none(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, self.PAYLOAD)
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "no-caps-model") is None
|
||||||
|
|
||||||
|
def test_unknown_model_returns_none(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, self.PAYLOAD)
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "not-listed") is None
|
||||||
|
|
||||||
|
def test_non_lmstudio_endpoint_returns_none(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, {"data": [{"id": "gpt-4o"}]})
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "gpt-4o") is None
|
||||||
|
|
||||||
|
def test_empty_model_returns_none(self, monkeypatch):
|
||||||
|
self._serve(monkeypatch, self.PAYLOAD)
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(self.URL, "") is None
|
||||||
|
|
||||||
|
def test_remote_endpoint_never_probed(self, monkeypatch):
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def tracking_get(url, timeout=None):
|
||||||
|
calls["n"] += 1
|
||||||
|
return _FakeResponse(self.PAYLOAD)
|
||||||
|
|
||||||
|
monkeypatch.setattr(chat_helpers.httpx, "get", tracking_get)
|
||||||
|
# A cloud provider host must short-circuit to None with no network probe.
|
||||||
|
assert chat_helpers.lmstudio_supports_vision(
|
||||||
|
"https://api.openai.com/v1/chat/completions", "gpt-4o") is None
|
||||||
|
assert calls["n"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# model_supports_vision — endpoint capability wins, name is fallback
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestModelSupportsVision:
|
||||||
|
"""Endpoint-aware vision check: API capability wins, name heuristic is the fallback."""
|
||||||
|
|
||||||
|
def test_api_capability_overrides_name_heuristic(self, monkeypatch):
|
||||||
|
# Name has no vision keyword, but the endpoint advertises vision=True.
|
||||||
|
monkeypatch.setattr(chat_helpers, "is_vision_model", lambda n: False)
|
||||||
|
monkeypatch.setattr(chat_helpers, "lmstudio_supports_vision", lambda url, m: True)
|
||||||
|
assert chat_helpers.model_supports_vision("qwen3.6-27b-finetune",
|
||||||
|
"http://localhost:1234/v1/chat/completions") is True
|
||||||
|
|
||||||
|
def test_falls_back_to_name_when_no_endpoint(self):
|
||||||
|
# No endpoint URL → pure name heuristic.
|
||||||
|
assert chat_helpers.model_supports_vision("llava-1.6", "") is True
|
||||||
|
assert chat_helpers.model_supports_vision("mistral-7b", "") is False
|
||||||
|
|
||||||
|
def test_falls_back_to_name_when_endpoint_unknown(self, monkeypatch):
|
||||||
|
# Endpoint doesn't advertise (None) → name heuristic decides.
|
||||||
|
monkeypatch.setattr(chat_helpers, "lmstudio_supports_vision", lambda url, m: None)
|
||||||
|
assert chat_helpers.model_supports_vision("qwen2-vl-7b", "http://host/v1") is True
|
||||||
|
assert chat_helpers.model_supports_vision("plain-llm", "http://host/v1") is False
|
||||||
Reference in New Issue
Block a user