diff --git a/README.md b/README.md index d02c139..5e7d3d8 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,20 @@ RENDER_GID=989 For NVIDIA/AMD GPU support, also read the comments in the selected overlay file: docker/gpu.nvidia.yml or docker/gpu.amd.yml. +**Stack-management UIs (Portainer, Coolify, Dockhand, etc.).** These tools +often accept only a single Compose file and do not reliably honor `COMPOSE_FILE` +or multiple `-f` overlays. CLI users should keep using the `COMPOSE_FILE` +overlay workflow above. For stack UIs, point the stack at one of the standalone +files instead, which bundle the base stack plus the GPU settings: + +- `docker-compose.gpu-nvidia.yml` — still requires the NVIDIA Container Toolkit + on the host. +- `docker-compose.gpu-amd.yml` — still requires host ROCm/kfd/DRI setup, the + `video`/`render` group membership, and `RENDER_GID` when needed. + +The base `docker-compose.yml` plus the `docker/gpu.*.yml` overlays remain the +source of truth; the standalone files mirror them for single-file deployments. + Verify after enabling either overlay: ```bash diff --git a/core/database.py b/core/database.py index 293a303..d530171 100644 --- a/core/database.py +++ b/core/database.py @@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base): is_enabled = Column(Boolean, default=True) hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) + pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models) model_type = Column(String, nullable=True, default="llm") # "llm" or "image" # Whether models on this endpoint accept OpenAI-style function # schemas + emit `tool_calls`. Auto-detected at Cookbook auto- @@ -856,6 +857,24 @@ def _migrate_add_cached_models_column(): except Exception as e: logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") +def _migrate_add_pinned_models_column(): + """Add pinned_models column to model_endpoints if it doesn't exist.""" + import sqlite3 + db_path = DATABASE_URL.replace("sqlite:///", "") + if not os.path.exists(db_path): + return + try: + conn = sqlite3.connect(db_path) + cursor = conn.execute("PRAGMA table_info(model_endpoints)") + columns = [row[1] for row in cursor.fetchall()] + if columns and "pinned_models" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT") + conn.commit() + logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints") + conn.close() + except Exception as e: + logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}") + def _migrate_add_notes_sort_order(): """Add sort_order, image_url, repeat columns to notes if they don't exist.""" import sqlite3 @@ -1511,6 +1530,7 @@ def init_db(): Base.metadata.create_all(bind=engine) _migrate_add_hidden_models_column() _migrate_add_cached_models_column() + _migrate_add_pinned_models_column() _migrate_add_notes_sort_order() _migrate_add_model_type_column() _migrate_add_model_endpoint_owner_column() diff --git a/docker-compose.gpu-amd.yml b/docker-compose.gpu-amd.yml new file mode 100644 index 0000000..47e0c85 --- /dev/null +++ b/docker-compose.gpu-amd.yml @@ -0,0 +1,164 @@ +# Standalone AMD ROCm GPU Compose file for stack-management UIs (Portainer, +# Coolify, Dockhand, etc.) that accept only a single Compose file and do not +# reliably honor COMPOSE_FILE or multiple `-f` overlays. +# +# This is equivalent to: docker-compose.yml + docker/gpu.amd.yml. +# The base docker-compose.yml plus the docker/gpu.amd.yml overlay remain the +# source of truth — CLI users should keep using the COMPOSE_FILE overlay +# workflow. Keep this file in sync with both when either changes. +# +# Requires ROCm drivers on the host (kfd + DRI devices) and the host user +# running Docker in the `video` and `render` groups. Set RENDER_GID to your +# host's numeric render group id when needed. See docker/gpu.amd.yml for details. +services: + odysseus: + build: . + ports: + - "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000" + volumes: + - ./data:/app/data:z + - ./logs:/app/logs:z + # Cookbook remote-server SSH identity. Odysseus can generate a key here; + # add the shown public key to each remote server's authorized_keys. + - ./data/ssh:/app/.ssh:z + # Cookbook local model cache. Inside Docker, "Local" means the Odysseus + # container, so persist its HuggingFace cache under ./data/huggingface. + - ./data/huggingface:/app/.cache/huggingface:z + # Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.) + # land under /app/.local for the odysseus user. Persist them so a + # container recreate does not silently remove installed serve engines. + - ./data/local:/app/.local:z + extra_hosts: + # Lets the container reach local services on the Docker host, including + # Ollama at http://host.docker.internal:11434. + - "host.docker.internal:host-gateway" + environment: + - LLM_HOST=${LLM_HOST:-localhost} + - LLM_HOSTS=${LLM_HOSTS:-} + - OPENAI_API_KEY=${OPENAI_API_KEY:-} + - OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-} + - RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-} + - HF_TOKEN=${HF_TOKEN:-} + - HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-} + - SEARXNG_INSTANCE=http://searxng:8080 + - CHROMADB_HOST=chromadb + - CHROMADB_PORT=8000 + - DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db} + - AUTH_ENABLED=${AUTH_ENABLED:-true} + - LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false} + - ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin} + - ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-} + - ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1} + - SECURE_COOKIES=${SECURE_COOKIES:-false} + - EMBEDDING_URL=${EMBEDDING_URL:-} + - EMBEDDING_MODEL=${EMBEDDING_MODEL:-} + - FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2} + - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-} + - CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24} + - ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1} + - ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1} + - ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost} + - DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY:-} + - GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-} + - TAVILY_API_KEY=${TAVILY_API_KEY:-} + - SERPER_API_KEY=${SERPER_API_KEY:-} + # PUID / PGID — the user/group the container drops to before + # running uvicorn (entrypoint also chowns /app/data + /app/logs + # to match, so bind-mounted files stay editable from the host). + # 1000 is the default first user on most Linux installs. If your + # host user has a different id, override here or via .env, e.g.: + # PUID=1001 + # PGID=1001 + # Find yours with: id -u / id -g + - PUID=${PUID:-1000} + - PGID=${PGID:-1000} + depends_on: + searxng: + condition: service_healthy + chromadb: + condition: service_started + restart: unless-stopped + # AMD ROCm overlay (from docker/gpu.amd.yml). + devices: + - /dev/kfd + - /dev/dri + group_add: + - video + - ${RENDER_GID:-render} + + chromadb: + image: docker.io/chromadb/chroma:latest + ports: + - "${CHROMADB_BIND:-127.0.0.1}:8100:8000" + volumes: + - chromadb-data:/chroma/chroma + environment: + - ANONYMIZED_TELEMETRY=FALSE + restart: unless-stopped + + searxng: + # Pinned, not :latest — odysseus waits on searxng's healthcheck + # (depends_on: condition: service_healthy), so a broken upstream `latest` + # tag blocks the whole app from starting. 2026.6.2 crashes on boot with + # `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414). + # Bump this deliberately after verifying a newer tag boots clean. + image: docker.io/searxng/searxng:2026.5.31-7159b8aed + entrypoint: + - /bin/sh + - -c + - | + set -eu + if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then + secret="$${SEARXNG_SECRET:-}" + if [ -z "$$secret" ]; then + secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')" + fi + sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml + fi + exec /usr/local/searxng/entrypoint.sh + ports: + - "127.0.0.1:8080:8080" + volumes: + - searxng-data:/etc/searxng + - ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z + environment: + - SEARXNG_BASE_URL=http://localhost:8080/ + - SEARXNG_SECRET=${SEARXNG_SECRET:-} + # The official searxng image runs as the non-root `searxng` user, but its + # entrypoint still needs to chown /etc/searxng on first boot, drop privs via + # su-exec, and (with our wrapper above) write settings.yml into the named + # volume. Without these capabilities the wrapper aborts at the redirection + # with EACCES and the container fails its healthcheck with permission + # errors during setup. Mirrors the cap set recommended by the upstream + # searxng-docker compose file. See issue #721. + cap_drop: + - ALL + cap_add: + - CHOWN + - SETGID + - SETUID + - DAC_OVERRIDE + healthcheck: + test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""] + interval: 5s + timeout: 6s + retries: 20 + start_period: 10s + restart: unless-stopped + + ntfy: + image: docker.io/binwiederhier/ntfy + command: serve + ports: + - "${NTFY_BIND:-127.0.0.1}:8091:80" + volumes: + - ntfy-cache:/var/cache/ntfy + environment: + - NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091} + restart: unless-stopped + +volumes: + searxng-data: + chromadb-data: + ntfy-cache: diff --git a/docker-compose.gpu-nvidia.yml b/docker-compose.gpu-nvidia.yml new file mode 100644 index 0000000..36ca10e --- /dev/null +++ b/docker-compose.gpu-nvidia.yml @@ -0,0 +1,167 @@ +# Standalone NVIDIA GPU Compose file for stack-management UIs (Portainer, +# Coolify, Dockhand, etc.) that accept only a single Compose file and do not +# reliably honor COMPOSE_FILE or multiple `-f` overlays. +# +# This is equivalent to: docker-compose.yml + docker/gpu.nvidia.yml. +# The base docker-compose.yml plus the docker/gpu.nvidia.yml overlay remain +# the source of truth — CLI users should keep using the COMPOSE_FILE overlay +# workflow. Keep this file in sync with both when either changes. +# +# Requires the NVIDIA Container Toolkit on the host. See docker/gpu.nvidia.yml +# for setup details. +services: + odysseus: + build: . + ports: + - "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000" + volumes: + - ./data:/app/data:z + - ./logs:/app/logs:z + # Cookbook remote-server SSH identity. Odysseus can generate a key here; + # add the shown public key to each remote server's authorized_keys. + - ./data/ssh:/app/.ssh:z + # Cookbook local model cache. Inside Docker, "Local" means the Odysseus + # container, so persist its HuggingFace cache under ./data/huggingface. + - ./data/huggingface:/app/.cache/huggingface:z + # Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.) + # land under /app/.local for the odysseus user. Persist them so a + # container recreate does not silently remove installed serve engines. + - ./data/local:/app/.local:z + extra_hosts: + # Lets the container reach local services on the Docker host, including + # Ollama at http://host.docker.internal:11434. + - "host.docker.internal:host-gateway" + environment: + - LLM_HOST=${LLM_HOST:-localhost} + - LLM_HOSTS=${LLM_HOSTS:-} + - OPENAI_API_KEY=${OPENAI_API_KEY:-} + - OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-} + - RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-} + - HF_TOKEN=${HF_TOKEN:-} + - HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-} + - SEARXNG_INSTANCE=http://searxng:8080 + - CHROMADB_HOST=chromadb + - CHROMADB_PORT=8000 + - DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db} + - AUTH_ENABLED=${AUTH_ENABLED:-true} + - LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false} + - ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin} + - ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-} + - ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1} + - SECURE_COOKIES=${SECURE_COOKIES:-false} + - EMBEDDING_URL=${EMBEDDING_URL:-} + - EMBEDDING_MODEL=${EMBEDDING_MODEL:-} + - FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2} + - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-} + - CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24} + - ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1} + - ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1} + - ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost} + - DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY:-} + - GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-} + - TAVILY_API_KEY=${TAVILY_API_KEY:-} + - SERPER_API_KEY=${SERPER_API_KEY:-} + # PUID / PGID — the user/group the container drops to before + # running uvicorn (entrypoint also chowns /app/data + /app/logs + # to match, so bind-mounted files stay editable from the host). + # 1000 is the default first user on most Linux installs. If your + # host user has a different id, override here or via .env, e.g.: + # PUID=1001 + # PGID=1001 + # Find yours with: id -u / id -g + - PUID=${PUID:-1000} + - PGID=${PGID:-1000} + # NVIDIA overlay (from docker/gpu.nvidia.yml). + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=compute,utility + depends_on: + searxng: + condition: service_healthy + chromadb: + condition: service_started + restart: unless-stopped + # NVIDIA overlay (from docker/gpu.nvidia.yml). + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + + chromadb: + image: docker.io/chromadb/chroma:latest + ports: + - "${CHROMADB_BIND:-127.0.0.1}:8100:8000" + volumes: + - chromadb-data:/chroma/chroma + environment: + - ANONYMIZED_TELEMETRY=FALSE + restart: unless-stopped + + searxng: + # Pinned, not :latest — odysseus waits on searxng's healthcheck + # (depends_on: condition: service_healthy), so a broken upstream `latest` + # tag blocks the whole app from starting. 2026.6.2 crashes on boot with + # `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414). + # Bump this deliberately after verifying a newer tag boots clean. + image: docker.io/searxng/searxng:2026.5.31-7159b8aed + entrypoint: + - /bin/sh + - -c + - | + set -eu + if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then + secret="$${SEARXNG_SECRET:-}" + if [ -z "$$secret" ]; then + secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')" + fi + sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml + fi + exec /usr/local/searxng/entrypoint.sh + ports: + - "127.0.0.1:8080:8080" + volumes: + - searxng-data:/etc/searxng + - ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z + environment: + - SEARXNG_BASE_URL=http://localhost:8080/ + - SEARXNG_SECRET=${SEARXNG_SECRET:-} + # The official searxng image runs as the non-root `searxng` user, but its + # entrypoint still needs to chown /etc/searxng on first boot, drop privs via + # su-exec, and (with our wrapper above) write settings.yml into the named + # volume. Without these capabilities the wrapper aborts at the redirection + # with EACCES and the container fails its healthcheck with permission + # errors during setup. Mirrors the cap set recommended by the upstream + # searxng-docker compose file. See issue #721. + cap_drop: + - ALL + cap_add: + - CHOWN + - SETGID + - SETUID + - DAC_OVERRIDE + healthcheck: + test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""] + interval: 5s + timeout: 6s + retries: 20 + start_period: 10s + restart: unless-stopped + + ntfy: + image: docker.io/binwiederhier/ntfy + command: serve + ports: + - "${NTFY_BIND:-127.0.0.1}:8091:80" + volumes: + - ntfy-cache:/var/cache/ntfy + environment: + - NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091} + restart: unless-stopped + +volumes: + searxng-data: + chromadb-data: + ntfy-cache: diff --git a/routes/model_routes.py b/routes/model_routes.py index 0135d1c..f4153b0 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> return "No models found for that provider/key." -def _visible_models(cached_models, hidden_models): - """Filter cached model IDs by hidden_models. Returns list of visible IDs.""" - all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or []) +def _normalize_model_ids(value): + """Coerce a model-ID input into a clean, ordered list of strings. + + Accepts a list, a JSON-encoded list string, or a comma/newline separated + string (handy for form or backend API input). Trims whitespace, drops + empty and non-string values, and de-duplicates preserving first-seen order. + """ + if value is None: + return [] + items = value + if isinstance(value, str): + text = value.strip() + if not text: + return [] + try: + parsed = json.loads(text) + except Exception: + parsed = None + items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text) + if not isinstance(items, list): + return [] + out, seen = [], set() + for item in items: + if not isinstance(item, str): + continue + s = item.strip() + if not s or s in seen: + continue + seen.add(s) + out.append(s) + return out + + +def _merge_model_ids(*lists): + """Concatenate model-ID lists, de-duplicating and preserving order.""" + out, seen = [], set() + for ids in lists: + for m in (ids or []): + if not isinstance(m, str) or m in seen: + continue + seen.add(m) + out.append(m) + return out + + +def _visible_models(cached_models, hidden_models, pinned_models=None): + """Merge cached + pinned model IDs, then filter out hidden ones. + + Pinned IDs are admin-entered and may not appear in cached_models (e.g. + cloud deployment IDs the provider does not list in /v1/models). Returns an + ordered, de-duplicated list of visible IDs. + """ + # Normalize each input so JSON strings, lists, comma/newline strings, and + # malformed strings are all handled without raising. + merged = _merge_model_ids( + _normalize_model_ids(cached_models), + _normalize_model_ids(pinned_models), + ) if not hidden_models: - return all_models - hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or [])) - return [m for m in all_models if m not in hidden] + return merged + hidden = set(_normalize_model_ids(hidden_models)) + return [m for m in merged if m not in hidden] def setup_model_routes(model_discovery): @@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery): hidden = set(json.loads(r.hidden_models)) except Exception: pass - visible = [m for m in all_models if m not in hidden] - status = "online" if all_models else "offline" + pinned = _normalize_model_ids(getattr(r, "pinned_models", None)) + visible = _visible_models(all_models, r.hidden_models, pinned) + # Endpoint counts as reachable if it has any model — including + # admin-pinned IDs that a probe would never surface. + status = "online" if (all_models or pinned) else "offline" ping = None - if not all_models and r.is_enabled: + if not all_models and not pinned and r.is_enabled: ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) if ping.get("reachable"): status = "empty" @@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery): "has_key": bool(r.api_key), "is_enabled": r.is_enabled, "models": visible, + "pinned_models": pinned, "hidden_count": len(hidden), "online": status != "offline", "status": status, @@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery): require_models: str = Form("false"), model_type: str = Form("llm"), supports_tools: str = Form(""), # "true"/"false"/"" (unknown) + pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline container_local: str = Form("false"), # Default `shared=true` → endpoints are visible to all users (the # app's historical behaviour). Admins can pass `shared=false` to @@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery): .first() ) if existing: + # Persist any incoming pinned IDs onto the existing row. An + # empty/omitted form field must not wipe previously pinned IDs. + _incoming_pinned = _normalize_model_ids(pinned_models) + if _incoming_pinned: + _merged_pinned = _merge_model_ids( + _normalize_model_ids(getattr(existing, "pinned_models", None)), + _incoming_pinned, + ) + existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None + _db_dedup.commit() + _invalidate_models_cache() + _existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None)) return { "id": existing.id, "name": existing.name, "base_url": existing.base_url, - "models": json.loads(existing.cached_models) if existing.cached_models else [], + "models": _visible_models( + getattr(existing, "cached_models", None), + getattr(existing, "hidden_models", None), + existing.pinned_models, + ), + "pinned_models": _existing_pinned, "online": True, "status": "online", "existing": True, @@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery): try: _st_raw = (supports_tools or "").strip().lower() _st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None) + _pinned = _normalize_model_ids(pinned_models) # Stamp owner so the picker only shows this endpoint to the admin # who added it. Pass `shared=true` to mark it null-owner (visible # to all users), preserving the pre-fix "everyone sees everything" @@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery): is_enabled=True, model_type=model_type.strip() if model_type else "llm", cached_models=json.dumps(model_ids) if model_ids else None, + pinned_models=json.dumps(_pinned) if _pinned else None, supports_tools=_st, owner=_owner_val, ) @@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery): "id": ep_id, "name": name.strip(), "base_url": base_url, - "models": model_ids, - "online": bool(model_ids) or bool(ping.get("reachable")), - "status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"), + "models": _merge_model_ids(model_ids, _pinned), + "pinned_models": _pinned, + "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")), + "status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"), "ping_error": ping.get("error") if ping else None, } @@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery): hidden = set(json.loads(ep.hidden_models)) except Exception: pass - # Try live probe, fall back to cached + # Try live probe, fall back to cached. Pinned IDs are admin-entered + # and persist regardless of probe results — never overwritten here. all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) if all_models: ep.cached_models = json.dumps(all_models) @@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery): all_models = json.loads(ep.cached_models) except Exception: pass + pinned = _normalize_model_ids(getattr(ep, "pinned_models", None)) + pinned_set = set(pinned) return [ - {"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden} - for m in all_models + { + "id": m, + "display": m.split("/")[-1], + "is_hidden": m in hidden, + "is_pinned": m in pinned_set, + } + for m in _merge_model_ids(all_models, pinned) ] finally: db.close() @router.patch("/model-endpoints/{ep_id}/models") async def update_hidden_models(ep_id: str, request: Request): - """Bulk update hidden models list for an endpoint. + """Bulk update hidden and/or pinned model lists for an endpoint. - Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]} + Expects JSON body with optional keys: + {"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]} + Each key is updated only when present, so callers can patch one list + without clobbering the other. """ require_admin(request) db = SessionLocal() @@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery): if not ep: raise HTTPException(404, "Endpoint not found") body = await request.json() - hidden = body.get("hidden", []) - if not isinstance(hidden, list): - raise HTTPException(400, "hidden must be a list of model IDs") - ep.hidden_models = json.dumps(hidden) if hidden else None + if not isinstance(body, dict): + raise HTTPException(400, "Body must be a JSON object") + if "hidden" in body: + hidden = body.get("hidden") + if not isinstance(hidden, list): + raise HTTPException(400, "hidden must be a list of model IDs") + ep.hidden_models = json.dumps(hidden) if hidden else None + # Accept either "pinned" or "pinned_models" for the manual IDs list. + if "pinned_models" in body or "pinned" in body: + pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned"))) + ep.pinned_models = json.dumps(pinned) if pinned else None db.commit() _invalidate_models_cache() - return {"id": ep_id, "hidden_count": len(hidden)} + hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0 + pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0 + return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count} finally: db.close() @@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery): return {"endpoint_id": "", "endpoint_url": "", "model": ""} base = _normalize_base(ep.base_url) chat_url = build_chat_url(base) - if not model and getattr(ep, "cached_models", None): + if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)): try: - visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None)) + visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None)) if visible: model = visible[0] except Exception: @@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery): ep.name = body["name"].strip() or ep.name if "model_type" in body and isinstance(body["model_type"], str): ep.model_type = body["model_type"].strip() or ep.model_type + if "pinned_models" in body: + _pinned = _normalize_model_ids(body["pinned_models"]) + ep.pinned_models = json.dumps(_pinned) if _pinned else None # Rotating an API key used to require DELETE+POST, which wiped # endpoint_url/model from every session referencing the old base # URL. Allow in-place updates so the admin can change the key @@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery): "name": ep.name, "model_type": ep.model_type, "base_url": ep.base_url, + "pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)), } finally: db.close() diff --git a/routes/webhook_routes.py b/routes/webhook_routes.py index de20f39..d1372be 100644 --- a/routes/webhook_routes.py +++ b/routes/webhook_routes.py @@ -9,7 +9,9 @@ import httpx from fastapi import APIRouter, HTTPException, Request, Form from pydantic import BaseModel, Field -from core.database import SessionLocal, Webhook +from core.database import SessionLocal, Webhook, ModelEndpoint +from src.auth_helpers import owner_filter +from src.url_security import validate_public_http_url from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events logger = logging.getLogger(__name__) @@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000 from core.middleware import require_admin as _require_admin -def _first_enabled_endpoint(db, owner): - """First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus - legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint - is per-user (core/database.py — "when non-null, the model picker only shows - the endpoint to that user"), and the sync-chat fallback uses the row's - decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token - (e.g. a paired mobile device) fall back onto ANOTHER user's private - endpoint and silently spend that owner's API key / quota — and reach - whatever internal base_url they configured. Mirrors the owner_filter scoping - in routes/model_routes.py and companion/routes.py. A null/empty owner is a - no-op (single-user / legacy mode), preserving the original behaviour. +def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]): + """First enabled ModelEndpoint visible to token_owner — their own rows plus + legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would + let a chat-scoped token fall back onto another user's private endpoint and + silently spend that owner's API key/quota. Prefer owner rows before shared + rows. Fails closed to null-owner rows only when token_owner is absent. + Does not validate base_url — admin-configured local/LAN endpoints remain allowed. """ - from core.database import ModelEndpoint - from src.auth_helpers import owner_filter - q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712 - q = owner_filter(q, ModelEndpoint, owner) - return q.first() + query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712 + if token_owner: + query = owner_filter(query, ModelEndpoint, token_owner) + return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first() + return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711 def _caller_owns_session(sess_owner, caller) -> bool: @@ -278,15 +276,21 @@ def setup_webhook_routes( api_key = body.api_key.strip() model = body.model or "deepseek-chat" - # Resolve base_url: explicit > provider name > model prefix auto-detect - base_url = body.base_url.strip().rstrip("/") if body.base_url else None - if not base_url: + # Validate only token-supplied direct base_url; auto-resolved known-provider + # URLs are not subject to extra local/LAN blocking beyond existing provider logic. + direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None + if direct_base_url: + try: + base_url = validate_public_http_url(direct_base_url) + except ValueError as e: + detail = str(e).replace("URL", "base_url", 1) + raise HTTPException(400, detail) + else: base_url = _resolve_base_url(model, body.provider) if not base_url: raise HTTPException(400, "Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') " "or provider ('deepseek', 'openai', 'groq', etc.)") - base_url = normalize_base(base_url) endpoint_url = build_chat_url(base_url) @@ -306,9 +310,7 @@ def setup_webhook_routes( if not sess: db = SessionLocal() try: - # Owner-scoped: only THIS token owner's endpoints + legacy - # shared rows, never another user's private endpoint/api_key. - ep = _first_enabled_endpoint(db, token_owner) + ep = _select_api_chat_fallback_endpoint(db, token_owner) finally: db.close() diff --git a/services/search/ranking.py b/services/search/ranking.py index 23ea691..771a11a 100644 --- a/services/search/ranking.py +++ b/services/search/ranking.py @@ -2,12 +2,49 @@ import re import logging -from datetime import datetime +from datetime import datetime, timezone from typing import List, Optional from urllib.parse import urlparse logger = logging.getLogger(__name__) +_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S") + + +def _utcnow_naive() -> datetime: + """Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below, + and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116).""" + return datetime.now(timezone.utc).replace(tzinfo=None) + + +def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float: + """Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days. + + The age is measured against UTC, not local time. The previous code used + ``datetime.now()`` (local) against UTC-style published dates, so the age was + skewed by the host's UTC offset; it was also a latent crash once neighbouring + code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests. + """ + if not age_str: + return 0.0 + dt = None + for fmt in _AGE_FORMATS: + try: + dt = datetime.strptime(age_str, fmt) + break + except Exception: + dt = None + if not dt: + return 0.0 + now = now or _utcnow_naive() + days_old = (now - dt).days + if days_old <= 7: + return 1.0 + if days_old >= 30: + return 0.0 + return (30 - days_old) / 23 + + _NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"} _SPORTS_HINTS = { "sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb", @@ -73,24 +110,6 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]: return 0.7 return 0.4 - def recency_score(age_str: Optional[str]) -> float: - if not age_str: - return 0.0 - for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"): - try: - dt = datetime.strptime(age_str, fmt) - break - except Exception: - dt = None - if not dt: - return 0.0 - days_old = (datetime.now() - dt).days - if days_old <= 7: - return 1.0 - if days_old >= 30: - return 0.0 - return (30 - days_old) / 23 - def news_quality_adjustment(title: str, snippet: str, url: str) -> float: if not is_news_query: return 0.0 diff --git a/src/search/ranking.py b/src/search/ranking.py index 771a11a..62e3869 100644 --- a/src/search/ranking.py +++ b/src/search/ranking.py @@ -1,151 +1,13 @@ -"""Search result ranking based on relevance, source quality, and recency.""" +"""Compatibility re-export shim for the live ranking module. -import re -import logging -from datetime import datetime, timezone -from typing import List, Optional -from urllib.parse import urlparse +The real implementation lives in :mod:`services.search.ranking`, which is what +the search runtime (services/search/core.py) imports. This module used to hold a +parallel copy; it now re-exports so the two cannot drift out of sync again. +""" -logger = logging.getLogger(__name__) - -_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S") - - -def _utcnow_naive() -> datetime: - """Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below, - and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116).""" - return datetime.now(timezone.utc).replace(tzinfo=None) - - -def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float: - """Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days. - - The age is measured against UTC, not local time. The previous code used - ``datetime.now()`` (local) against UTC-style published dates, so the age was - skewed by the host's UTC offset; it was also a latent crash once neighbouring - code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests. - """ - if not age_str: - return 0.0 - dt = None - for fmt in _AGE_FORMATS: - try: - dt = datetime.strptime(age_str, fmt) - break - except Exception: - dt = None - if not dt: - return 0.0 - now = now or _utcnow_naive() - days_old = (now - dt).days - if days_old <= 7: - return 1.0 - if days_old >= 30: - return 0.0 - return (30 - days_old) / 23 - - -_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"} -_SPORTS_HINTS = { - "sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb", - "fifa", "world cup", "championship", "quarterfinal", "eliminates", -} -# Word-boundary match so "sport" does not fire inside "transport"/"passport" -# and a domain like "transport.gov" is not mistaken for a sports site. -_SPORTS_HINT_RE = re.compile( - r"\b(?:" + "|".join(re.escape(h) for h in _SPORTS_HINTS) + r")\b" +from services.search.ranking import ( # noqa: F401 + _AGE_FORMATS, + _utcnow_naive, + rank_search_results, + recency_score, ) -_LOW_VALUE_NEWS_DOMAINS = { - "facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com", - "www.yahoo.com", "msn.com", "www.msn.com", -} -_TRUSTED_NEWS_DOMAINS = { - "apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com", - "bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca", - "ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca", - "theguardian.com", - "www.theguardian.com", "euronews.com", "www.euronews.com", - "dw.com", "www.dw.com", "government.se", "www.government.se", -} - - -def _domain(url: str) -> str: - try: - return urlparse(url).netloc.lower() - except Exception: - return "" - - -def rank_search_results(query: str, results: List[dict]) -> List[dict]: - """Rank search results by title relevance, snippet quality, domain authority, and recency.""" - query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)] - query_lc = query.lower() - is_news_query = any(term in _NEWS_HINTS for term in query_terms) - is_sports_query = bool(_SPORTS_HINT_RE.search(query_lc)) - - def title_score(title: str) -> float: - if not title: - return 0.0 - title_lc = title.lower() - matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc)) - return matches / len(query_terms) if query_terms else 0.0 - - def snippet_score(snippet: str) -> float: - if not snippet: - return 0.0 - length_factor = min(len(snippet), 200) / 200 - term_hits = sum(1 for term in query_terms if term in snippet.lower()) - term_factor = term_hits / len(query_terms) if query_terms else 0.0 - return (length_factor + term_factor) / 2 - - def domain_score(url: str) -> float: - netloc = _domain(url) - if not netloc: - return 0.0 - if netloc in _TRUSTED_NEWS_DOMAINS: - return 1.0 - if netloc.endswith(".edu") or netloc.endswith(".gov"): - return 1.0 - if netloc.endswith(".org"): - return 0.7 - return 0.4 - - def news_quality_adjustment(title: str, snippet: str, url: str) -> float: - if not is_news_query: - return 0.0 - text = f"{title} {snippet}".lower() - netloc = _domain(url) - adjustment = 0.0 - if netloc in _TRUSTED_NEWS_DOMAINS: - adjustment += 1.2 - if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")): - adjustment += 0.4 - if netloc in _LOW_VALUE_NEWS_DOMAINS: - adjustment -= 0.8 - if not is_sports_query and (_SPORTS_HINT_RE.search(text) or _SPORTS_HINT_RE.search(netloc)): - adjustment -= 1.5 - # A country/news query should not rank a page whose title/snippet barely - # mentions the country above actual news pages for that country. - subject_terms = [t for t in query_terms if t not in _NEWS_HINTS] - if subject_terms and not any(t in text or t in netloc for t in subject_terms): - adjustment -= 1.0 - return adjustment - - ranked = [] - for result in results: - title = result.get("title", "") - snippet = result.get("snippet", "") - url = result.get("url", "") - age = result.get("age", None) - - score = ( - 2.0 * title_score(title) - + 1.0 * snippet_score(snippet) - + 1.5 * domain_score(url) - + 1.0 * recency_score(age) - + news_quality_adjustment(title, snippet, url) - ) - ranked.append((score, result)) - - ranked.sort(key=lambda x: x[0], reverse=True) - return [r for _, r in ranked] diff --git a/src/url_security.py b/src/url_security.py new file mode 100644 index 0000000..8deb048 --- /dev/null +++ b/src/url_security.py @@ -0,0 +1,94 @@ +"""URL validation helpers for server-side outbound requests.""" + +from __future__ import annotations + +import ipaddress +import socket +from urllib.parse import urlparse + + +_INTERNAL_HOSTNAMES = { + "localhost", + "metadata", + "metadata.google.internal", +} + +_INTERNAL_SUFFIXES = ( + ".localhost", + ".local", + ".internal", + ".lan", + ".intranet", +) + +_BLOCKED_NETWORKS = ( + ipaddress.ip_network("0.0.0.0/8"), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("100.64.0.0/10"), + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("::/128"), + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network("fe80::/10"), +) + + +def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]: + ips: list[ipaddress._BaseAddress] = [] + for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None): + if family in (socket.AF_INET, socket.AF_INET6): + ips.append(ipaddress.ip_address(sockaddr[0])) + return ips + + +def _blocked_ip(addr: ipaddress._BaseAddress) -> bool: + return ( + any(addr in net for net in _BLOCKED_NETWORKS) + or addr.is_private + or addr.is_loopback + or addr.is_link_local + or addr.is_multicast + or addr.is_unspecified + or addr.is_reserved + ) + + +def _host_resolves_publicly(hostname: str) -> bool: + host = hostname.strip().lower() + if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES): + return False + try: + return not _blocked_ip(ipaddress.ip_address(host)) + except ValueError: + pass + try: + addrs = _resolve_hostname_ips(host) + except OSError: + return False + return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs) + + +def is_public_http_url(url: str) -> bool: + parsed = urlparse((url or "").strip()) + if parsed.scheme not in ("http", "https") or not parsed.hostname: + return False + return _host_resolves_publicly(parsed.hostname) + + +def validate_public_http_url(url: str, *, max_length: int = 2048) -> str: + """Validate a user/API-token supplied server-side HTTP(S) endpoint. + + This is for untrusted outbound URLs, not admin-created model endpoints + that are intentionally allowed to point at private model providers. DNS + failures fail closed, and DNS checks reduce obvious private-network + targets but do not eliminate every DNS rebinding race by themselves. + """ + cleaned = (url or "").strip() + if len(cleaned) > max_length: + raise ValueError("URL is too long") + if not is_public_http_url(cleaned): + raise ValueError("URL must point to a public HTTP(S) endpoint") + return cleaned diff --git a/tests/test_api_chat_security.py b/tests/test_api_chat_security.py new file mode 100644 index 0000000..3b94bd5 --- /dev/null +++ b/tests/test_api_chat_security.py @@ -0,0 +1,401 @@ +import ipaddress +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +@pytest.mark.parametrize("url", [ + "http://127.0.0.1:8000/v1", + "http://localhost:8000/v1", + "http://10.0.0.5/v1", + "http://172.16.0.1/v1", + "http://192.168.1.2/v1", + "http://169.254.169.254/latest/meta-data/", + "http://metadata.google.internal/", + "http://[::1]:8000/v1", + "http://[fc00::1]/v1", + "http://224.0.0.1/v1", + "http://0.0.0.0/v1", + "file:///etc/passwd", +]) +def test_public_url_validator_blocks_internal_targets(url): + from src.url_security import is_public_http_url + + assert is_public_http_url(url) is False + + +def test_public_url_validator_allows_public_endpoint(monkeypatch): + from src import url_security + + monkeypatch.setattr( + url_security, + "_resolve_hostname_ips", + lambda host: [ipaddress.ip_address("93.184.216.34")], + ) + + assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1" + + +def test_public_url_validator_blocks_dns_to_private(monkeypatch): + from src import url_security + + monkeypatch.setattr( + url_security, + "_resolve_hostname_ips", + lambda host: [ipaddress.ip_address("10.0.0.5")], + ) + + with pytest.raises(ValueError): + url_security.validate_public_http_url("https://api.example.com/v1") + + +def _load_webhook_routes_for_test(monkeypatch): + # Load under a unique module name so each test gets a fresh module object + # rather than a cached one from a previous monkeypatch run. + core_pkg = types.ModuleType("core") + core_pkg.__path__ = [] + core_db = types.ModuleType("core.database") + core_db.SessionLocal = object + core_db.Webhook = object + core_db.ModelEndpoint = object + core_middleware = types.ModuleType("core.middleware") + core_middleware.require_admin = lambda request: None + webhook_manager = types.ModuleType("src.webhook_manager") + webhook_manager.WebhookManager = object + webhook_manager.validate_webhook_url = lambda url: url + webhook_manager.validate_events = lambda events: events + + monkeypatch.setitem(sys.modules, "core", core_pkg) + monkeypatch.setitem(sys.modules, "core.database", core_db) + monkeypatch.setitem(sys.modules, "core.middleware", core_middleware) + monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager) + + module_name = "routes.webhook_routes_under_test" + spec = importlib.util.spec_from_file_location( + module_name, + Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py", + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class _Expr: + def __init__(self, fn): + self.fn = fn + + def __call__(self, row): + return self.fn(row) + + def __or__(self, other): + return _Expr(lambda row: self(row) or other(row)) + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return _Expr(lambda row: getattr(row, self.name) == other) + + def desc(self): + return ("desc", self.name) + + +class _ModelEndpoint: + is_enabled = _Column("is_enabled") + owner = _Column("owner") + created_at = _Column("created_at") + + +class _Endpoint: + def __init__( + self, + *, + owner, + is_enabled=True, + created_at=1, + base_url="https://api.example.com/v1", + api_key=None, + ): + self.owner = owner + self.is_enabled = is_enabled + self.created_at = created_at + self.base_url = base_url + self.api_key = api_key + + +class _EndpointQuery: + def __init__(self, rows): + self.rows = rows + self.filters = [] + self.orders = [] + + def filter(self, *exprs): + self.filters.extend(exprs) + return self + + def order_by(self, *exprs): + self.orders.extend(exprs) + return self + + def first(self): + rows = self.rows + for expr in self.filters: + rows = [row for row in rows if expr(row)] + # Apply sort keys right-to-left so the leftmost key ends up as the + # primary sort (stable-sort reversal idiom mirrors SQLAlchemy's + # multi-column ORDER BY behaviour). + for order in reversed(self.orders): + reverse = False + name = getattr(order, "name", None) + if isinstance(order, tuple) and order[0] == "desc": + reverse = True + name = order[1] + rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse) + if name != "owner": + rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse) + return rows[0] if rows else None + + +class _DB: + def __init__(self, rows): + self.query_obj = _EndpointQuery(rows) + self.closed = False + + def query(self, model): + assert model is _ModelEndpoint + return self.query_obj + + def close(self): + self.closed = True + + +class _ChatSession: + def __init__(self, endpoint_url, model): + self.endpoint_url = endpoint_url + self.model = model + self.headers = {} + self.history = [] + + def add_message(self, message): + self.history.append(message) + + +class _SessionManager: + def __init__(self): + self.created = [] + self.save_calls = 0 + + def create_session(self, *, session_id, name, endpoint_url, model, owner): + session = _ChatSession(endpoint_url, model) + self.created.append({ + "session_id": session_id, + "name": name, + "endpoint_url": endpoint_url, + "model": model, + "owner": owner, + "session": session, + }) + return session + + def save_sessions(self): + self.save_calls += 1 + + +class _Request: + def __init__(self, *, owner="alice"): + self.state = types.SimpleNamespace( + api_token=True, + api_token_scopes=["chat"], + api_token_owner=owner, + ) + + +class _WebhookManager: + async def fire(self, event, payload): + return None + + +def _install_sync_chat_stubs(monkeypatch): + # FastAPI checks for python_multipart at import time when Form is used; + # stub it so the optional dependency is not required in the test environment. + python_multipart = types.ModuleType("python_multipart") + python_multipart.__version__ = "0.0.13" + core_models = types.ModuleType("core.models") + + class _ChatMessage: + def __init__(self, role, content): + self.role = role + self.content = content + + async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None): + return "mocked response" + + endpoint_resolver = types.ModuleType("src.endpoint_resolver") + endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/") + endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions" + endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models" + endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"} + + llm_core = types.ModuleType("src.llm_core") + llm_core.llm_call_async = _llm_call_async + core_models.ChatMessage = _ChatMessage + + monkeypatch.setitem(sys.modules, "python_multipart", python_multipart) + monkeypatch.setitem(sys.modules, "core.models", core_models) + monkeypatch.setitem(sys.modules, "src.llm_core", llm_core) + monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver) + + +def _sync_chat_endpoint(webhook_routes, session_manager): + router = webhook_routes.setup_webhook_routes( + _WebhookManager(), + auth_manager=None, + session_manager=session_manager, + ) + for route in router.routes: + if route.path == "/api/v1/chat": + return route.endpoint + raise AssertionError("sync chat route not found") + + +@pytest.mark.parametrize("base_url", [ + "http://127.0.0.1:11434/v1", + "http://localhost:11434/v1", + "http://10.0.0.5/v1", + "http://169.254.169.254/latest/meta-data/", +]) +@pytest.mark.asyncio +async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url): + webhook_routes = _load_webhook_routes_for_test(monkeypatch) + _install_sync_chat_stubs(monkeypatch) + session_manager = _SessionManager() + sync_chat = _sync_chat_endpoint(webhook_routes, session_manager) + + body = types.SimpleNamespace( + message="hello", + api_key="test-key", + base_url=base_url, + model="test-model", + provider=None, + session=None, + ) + + with pytest.raises(webhook_routes.HTTPException) as exc: + await sync_chat(_Request(), body) + + assert exc.value.status_code == 400 + assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint" + assert session_manager.created == [] + + +@pytest.mark.asyncio +async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch): + webhook_routes = _load_webhook_routes_for_test(monkeypatch) + _install_sync_chat_stubs(monkeypatch) + + from src import url_security + + monkeypatch.setattr( + url_security, + "_resolve_hostname_ips", + lambda host: [ipaddress.ip_address("93.184.216.34")], + ) + + session_manager = _SessionManager() + sync_chat = _sync_chat_endpoint(webhook_routes, session_manager) + body = types.SimpleNamespace( + message="hello", + api_key="test-key", + base_url="https://api.example.com/v1", + model="test-model", + provider=None, + session=None, + ) + + response = await sync_chat(_Request(), body) + + assert response["response"] == "mocked response" + assert response["model"] == "test-model" + assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions" + + +def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch): + webhook_routes = _load_webhook_routes_for_test(monkeypatch) + rows = [ + _Endpoint(owner="alice", is_enabled=False, created_at=0), + _Endpoint(owner="bob", created_at=0), + _Endpoint(owner=None, created_at=1), + _Endpoint(owner="alice", created_at=2), + ] + + monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint) + + selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice") + + assert selected.owner == "alice" + assert selected.is_enabled is True + assert selected.created_at == 2 + + +def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch): + webhook_routes = _load_webhook_routes_for_test(monkeypatch) + rows = [ + _Endpoint(owner="alice", created_at=0), + _Endpoint(owner=None, is_enabled=False, created_at=1), + _Endpoint(owner=None, created_at=2), + ] + + monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint) + + selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None) + + assert selected.owner is None + assert selected.is_enabled is True + assert selected.created_at == 2 + + +@pytest.mark.asyncio +async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch): + webhook_routes = _load_webhook_routes_for_test(monkeypatch) + _install_sync_chat_stubs(monkeypatch) + local_endpoint = _Endpoint( + owner=None, + base_url="http://localhost:11434/v1", + api_key="configured-key", + ) + db = _DB([local_endpoint]) + calls = [] + + def _session_local(): + return db + + def _validate_public_http_url(url, *, max_length=2048): + calls.append(url) + raise AssertionError("configured fallback endpoint should not be publicly validated") + + monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint) + monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local) + monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url) + + session_manager = _SessionManager() + sync_chat = _sync_chat_endpoint(webhook_routes, session_manager) + body = types.SimpleNamespace( + message="hello", + model="local-model", + api_key=None, + base_url=None, + provider=None, + session=None, + ) + + response = await sync_chat(_Request(owner=None), body) + + assert response["response"] == "mocked response" + assert response["model"] == "local-model" + assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions" + assert calls == [] diff --git a/tests/test_gpu_compose_standalone.py b/tests/test_gpu_compose_standalone.py new file mode 100644 index 0000000..57bdaf3 --- /dev/null +++ b/tests/test_gpu_compose_standalone.py @@ -0,0 +1,147 @@ +"""Guards the standalone GPU compose files against drift. + +Stack-management UIs (Portainer, Coolify, Dockhand, ...) often accept only a +single compose file and do not honor COMPOSE_FILE or multiple ``-f`` overlays, +so the repo ships standalone ``docker-compose.gpu-*.yml`` files that inline the +GPU overlay. The base ``docker-compose.yml`` plus ``docker/gpu.*.yml`` overlays +remain the source of truth; these tests assert each standalone file equals the +base compose with only the matching overlay merged into the ``odysseus`` +service. No Docker / docker compose is required — everything is pure YAML. +""" + +import copy +from pathlib import Path + +import pytest +import yaml + +ROOT = Path(__file__).resolve().parents[1] + +BASE = ROOT / "docker-compose.yml" +NVIDIA_OVERLAY = ROOT / "docker" / "gpu.nvidia.yml" +AMD_OVERLAY = ROOT / "docker" / "gpu.amd.yml" +NVIDIA_STANDALONE = ROOT / "docker-compose.gpu-nvidia.yml" +AMD_STANDALONE = ROOT / "docker-compose.gpu-amd.yml" + +SERVICE = "odysseus" + + +def _load(path: Path) -> dict: + return yaml.safe_load(path.read_text(encoding="utf-8")) + + +def _deep_merge(base: dict, overlay: dict) -> dict: + """Mirror docker compose overlay semantics for the keys these files use. + + Mappings merge recursively; list-valued service fields are concatenated + (compose appends override sequences such as ``environment`` rather than + replacing them); scalars are overwritten. The overlays here only append to + ``environment`` and add otherwise-absent keys (``deploy``, ``devices``, + ``group_add``), so this keeps the expected merge explicit without invoking + docker compose. + """ + result = copy.deepcopy(base) + for key, value in overlay.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result[key], value) + elif isinstance(value, list) and isinstance(result.get(key), list): + result[key] = copy.deepcopy(result[key]) + copy.deepcopy(value) + else: + result[key] = copy.deepcopy(value) + return result + + +def _merge_overlay_into_base(base: dict, overlay: dict) -> dict: + """Build the expected standalone config: base + overlay on odysseus only.""" + expected = copy.deepcopy(base) + overlay_service = overlay["services"][SERVICE] + expected["services"][SERVICE] = _deep_merge( + expected["services"][SERVICE], overlay_service + ) + return expected + + +@pytest.fixture(scope="module") +def base(): + return _load(BASE) + + +# --- Equivalence: standalone == base + overlay ----------------------------- + + +def test_nvidia_standalone_equals_base_plus_overlay(base): + overlay = _load(NVIDIA_OVERLAY) + standalone = _load(NVIDIA_STANDALONE) + assert standalone == _merge_overlay_into_base(base, overlay) + + +def test_amd_standalone_equals_base_plus_overlay(base): + overlay = _load(AMD_OVERLAY) + standalone = _load(AMD_STANDALONE) + assert standalone == _merge_overlay_into_base(base, overlay) + + +# --- Non-odysseus services and volumes untouched --------------------------- + + +@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE]) +def test_non_odysseus_services_match_base(base, standalone_path): + standalone = _load(standalone_path) + for name, definition in base["services"].items(): + if name == SERVICE: + continue + assert standalone["services"][name] == definition + assert set(standalone["services"]) == set(base["services"]) + + +@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE]) +def test_top_level_volumes_match_base(base, standalone_path): + standalone = _load(standalone_path) + assert standalone.get("volumes") == base.get("volumes") + + +# --- odysseus = base service + only the overlay additions ------------------ + + +def test_nvidia_odysseus_adds_only_overlay(base): + standalone = _load(NVIDIA_STANDALONE) + svc = standalone["services"][SERVICE] + base_svc = base["services"][SERVICE] + + # Base environment preserved, plus exactly the two NVIDIA variables. + assert "NVIDIA_VISIBLE_DEVICES=all" in svc["environment"] + assert "NVIDIA_DRIVER_CAPABILITIES=compute,utility" in svc["environment"] + added_env = set(svc["environment"]) - set(base_svc["environment"]) + assert added_env == { + "NVIDIA_VISIBLE_DEVICES=all", + "NVIDIA_DRIVER_CAPABILITIES=compute,utility", + } + + # deploy block is new and matches the overlay's GPU reservation exactly. + assert "deploy" not in base_svc + devices = svc["deploy"]["resources"]["reservations"]["devices"] + assert devices == [ + {"driver": "nvidia", "count": "all", "capabilities": ["gpu"]} + ] + + # No AMD-only keys leaked in. + assert "devices" not in svc + assert "group_add" not in svc + + +def test_amd_odysseus_adds_only_overlay(base): + standalone = _load(AMD_STANDALONE) + svc = standalone["services"][SERVICE] + base_svc = base["services"][SERVICE] + + # Environment is unchanged from base for AMD. + assert svc["environment"] == base_svc["environment"] + + # devices and group_add are new and match the overlay exactly. + assert "devices" not in base_svc + assert "group_add" not in base_svc + assert svc["devices"] == ["/dev/kfd", "/dev/dri"] + assert svc["group_add"] == ["video", "${RENDER_GID:-render}"] + + # No NVIDIA-only keys leaked in. + assert "deploy" not in svc diff --git a/tests/test_llm_core_temperature.py b/tests/test_llm_core_temperature.py index 09abf8a..00be525 100644 --- a/tests/test_llm_core_temperature.py +++ b/tests/test_llm_core_temperature.py @@ -66,3 +66,37 @@ def test_normal_model_payload_keeps_temperature(monkeypatch): payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2) assert payload["temperature"] == 0.2 assert payload["max_tokens"] == 5 + + +def test_normal_model_payload_keeps_temperature_above_one(monkeypatch): + # OpenAI/local providers may validly use temperatures above 1.0; the clamp + # is Anthropic-only and must not touch this path. + payload = _capture_openai_payload(monkeypatch, "gpt-4o", 1.2) + assert payload["temperature"] == 1.2 + + +def _anthropic_payload(temperature): + return llm_core._build_anthropic_payload( + "claude-3-5-sonnet", + [{"role": "user", "content": "Hi"}], + temperature, + max_tokens=5, + ) + + +def test_anthropic_payload_clamps_above_one(): + # Anthropic rejects temperature > 1.0 (e.g. the Nietzsche preset's 1.2). + assert _anthropic_payload(1.2)["temperature"] == 1.0 + + +def test_anthropic_payload_keeps_in_range(): + assert _anthropic_payload(0.7)["temperature"] == 0.7 + + +def test_anthropic_payload_clamps_negative(): + assert _anthropic_payload(-0.5)["temperature"] == 0.0 + + +def test_anthropic_payload_none_temperature_does_not_crash(): + payload = _anthropic_payload(None) + assert payload["temperature"] is None diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index be767e4..48d6293 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -1,6 +1,9 @@ """Tests for model route helper functions — pure logic, no server needed.""" +import asyncio +import json import sys import types +from types import SimpleNamespace from unittest.mock import MagicMock import httpx @@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver from routes.model_routes import ( _match_provider_curated, _curate_models, + _visible_models, + _normalize_model_ids, _is_chat_model, _classify_endpoint, _probe_endpoint, @@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable: monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail) assert model_routes._docker_host_gateway_reachable() is False + + +# ── pinned model IDs: normalization helper ── + + +class TestNormalizeModelIds: + def test_list_passthrough_trims_and_dedupes(self): + assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"] + + def test_json_string_list(self): + assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"] + + def test_comma_and_newline_string(self): + assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"] + + def test_none_and_empty(self): + assert _normalize_model_ids(None) == [] + assert _normalize_model_ids("") == [] + assert _normalize_model_ids(" ") == [] + + def test_non_string_values_ignored(self): + assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"] + + +# ── pinned model IDs: _visible_models merge ── + + +class TestVisibleModelsPinned: + def test_includes_pinned_not_in_cached(self): + visible = _visible_models(["a"], None, ["deploy-1"]) + assert visible == ["a", "deploy-1"] + + def test_cached_plus_pinned_dedup_preserves_order(self): + visible = _visible_models(["a", "b"], None, ["b", "c"]) + assert visible == ["a", "b", "c"] + + def test_hidden_can_hide_a_pinned_model(self): + visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"]) + assert visible == ["a"] + + def test_accepts_json_string_inputs(self): + visible = _visible_models('["a"]', '["a"]', '["b"]') + assert visible == ["b"] + + +# ── pinned model IDs: route behaviour ── + +# Building the router exercises FastAPI's Form() routes, which require +# python-multipart. The test env ships without it, so register a minimal stub +# (mirrors tests/test_review_regressions.py) only when it's genuinely missing. +if "python_multipart" not in sys.modules: + try: + import python_multipart # noqa: F401 + except ImportError: + _mp_stub = types.ModuleType("python_multipart") + _mp_stub.__version__ = "0.0.13" + sys.modules["python_multipart"] = _mp_stub + + +class _PinnedFakeQuery: + def __init__(self, rows): + self.rows = list(rows) + + def filter(self, *conditions): + return self + + def order_by(self, *args): + return self + + def first(self): + return self.rows[0] if self.rows else None + + def all(self): + return list(self.rows) + + +class _PinnedFakeDb: + def __init__(self, rows): + self.rows = rows + self.added = [] + self.committed = 0 + + def query(self, model): + return _PinnedFakeQuery(self.rows) + + def add(self, row): + self.added.append(row) + + def commit(self): + self.committed += 1 + + def close(self): + pass + + +class _FakeCol: + """Column stand-in: every comparison/operator just returns itself so the + dedupe query expressions evaluate without a real SQLAlchemy column.""" + + __hash__ = None + + def __eq__(self, other): + return self + + def is_(self, other): + return self + + def __or__(self, other): + return self + + def desc(self): + return self + + +class _RecordingEndpoint: + """ModelEndpoint stand-in that stores constructor kwargs as attributes. + + Class-level fake columns let it double as the query class in the dedupe + lookup; instance attributes (set in __init__) shadow them per-row. + """ + + id = _FakeCol() + base_url = _FakeCol() + owner = _FakeCol() + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +class _PinnedFakeRequest: + def __init__(self, body=None, headers=None): + self._body = body if body is not None else {} + self.headers = headers or {} + + async def json(self): + return self._body + + +def _get_route(path, method): + from routes.model_routes import setup_model_routes + router = setup_model_routes(model_discovery=None) + for route in router.routes: + if getattr(route, "path", "") == path and method in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError(f"{method} {path} not found") + + +def _make_endpoint(**kwargs): + base = dict( + id="ep1", + name="EP", + base_url="http://localhost:9999/v1", + api_key=None, + is_enabled=True, + hidden_models=None, + cached_models=None, + pinned_models=None, + model_type="llm", + supports_tools=None, + ) + base.update(kwargs) + return SimpleNamespace(**base) + + +def test_patch_models_saves_pinned_models(monkeypatch): + ep = _make_endpoint() + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH") + + request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]}) + result = asyncio.run(endpoint("ep1", request)) + + assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"] + assert result["pinned_count"] == 2 + + +def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch): + ep = _make_endpoint(hidden_models=json.dumps(["hide-me"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH") + + request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]}) + asyncio.run(endpoint("ep1", request)) + + assert json.loads(ep.hidden_models) == ["hide-me"] + assert json.loads(ep.pinned_models) == ["deploy-1"] + + +def test_get_models_returns_pinned_when_probe_empty(monkeypatch): + ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: []) + endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET") + + result = endpoint("ep1", _PinnedFakeRequest()) + + ids = [row["id"] for row in result] + assert ids == ["deploy-1"] + assert result[0]["is_pinned"] is True + + +def test_reprobe_preserves_pinned_models(monkeypatch): + ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"])) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"]) + monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True) + monkeypatch.setattr( + model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"} + ) + endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET") + + response = endpoint("ep1", _PinnedFakeRequest()) + + async def _drain(): + async for _ in response.body_iterator: + pass + + asyncio.run(_drain()) + + # Probe rewrites cached/hidden but must never touch admin-pinned IDs. + assert json.loads(ep.pinned_models) == ["deploy-1"] + assert json.loads(ep.cached_models) == ["m1"] + + +def test_visible_models_handles_malformed_strings(): + # Non-JSON cached/pinned strings are treated as comma/newline lists and + # never raise; a malformed hidden string is normalized too. + result = _visible_models("a,b", "b", "{bad json") + assert isinstance(result, list) + assert result == ["a", "{bad json"] + assert _visible_models("", None, "") == [] + assert _visible_models("only-cached", None, None) == ["only-cached"] + + +def _create_form_kwargs(**overrides): + """Defaults for every Form() param create_model_endpoint reads directly. + + Calling the route as a plain function bypasses FastAPI form parsing, so the + Form() sentinels must be replaced with real strings. + """ + kwargs = dict( + name="", + api_key="", + skip_probe="true", # avoid any network probe in unit tests + require_models="false", + model_type="llm", + supports_tools="", + pinned_models="", + container_local="false", + shared="true", + ) + kwargs.update(overrides) + return kwargs + + +def _patch_create_deps(monkeypatch, db): + import src.auth_helpers as auth_helpers + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint) + monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b) + monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b) + monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"}) + monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u) + monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None) + + +def test_post_creates_endpoint_with_pinned_models(monkeypatch): + db = _PinnedFakeDb([]) # no existing row → fresh create path + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"), + ) + + assert result["pinned_models"] == ["deploy-1", "deploy-2"] + assert result["models"] == ["deploy-1", "deploy-2"] + assert result["online"] is True + # Persisted onto the created row. + assert len(db.added) == 1 + assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"] + + +def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch): + existing = _make_endpoint( + cached_models=json.dumps(["m1"]), + hidden_models=None, + pinned_models=json.dumps(["old-pin"]), + ) + db = _PinnedFakeDb([existing]) + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(pinned_models="new-pin"), + ) + + assert result["existing"] is True + # Incoming pin merged onto the existing pins (no clobber, order preserved). + assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"] + assert result["pinned_models"] == ["old-pin", "new-pin"] + # models = cached + pinned - hidden, visible merged list. + assert result["models"] == ["m1", "old-pin", "new-pin"] + # No new row created on the dedupe path. + assert db.added == [] + + +def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch): + existing = _make_endpoint( + cached_models=json.dumps(["m1"]), + pinned_models=json.dumps(["keep-me"]), + ) + db = _PinnedFakeDb([existing]) + _patch_create_deps(monkeypatch, db) + create = _get_route("/api/model-endpoints", "POST") + + result = create( + _PinnedFakeRequest(), + base_url="http://host:1234/v1", + **_create_form_kwargs(), # pinned_models defaults to "" + ) + + assert json.loads(existing.pinned_models) == ["keep-me"] + assert result["pinned_models"] == ["keep-me"] + assert db.committed == 0 # nothing to persist diff --git a/tests/test_null_owner_gates.py b/tests/test_null_owner_gates.py index 57b98a8..84ecff0 100644 --- a/tests/test_null_owner_gates.py +++ b/tests/test_null_owner_gates.py @@ -247,10 +247,14 @@ class _Column: def __eq__(self, value): return _Predicate(lambda row: getattr(row, self.name) == value) + def desc(self): + return self + class _ModelEndpoint: is_enabled = _Column("is_enabled") owner = _Column("owner") + created_at = _Column("created_at") class _Query: @@ -261,6 +265,9 @@ class _Query: self._rows = [r for r in self._rows if all(p(r) for p in predicates)] return self + def order_by(self, *exprs): + return self + def first(self): return self._rows[0] if self._rows else None @@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True): def _select(rows, owner): wh_mod = _import_webhook_helper() - sys.modules["core.database"].ModelEndpoint = _ModelEndpoint - return wh_mod._first_enabled_endpoint(_DB(rows), owner) + # _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint + # (not a local import), so we patch the module attribute directly. + wh_mod.ModelEndpoint = _ModelEndpoint + return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner) def test_sync_chat_fallback_never_picks_another_owners_endpoint(): @@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint(): assert ep is not None and ep.name == "shared" -def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop(): - # An unresolvable/empty token owner keeps the original single-user behaviour - # (owner_filter no-op): first enabled row, whatever it is. - rows = [_ep("first", "bob"), _ep("second", "alice")] +def test_sync_chat_fallback_null_owner_uses_shared_rows_only(): + # When no token owner is known, only null-owner (shared) endpoints are + # visible — private endpoints of any user must not be returned. + rows = [_ep("bob-private", "bob"), _ep("shared", None)] ep = _select(rows, None) - assert ep is not None and ep.name == "first" + assert ep is not None and ep.name == "shared" + + +def test_sync_chat_fallback_null_owner_returns_none_with_no_shared(): + # No shared rows → fail closed rather than returning another user's endpoint. + rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")] + assert _select(rows, None) is None diff --git a/tests/test_search_ranking.py b/tests/test_search_ranking.py index f361bd6..b10bf50 100644 --- a/tests/test_search_ranking.py +++ b/tests/test_search_ranking.py @@ -1,4 +1,4 @@ -from src.search.ranking import rank_search_results +from services.search.ranking import rank_search_results def test_news_queries_prefer_news_sources_over_sports_and_social_results(): diff --git a/tests/test_search_ranking_recency.py b/tests/test_search_ranking_recency.py index 64e59d4..e0cfd66 100644 --- a/tests/test_search_ranking_recency.py +++ b/tests/test_search_ranking_recency.py @@ -8,7 +8,8 @@ module-level, time-injectable function. from datetime import datetime, timezone -from src.search.ranking import recency_score, _utcnow_naive +import services.search.ranking as live_ranking +from services.search.ranking import recency_score, _utcnow_naive, rank_search_results def test_fresh_result_scores_one(): @@ -37,3 +38,37 @@ def test_default_now_is_naive_utc(): assert now.tzinfo is None reference = datetime.now(timezone.utc).replace(tzinfo=None) assert abs((now - reference).total_seconds()) < 5 + + +def test_supported_timestamp_formats_parse(): + # All three formats the current implementation supports resolve to the same + # ~4-day-old age, so each scores a full 1.0. + now = datetime(2026, 1, 5, 12, 0, 0) + assert recency_score("2026-01-01", now=now) == 1.0 + assert recency_score("2026-01-01T08:30:00", now=now) == 1.0 + assert recency_score("2026-01-01 08:30:00", now=now) == 1.0 + + +def test_shim_reexports_live_objects(): + # src.search.ranking is a compatibility shim; it must expose the *same* + # objects as the live services module so the two cannot diverge. + import src.search.ranking as shim + + assert shim.recency_score is live_ranking.recency_score + assert shim.rank_search_results is live_ranking.rank_search_results + assert shim._utcnow_naive is live_ranking._utcnow_naive + + +def test_live_rank_path_prefers_newer_result(monkeypatch): + # Pin "now" so age scoring is deterministic. The two results are identical + # apart from age, isolating recency as the only differentiator. + monkeypatch.setattr(live_ranking, "_utcnow_naive", lambda: datetime(2026, 1, 31)) + results = [ + {"title": "Report", "url": "https://example.org/a", "snippet": "x", "age": "2026-01-01"}, + {"title": "Report", "url": "https://example.org/b", "snippet": "x", "age": "2026-01-29"}, + ] + + ranked = rank_search_results("report", results) + + assert ranked[0]["url"] == "https://example.org/b" + assert ranked[1]["url"] == "https://example.org/a"