Chat: route image sessions only to matching image endpoints

Co-authored-by: ghreprimand <203024559+ghreprimand@users.noreply.github.com>
This commit is contained in:
ghreprimand
2026-06-02 06:52:03 -05:00
committed by GitHub
parent 064c1ace91
commit 4cec31d988
2 changed files with 133 additions and 22 deletions

View File

@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
# Track active streams for partial-save safety net
_active_streams: Dict[str, dict] = {}
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
def _stream_set(session_id: str, **fields) -> None:
@@ -98,6 +99,59 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
db.close()
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
"""Return True when a populated endpoint model cache includes ``model``.
Empty/malformed caches are treated as unknown rather than a negative match
so older image endpoints without cached models still work.
"""
raw = getattr(endpoint, "cached_models", None)
if not raw:
return True
try:
models = json.loads(raw) if isinstance(raw, str) else raw
except Exception:
return True
if not isinstance(models, list) or not models:
return True
wanted = (model or "").strip()
return wanted in {str(item).strip() for item in models}
def _is_image_generation_session(sess) -> bool:
"""Whether this chat session should bypass text chat and generate images.
Model-name prefixes are explicit image models. Endpoint type is only used
when the current session endpoint actually matches that image endpoint, and
when a populated endpoint model cache includes the selected model. This
prevents an image endpoint on the same host from misrouting ordinary text
models into the image-generation path.
"""
model = (getattr(sess, "model", "") or "").strip()
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
return True
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
if not endpoint_url:
return False
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for endpoint in endpoints:
if (getattr(endpoint, "model_type", None) or "llm") != "image":
continue
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
continue
if _endpoint_cache_contains_model(endpoint, model):
return True
except Exception:
return False
finally:
db.close()
return False
def _recover_empty_session_model(sess, session_id: str) -> bool:
"""Re-populate sess.model from the matching endpoint's cached models.
@@ -726,28 +780,7 @@ def setup_chat_routes(
_model_info["character_name"] = ctx.preset.character_name
yield f'data: {json.dumps(_model_info)}\n\n'
# Detect image models and route directly to image generation
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
# Also check if the endpoint is registered as an image-type endpoint
if not _is_image_model:
try:
from src.endpoint_resolver import normalize_base as _nb
_ep_base = _nb(sess.endpoint_url)
_db = SessionLocal()
try:
_is_image_model = _db.query(ModelEndpoint).filter(
ModelEndpoint.model_type == "image",
ModelEndpoint.is_enabled == True,
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
).first() is not None
finally:
_db.close()
except Exception:
pass
if _is_image_model:
if _is_image_generation_session(sess):
from src.settings import get_setting
if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'

View File

@@ -0,0 +1,78 @@
import json
from types import SimpleNamespace
from routes import chat_routes
class _FakeQuery:
def __init__(self, rows):
self.rows = rows
def filter(self, *conditions):
return self
def all(self):
return list(self.rows)
class _FakeDb:
def __init__(self, rows):
self.rows = rows
self.closed = False
def query(self, model):
return _FakeQuery(self.rows)
def close(self):
self.closed = True
def _session(model="qwen3.5:latest", endpoint_url="http://localhost:11434/v1/chat/completions"):
return SimpleNamespace(model=model, endpoint_url=endpoint_url)
def _endpoint(base_url, model_type="image", models=None):
cached_models = None if models is None else json.dumps(models)
return SimpleNamespace(
base_url=base_url,
model_type=model_type,
is_enabled=True,
cached_models=cached_models,
)
def test_image_model_prefix_routes_to_image_generation_without_endpoint_lookup(monkeypatch):
def fail_if_called():
raise AssertionError("prefixed image models should not need a DB lookup")
monkeypatch.setattr(chat_routes, "SessionLocal", fail_if_called)
assert chat_routes._is_image_generation_session(_session(model="dall-e-3"))
def test_image_endpoint_does_not_catch_text_model_on_different_path(monkeypatch):
db = _FakeDb([
_endpoint("http://localhost:11434/v1/images", models=["sdxl-local"]),
])
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
assert not chat_routes._is_image_generation_session(_session())
assert db.closed
def test_image_endpoint_cache_must_contain_selected_model(monkeypatch):
db = _FakeDb([
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
])
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
assert not chat_routes._is_image_generation_session(_session(model="qwen3.5:latest"))
def test_matching_image_endpoint_routes_selected_image_model(monkeypatch):
db = _FakeDb([
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
])
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
assert chat_routes._is_image_generation_session(_session(model="sdxl-local"))