Chat: route image sessions only to matching image endpoints
Co-authored-by: ghreprimand <203024559+ghreprimand@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Track active streams for partial-save safety net
|
||||
_active_streams: Dict[str, dict] = {}
|
||||
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
||||
|
||||
|
||||
def _stream_set(session_id: str, **fields) -> None:
|
||||
@@ -98,6 +99,59 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
db.close()
|
||||
|
||||
|
||||
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
||||
"""Return True when a populated endpoint model cache includes ``model``.
|
||||
|
||||
Empty/malformed caches are treated as unknown rather than a negative match
|
||||
so older image endpoints without cached models still work.
|
||||
"""
|
||||
raw = getattr(endpoint, "cached_models", None)
|
||||
if not raw:
|
||||
return True
|
||||
try:
|
||||
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
return True
|
||||
if not isinstance(models, list) or not models:
|
||||
return True
|
||||
wanted = (model or "").strip()
|
||||
return wanted in {str(item).strip() for item in models}
|
||||
|
||||
|
||||
def _is_image_generation_session(sess) -> bool:
|
||||
"""Whether this chat session should bypass text chat and generate images.
|
||||
|
||||
Model-name prefixes are explicit image models. Endpoint type is only used
|
||||
when the current session endpoint actually matches that image endpoint, and
|
||||
when a populated endpoint model cache includes the selected model. This
|
||||
prevents an image endpoint on the same host from misrouting ordinary text
|
||||
models into the image-generation path.
|
||||
"""
|
||||
model = (getattr(sess, "model", "") or "").strip()
|
||||
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
|
||||
return True
|
||||
|
||||
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||
if not endpoint_url:
|
||||
return False
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for endpoint in endpoints:
|
||||
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
||||
continue
|
||||
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
|
||||
continue
|
||||
if _endpoint_cache_contains_model(endpoint, model):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
return False
|
||||
|
||||
|
||||
def _recover_empty_session_model(sess, session_id: str) -> bool:
|
||||
"""Re-populate sess.model from the matching endpoint's cached models.
|
||||
|
||||
@@ -726,28 +780,7 @@ def setup_chat_routes(
|
||||
_model_info["character_name"] = ctx.preset.character_name
|
||||
yield f'data: {json.dumps(_model_info)}\n\n'
|
||||
|
||||
# Detect image models and route directly to image generation
|
||||
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
||||
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
|
||||
|
||||
# Also check if the endpoint is registered as an image-type endpoint
|
||||
if not _is_image_model:
|
||||
try:
|
||||
from src.endpoint_resolver import normalize_base as _nb
|
||||
_ep_base = _nb(sess.endpoint_url)
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
_is_image_model = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.model_type == "image",
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
|
||||
).first() is not None
|
||||
finally:
|
||||
_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _is_image_model:
|
||||
if _is_image_generation_session(sess):
|
||||
from src.settings import get_setting
|
||||
if not get_setting("image_gen_enabled", True):
|
||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||
|
||||
78
tests/test_chat_image_routing.py
Normal file
78
tests/test_chat_image_routing.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
from routes import chat_routes
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
|
||||
def filter(self, *conditions):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _FakeDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
return _FakeQuery(self.rows)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _session(model="qwen3.5:latest", endpoint_url="http://localhost:11434/v1/chat/completions"):
|
||||
return SimpleNamespace(model=model, endpoint_url=endpoint_url)
|
||||
|
||||
|
||||
def _endpoint(base_url, model_type="image", models=None):
|
||||
cached_models = None if models is None else json.dumps(models)
|
||||
return SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model_type=model_type,
|
||||
is_enabled=True,
|
||||
cached_models=cached_models,
|
||||
)
|
||||
|
||||
|
||||
def test_image_model_prefix_routes_to_image_generation_without_endpoint_lookup(monkeypatch):
|
||||
def fail_if_called():
|
||||
raise AssertionError("prefixed image models should not need a DB lookup")
|
||||
|
||||
monkeypatch.setattr(chat_routes, "SessionLocal", fail_if_called)
|
||||
|
||||
assert chat_routes._is_image_generation_session(_session(model="dall-e-3"))
|
||||
|
||||
|
||||
def test_image_endpoint_does_not_catch_text_model_on_different_path(monkeypatch):
|
||||
db = _FakeDb([
|
||||
_endpoint("http://localhost:11434/v1/images", models=["sdxl-local"]),
|
||||
])
|
||||
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||
|
||||
assert not chat_routes._is_image_generation_session(_session())
|
||||
assert db.closed
|
||||
|
||||
|
||||
def test_image_endpoint_cache_must_contain_selected_model(monkeypatch):
|
||||
db = _FakeDb([
|
||||
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
|
||||
])
|
||||
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||
|
||||
assert not chat_routes._is_image_generation_session(_session(model="qwen3.5:latest"))
|
||||
|
||||
|
||||
def test_matching_image_endpoint_routes_selected_image_model(monkeypatch):
|
||||
db = _FakeDb([
|
||||
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
|
||||
])
|
||||
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||
|
||||
assert chat_routes._is_image_generation_session(_session(model="sdxl-local"))
|
||||
Reference in New Issue
Block a user