diff --git a/routes/chat_routes.py b/routes/chat_routes.py index b2e0de0..df3ae7a 100644 --- a/routes/chat_routes.py +++ b/routes/chat_routes.py @@ -43,6 +43,7 @@ logger = logging.getLogger(__name__) # Track active streams for partial-save safety net _active_streams: Dict[str, dict] = {} +_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image") def _stream_set(session_id: str, **fields) -> None: @@ -98,6 +99,59 @@ def _clear_orphaned_session_endpoint(sess) -> bool: db.close() +def _endpoint_cache_contains_model(endpoint, model: str) -> bool: + """Return True when a populated endpoint model cache includes ``model``. + + Empty/malformed caches are treated as unknown rather than a negative match + so older image endpoints without cached models still work. + """ + raw = getattr(endpoint, "cached_models", None) + if not raw: + return True + try: + models = json.loads(raw) if isinstance(raw, str) else raw + except Exception: + return True + if not isinstance(models, list) or not models: + return True + wanted = (model or "").strip() + return wanted in {str(item).strip() for item in models} + + +def _is_image_generation_session(sess) -> bool: + """Whether this chat session should bypass text chat and generate images. + + Model-name prefixes are explicit image models. Endpoint type is only used + when the current session endpoint actually matches that image endpoint, and + when a populated endpoint model cache includes the selected model. This + prevents an image endpoint on the same host from misrouting ordinary text + models into the image-generation path. + """ + model = (getattr(sess, "model", "") or "").strip() + if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES): + return True + + endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip() + if not endpoint_url: + return False + + db = SessionLocal() + try: + endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + for endpoint in endpoints: + if (getattr(endpoint, "model_type", None) or "llm") != "image": + continue + if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""): + continue + if _endpoint_cache_contains_model(endpoint, model): + return True + except Exception: + return False + finally: + db.close() + return False + + def _recover_empty_session_model(sess, session_id: str) -> bool: """Re-populate sess.model from the matching endpoint's cached models. @@ -726,28 +780,7 @@ def setup_chat_routes( _model_info["character_name"] = ctx.preset.character_name yield f'data: {json.dumps(_model_info)}\n\n' - # Detect image models and route directly to image generation - _IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image") - _is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES) - - # Also check if the endpoint is registered as an image-type endpoint - if not _is_image_model: - try: - from src.endpoint_resolver import normalize_base as _nb - _ep_base = _nb(sess.endpoint_url) - _db = SessionLocal() - try: - _is_image_model = _db.query(ModelEndpoint).filter( - ModelEndpoint.model_type == "image", - ModelEndpoint.is_enabled == True, - ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]), - ).first() is not None - finally: - _db.close() - except Exception: - pass - - if _is_image_model: + if _is_image_generation_session(sess): from src.settings import get_setting if not get_setting("image_gen_enabled", True): yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n' diff --git a/tests/test_chat_image_routing.py b/tests/test_chat_image_routing.py new file mode 100644 index 0000000..dc2a869 --- /dev/null +++ b/tests/test_chat_image_routing.py @@ -0,0 +1,78 @@ +import json +from types import SimpleNamespace + +from routes import chat_routes + + +class _FakeQuery: + def __init__(self, rows): + self.rows = rows + + def filter(self, *conditions): + return self + + def all(self): + return list(self.rows) + + +class _FakeDb: + def __init__(self, rows): + self.rows = rows + self.closed = False + + def query(self, model): + return _FakeQuery(self.rows) + + def close(self): + self.closed = True + + +def _session(model="qwen3.5:latest", endpoint_url="http://localhost:11434/v1/chat/completions"): + return SimpleNamespace(model=model, endpoint_url=endpoint_url) + + +def _endpoint(base_url, model_type="image", models=None): + cached_models = None if models is None else json.dumps(models) + return SimpleNamespace( + base_url=base_url, + model_type=model_type, + is_enabled=True, + cached_models=cached_models, + ) + + +def test_image_model_prefix_routes_to_image_generation_without_endpoint_lookup(monkeypatch): + def fail_if_called(): + raise AssertionError("prefixed image models should not need a DB lookup") + + monkeypatch.setattr(chat_routes, "SessionLocal", fail_if_called) + + assert chat_routes._is_image_generation_session(_session(model="dall-e-3")) + + +def test_image_endpoint_does_not_catch_text_model_on_different_path(monkeypatch): + db = _FakeDb([ + _endpoint("http://localhost:11434/v1/images", models=["sdxl-local"]), + ]) + monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db) + + assert not chat_routes._is_image_generation_session(_session()) + assert db.closed + + +def test_image_endpoint_cache_must_contain_selected_model(monkeypatch): + db = _FakeDb([ + _endpoint("http://localhost:11434/v1", models=["sdxl-local"]), + ]) + monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db) + + assert not chat_routes._is_image_generation_session(_session(model="qwen3.5:latest")) + + +def test_matching_image_endpoint_routes_selected_image_model(monkeypatch): + db = _FakeDb([ + _endpoint("http://localhost:11434/v1", models=["sdxl-local"]), + ]) + monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db) + + assert chat_routes._is_image_generation_session(_session(model="sdxl-local"))