Reapply "Merge branch 'main' of github.com:pewdiepie-archdaemon/odysseus"

This reverts commit cc8fe2f6e3.
This commit is contained in:
pewdiepie-archdaemon
2026-06-03 22:47:00 +09:00
parent cc8fe2f6e3
commit 6861c41580
16 changed files with 1647 additions and 225 deletions

View File

@@ -189,6 +189,20 @@ RENDER_GID=989
For NVIDIA/AMD GPU support, also read the comments in the selected overlay file: docker/gpu.nvidia.yml or docker/gpu.amd.yml. For NVIDIA/AMD GPU support, also read the comments in the selected overlay file: docker/gpu.nvidia.yml or docker/gpu.amd.yml.
**Stack-management UIs (Portainer, Coolify, Dockhand, etc.).** These tools
often accept only a single Compose file and do not reliably honor `COMPOSE_FILE`
or multiple `-f` overlays. CLI users should keep using the `COMPOSE_FILE`
overlay workflow above. For stack UIs, point the stack at one of the standalone
files instead, which bundle the base stack plus the GPU settings:
- `docker-compose.gpu-nvidia.yml` — still requires the NVIDIA Container Toolkit
on the host.
- `docker-compose.gpu-amd.yml` — still requires host ROCm/kfd/DRI setup, the
`video`/`render` group membership, and `RENDER_GID` when needed.
The base `docker-compose.yml` plus the `docker/gpu.*.yml` overlays remain the
source of truth; the standalone files mirror them for single-file deployments.
Verify after enabling either overlay: Verify after enabling either overlay:
```bash ```bash

View File

@@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base):
is_enabled = Column(Boolean, default=True) is_enabled = Column(Boolean, default=True)
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
model_type = Column(String, nullable=True, default="llm") # "llm" or "image" model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
# Whether models on this endpoint accept OpenAI-style function # Whether models on this endpoint accept OpenAI-style function
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto- # schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
@@ -856,6 +857,24 @@ def _migrate_add_cached_models_column():
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
def _migrate_add_pinned_models_column():
"""Add pinned_models column to model_endpoints if it doesn't exist."""
import sqlite3
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
columns = [row[1] for row in cursor.fetchall()]
if columns and "pinned_models" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
def _migrate_add_notes_sort_order(): def _migrate_add_notes_sort_order():
"""Add sort_order, image_url, repeat columns to notes if they don't exist.""" """Add sort_order, image_url, repeat columns to notes if they don't exist."""
import sqlite3 import sqlite3
@@ -1511,6 +1530,7 @@ def init_db():
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
_migrate_add_hidden_models_column() _migrate_add_hidden_models_column()
_migrate_add_cached_models_column() _migrate_add_cached_models_column()
_migrate_add_pinned_models_column()
_migrate_add_notes_sort_order() _migrate_add_notes_sort_order()
_migrate_add_model_type_column() _migrate_add_model_type_column()
_migrate_add_model_endpoint_owner_column() _migrate_add_model_endpoint_owner_column()

164
docker-compose.gpu-amd.yml Normal file
View File

@@ -0,0 +1,164 @@
# Standalone AMD ROCm GPU Compose file for stack-management UIs (Portainer,
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
#
# This is equivalent to: docker-compose.yml + docker/gpu.amd.yml.
# The base docker-compose.yml plus the docker/gpu.amd.yml overlay remain the
# source of truth — CLI users should keep using the COMPOSE_FILE overlay
# workflow. Keep this file in sync with both when either changes.
#
# Requires ROCm drivers on the host (kfd + DRI devices) and the host user
# running Docker in the `video` and `render` groups. Set RENDER_GID to your
# host's numeric render group id when needed. See docker/gpu.amd.yml for details.
services:
odysseus:
build: .
ports:
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
volumes:
- ./data:/app/data:z
- ./logs:/app/logs:z
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
# add the shown public key to each remote server's authorized_keys.
- ./data/ssh:/app/.ssh:z
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
# container, so persist its HuggingFace cache under ./data/huggingface.
- ./data/huggingface:/app/.cache/huggingface:z
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
# land under /app/.local for the odysseus user. Persist them so a
# container recreate does not silently remove installed serve engines.
- ./data/local:/app/.local:z
extra_hosts:
# Lets the container reach local services on the Docker host, including
# Ollama at http://host.docker.internal:11434.
- "host.docker.internal:host-gateway"
environment:
- LLM_HOST=${LLM_HOST:-localhost}
- LLM_HOSTS=${LLM_HOSTS:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
- HF_TOKEN=${HF_TOKEN:-}
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
- SEARXNG_INSTANCE=http://searxng:8080
- CHROMADB_HOST=chromadb
- CHROMADB_PORT=8000
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
- AUTH_ENABLED=${AUTH_ENABLED:-true}
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
- SERPER_API_KEY=${SERPER_API_KEY:-}
# PUID / PGID — the user/group the container drops to before
# running uvicorn (entrypoint also chowns /app/data + /app/logs
# to match, so bind-mounted files stay editable from the host).
# 1000 is the default first user on most Linux installs. If your
# host user has a different id, override here or via .env, e.g.:
# PUID=1001
# PGID=1001
# Find yours with: id -u / id -g
- PUID=${PUID:-1000}
- PGID=${PGID:-1000}
depends_on:
searxng:
condition: service_healthy
chromadb:
condition: service_started
restart: unless-stopped
# AMD ROCm overlay (from docker/gpu.amd.yml).
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
- ${RENDER_GID:-render}
chromadb:
image: docker.io/chromadb/chroma:latest
ports:
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
volumes:
- chromadb-data:/chroma/chroma
environment:
- ANONYMIZED_TELEMETRY=FALSE
restart: unless-stopped
searxng:
# Pinned, not :latest — odysseus waits on searxng's healthcheck
# (depends_on: condition: service_healthy), so a broken upstream `latest`
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
# Bump this deliberately after verifying a newer tag boots clean.
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
entrypoint:
- /bin/sh
- -c
- |
set -eu
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
secret="$${SEARXNG_SECRET:-}"
if [ -z "$$secret" ]; then
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
fi
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
fi
exec /usr/local/searxng/entrypoint.sh
ports:
- "127.0.0.1:8080:8080"
volumes:
- searxng-data:/etc/searxng
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
environment:
- SEARXNG_BASE_URL=http://localhost:8080/
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
# The official searxng image runs as the non-root `searxng` user, but its
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
# su-exec, and (with our wrapper above) write settings.yml into the named
# volume. Without these capabilities the wrapper aborts at the redirection
# with EACCES and the container fails its healthcheck with permission
# errors during setup. Mirrors the cap set recommended by the upstream
# searxng-docker compose file. See issue #721.
cap_drop:
- ALL
cap_add:
- CHOWN
- SETGID
- SETUID
- DAC_OVERRIDE
healthcheck:
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
interval: 5s
timeout: 6s
retries: 20
start_period: 10s
restart: unless-stopped
ntfy:
image: docker.io/binwiederhier/ntfy
command: serve
ports:
- "${NTFY_BIND:-127.0.0.1}:8091:80"
volumes:
- ntfy-cache:/var/cache/ntfy
environment:
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
restart: unless-stopped
volumes:
searxng-data:
chromadb-data:
ntfy-cache:

View File

@@ -0,0 +1,167 @@
# Standalone NVIDIA GPU Compose file for stack-management UIs (Portainer,
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
#
# This is equivalent to: docker-compose.yml + docker/gpu.nvidia.yml.
# The base docker-compose.yml plus the docker/gpu.nvidia.yml overlay remain
# the source of truth — CLI users should keep using the COMPOSE_FILE overlay
# workflow. Keep this file in sync with both when either changes.
#
# Requires the NVIDIA Container Toolkit on the host. See docker/gpu.nvidia.yml
# for setup details.
services:
odysseus:
build: .
ports:
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
volumes:
- ./data:/app/data:z
- ./logs:/app/logs:z
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
# add the shown public key to each remote server's authorized_keys.
- ./data/ssh:/app/.ssh:z
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
# container, so persist its HuggingFace cache under ./data/huggingface.
- ./data/huggingface:/app/.cache/huggingface:z
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
# land under /app/.local for the odysseus user. Persist them so a
# container recreate does not silently remove installed serve engines.
- ./data/local:/app/.local:z
extra_hosts:
# Lets the container reach local services on the Docker host, including
# Ollama at http://host.docker.internal:11434.
- "host.docker.internal:host-gateway"
environment:
- LLM_HOST=${LLM_HOST:-localhost}
- LLM_HOSTS=${LLM_HOSTS:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
- HF_TOKEN=${HF_TOKEN:-}
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
- SEARXNG_INSTANCE=http://searxng:8080
- CHROMADB_HOST=chromadb
- CHROMADB_PORT=8000
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
- AUTH_ENABLED=${AUTH_ENABLED:-true}
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
- SERPER_API_KEY=${SERPER_API_KEY:-}
# PUID / PGID — the user/group the container drops to before
# running uvicorn (entrypoint also chowns /app/data + /app/logs
# to match, so bind-mounted files stay editable from the host).
# 1000 is the default first user on most Linux installs. If your
# host user has a different id, override here or via .env, e.g.:
# PUID=1001
# PGID=1001
# Find yours with: id -u / id -g
- PUID=${PUID:-1000}
- PGID=${PGID:-1000}
# NVIDIA overlay (from docker/gpu.nvidia.yml).
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
depends_on:
searxng:
condition: service_healthy
chromadb:
condition: service_started
restart: unless-stopped
# NVIDIA overlay (from docker/gpu.nvidia.yml).
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
chromadb:
image: docker.io/chromadb/chroma:latest
ports:
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
volumes:
- chromadb-data:/chroma/chroma
environment:
- ANONYMIZED_TELEMETRY=FALSE
restart: unless-stopped
searxng:
# Pinned, not :latest — odysseus waits on searxng's healthcheck
# (depends_on: condition: service_healthy), so a broken upstream `latest`
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
# Bump this deliberately after verifying a newer tag boots clean.
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
entrypoint:
- /bin/sh
- -c
- |
set -eu
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
secret="$${SEARXNG_SECRET:-}"
if [ -z "$$secret" ]; then
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
fi
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
fi
exec /usr/local/searxng/entrypoint.sh
ports:
- "127.0.0.1:8080:8080"
volumes:
- searxng-data:/etc/searxng
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
environment:
- SEARXNG_BASE_URL=http://localhost:8080/
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
# The official searxng image runs as the non-root `searxng` user, but its
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
# su-exec, and (with our wrapper above) write settings.yml into the named
# volume. Without these capabilities the wrapper aborts at the redirection
# with EACCES and the container fails its healthcheck with permission
# errors during setup. Mirrors the cap set recommended by the upstream
# searxng-docker compose file. See issue #721.
cap_drop:
- ALL
cap_add:
- CHOWN
- SETGID
- SETUID
- DAC_OVERRIDE
healthcheck:
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
interval: 5s
timeout: 6s
retries: 20
start_period: 10s
restart: unless-stopped
ntfy:
image: docker.io/binwiederhier/ntfy
command: serve
ports:
- "${NTFY_BIND:-127.0.0.1}:8091:80"
volumes:
- ntfy-cache:/var/cache/ntfy
environment:
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
restart: unless-stopped
volumes:
searxng-data:
chromadb-data:
ntfy-cache:

View File

@@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
return "No models found for that provider/key." return "No models found for that provider/key."
def _visible_models(cached_models, hidden_models): def _normalize_model_ids(value):
"""Filter cached model IDs by hidden_models. Returns list of visible IDs.""" """Coerce a model-ID input into a clean, ordered list of strings.
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
Accepts a list, a JSON-encoded list string, or a comma/newline separated
string (handy for form or backend API input). Trims whitespace, drops
empty and non-string values, and de-duplicates preserving first-seen order.
"""
if value is None:
return []
items = value
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
except Exception:
parsed = None
items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text)
if not isinstance(items, list):
return []
out, seen = [], set()
for item in items:
if not isinstance(item, str):
continue
s = item.strip()
if not s or s in seen:
continue
seen.add(s)
out.append(s)
return out
def _merge_model_ids(*lists):
"""Concatenate model-ID lists, de-duplicating and preserving order."""
out, seen = [], set()
for ids in lists:
for m in (ids or []):
if not isinstance(m, str) or m in seen:
continue
seen.add(m)
out.append(m)
return out
def _visible_models(cached_models, hidden_models, pinned_models=None):
"""Merge cached + pinned model IDs, then filter out hidden ones.
Pinned IDs are admin-entered and may not appear in cached_models (e.g.
cloud deployment IDs the provider does not list in /v1/models). Returns an
ordered, de-duplicated list of visible IDs.
"""
# Normalize each input so JSON strings, lists, comma/newline strings, and
# malformed strings are all handled without raising.
merged = _merge_model_ids(
_normalize_model_ids(cached_models),
_normalize_model_ids(pinned_models),
)
if not hidden_models: if not hidden_models:
return all_models return merged
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or [])) hidden = set(_normalize_model_ids(hidden_models))
return [m for m in all_models if m not in hidden] return [m for m in merged if m not in hidden]
def setup_model_routes(model_discovery): def setup_model_routes(model_discovery):
@@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(r.hidden_models)) hidden = set(json.loads(r.hidden_models))
except Exception: except Exception:
pass pass
visible = [m for m in all_models if m not in hidden] pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
status = "online" if all_models else "offline" visible = _visible_models(all_models, r.hidden_models, pinned)
# Endpoint counts as reachable if it has any model — including
# admin-pinned IDs that a probe would never surface.
status = "online" if (all_models or pinned) else "offline"
ping = None ping = None
if not all_models and r.is_enabled: if not all_models and not pinned and r.is_enabled:
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"): if ping.get("reachable"):
status = "empty" status = "empty"
@@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery):
"has_key": bool(r.api_key), "has_key": bool(r.api_key),
"is_enabled": r.is_enabled, "is_enabled": r.is_enabled,
"models": visible, "models": visible,
"pinned_models": pinned,
"hidden_count": len(hidden), "hidden_count": len(hidden),
"online": status != "offline", "online": status != "offline",
"status": status, "status": status,
@@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery):
require_models: str = Form("false"), require_models: str = Form("false"),
model_type: str = Form("llm"), model_type: str = Form("llm"),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown) supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
container_local: str = Form("false"), container_local: str = Form("false"),
# Default `shared=true` → endpoints are visible to all users (the # Default `shared=true` → endpoints are visible to all users (the
# app's historical behaviour). Admins can pass `shared=false` to # app's historical behaviour). Admins can pass `shared=false` to
@@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery):
.first() .first()
) )
if existing: if existing:
# Persist any incoming pinned IDs onto the existing row. An
# empty/omitted form field must not wipe previously pinned IDs.
_incoming_pinned = _normalize_model_ids(pinned_models)
if _incoming_pinned:
_merged_pinned = _merge_model_ids(
_normalize_model_ids(getattr(existing, "pinned_models", None)),
_incoming_pinned,
)
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
_db_dedup.commit()
_invalidate_models_cache()
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
return { return {
"id": existing.id, "id": existing.id,
"name": existing.name, "name": existing.name,
"base_url": existing.base_url, "base_url": existing.base_url,
"models": json.loads(existing.cached_models) if existing.cached_models else [], "models": _visible_models(
getattr(existing, "cached_models", None),
getattr(existing, "hidden_models", None),
existing.pinned_models,
),
"pinned_models": _existing_pinned,
"online": True, "online": True,
"status": "online", "status": "online",
"existing": True, "existing": True,
@@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery):
try: try:
_st_raw = (supports_tools or "").strip().lower() _st_raw = (supports_tools or "").strip().lower()
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None) _st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
_pinned = _normalize_model_ids(pinned_models)
# Stamp owner so the picker only shows this endpoint to the admin # Stamp owner so the picker only shows this endpoint to the admin
# who added it. Pass `shared=true` to mark it null-owner (visible # who added it. Pass `shared=true` to mark it null-owner (visible
# to all users), preserving the pre-fix "everyone sees everything" # to all users), preserving the pre-fix "everyone sees everything"
@@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery):
is_enabled=True, is_enabled=True,
model_type=model_type.strip() if model_type else "llm", model_type=model_type.strip() if model_type else "llm",
cached_models=json.dumps(model_ids) if model_ids else None, cached_models=json.dumps(model_ids) if model_ids else None,
pinned_models=json.dumps(_pinned) if _pinned else None,
supports_tools=_st, supports_tools=_st,
owner=_owner_val, owner=_owner_val,
) )
@@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery):
"id": ep_id, "id": ep_id,
"name": name.strip(), "name": name.strip(),
"base_url": base_url, "base_url": base_url,
"models": model_ids, "models": _merge_model_ids(model_ids, _pinned),
"online": bool(model_ids) or bool(ping.get("reachable")), "pinned_models": _pinned,
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"), "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None, "ping_error": ping.get("error") if ping else None,
} }
@@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(ep.hidden_models)) hidden = set(json.loads(ep.hidden_models))
except Exception: except Exception:
pass pass
# Try live probe, fall back to cached # Try live probe, fall back to cached. Pinned IDs are admin-entered
# and persist regardless of probe results — never overwritten here.
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
if all_models: if all_models:
ep.cached_models = json.dumps(all_models) ep.cached_models = json.dumps(all_models)
@@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery):
all_models = json.loads(ep.cached_models) all_models = json.loads(ep.cached_models)
except Exception: except Exception:
pass pass
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
pinned_set = set(pinned)
return [ return [
{"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden} {
for m in all_models "id": m,
"display": m.split("/")[-1],
"is_hidden": m in hidden,
"is_pinned": m in pinned_set,
}
for m in _merge_model_ids(all_models, pinned)
] ]
finally: finally:
db.close() db.close()
@router.patch("/model-endpoints/{ep_id}/models") @router.patch("/model-endpoints/{ep_id}/models")
async def update_hidden_models(ep_id: str, request: Request): async def update_hidden_models(ep_id: str, request: Request):
"""Bulk update hidden models list for an endpoint. """Bulk update hidden and/or pinned model lists for an endpoint.
Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]} Expects JSON body with optional keys:
{"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]}
Each key is updated only when present, so callers can patch one list
without clobbering the other.
""" """
require_admin(request) require_admin(request)
db = SessionLocal() db = SessionLocal()
@@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery):
if not ep: if not ep:
raise HTTPException(404, "Endpoint not found") raise HTTPException(404, "Endpoint not found")
body = await request.json() body = await request.json()
hidden = body.get("hidden", []) if not isinstance(body, dict):
if not isinstance(hidden, list): raise HTTPException(400, "Body must be a JSON object")
raise HTTPException(400, "hidden must be a list of model IDs") if "hidden" in body:
ep.hidden_models = json.dumps(hidden) if hidden else None hidden = body.get("hidden")
if not isinstance(hidden, list):
raise HTTPException(400, "hidden must be a list of model IDs")
ep.hidden_models = json.dumps(hidden) if hidden else None
# Accept either "pinned" or "pinned_models" for the manual IDs list.
if "pinned_models" in body or "pinned" in body:
pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned")))
ep.pinned_models = json.dumps(pinned) if pinned else None
db.commit() db.commit()
_invalidate_models_cache() _invalidate_models_cache()
return {"id": ep_id, "hidden_count": len(hidden)} hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0
pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0
return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count}
finally: finally:
db.close() db.close()
@@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery):
return {"endpoint_id": "", "endpoint_url": "", "model": ""} return {"endpoint_id": "", "endpoint_url": "", "model": ""}
base = _normalize_base(ep.base_url) base = _normalize_base(ep.base_url)
chat_url = build_chat_url(base) chat_url = build_chat_url(base)
if not model and getattr(ep, "cached_models", None): if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)):
try: try:
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None)) visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None))
if visible: if visible:
model = visible[0] model = visible[0]
except Exception: except Exception:
@@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery):
ep.name = body["name"].strip() or ep.name ep.name = body["name"].strip() or ep.name
if "model_type" in body and isinstance(body["model_type"], str): if "model_type" in body and isinstance(body["model_type"], str):
ep.model_type = body["model_type"].strip() or ep.model_type ep.model_type = body["model_type"].strip() or ep.model_type
if "pinned_models" in body:
_pinned = _normalize_model_ids(body["pinned_models"])
ep.pinned_models = json.dumps(_pinned) if _pinned else None
# Rotating an API key used to require DELETE+POST, which wiped # Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base # endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key # URL. Allow in-place updates so the admin can change the key
@@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery):
"name": ep.name, "name": ep.name,
"model_type": ep.model_type, "model_type": ep.model_type,
"base_url": ep.base_url, "base_url": ep.base_url,
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
} }
finally: finally:
db.close() db.close()

View File

@@ -9,7 +9,9 @@ import httpx
from fastapi import APIRouter, HTTPException, Request, Form from fastapi import APIRouter, HTTPException, Request, Form
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.database import SessionLocal, Webhook from core.database import SessionLocal, Webhook, ModelEndpoint
from src.auth_helpers import owner_filter
from src.url_security import validate_public_http_url
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000
from core.middleware import require_admin as _require_admin from core.middleware import require_admin as _require_admin
def _first_enabled_endpoint(db, owner): def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]):
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus """First enabled ModelEndpoint visible to token_owner — their own rows plus
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would
is per-user (core/database.py — "when non-null, the model picker only shows let a chat-scoped token fall back onto another user's private endpoint and
the endpoint to that user"), and the sync-chat fallback uses the row's silently spend that owner's API key/quota. Prefer owner rows before shared
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token rows. Fails closed to null-owner rows only when token_owner is absent.
(e.g. a paired mobile device) fall back onto ANOTHER user's private Does not validate base_url — admin-configured local/LAN endpoints remain allowed.
endpoint and silently spend that owner's API key / quota — and reach
whatever internal base_url they configured. Mirrors the owner_filter scoping
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
no-op (single-user / legacy mode), preserving the original behaviour.
""" """
from core.database import ModelEndpoint query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
from src.auth_helpers import owner_filter if token_owner:
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712 query = owner_filter(query, ModelEndpoint, token_owner)
q = owner_filter(q, ModelEndpoint, owner) return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first()
return q.first() return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711
def _caller_owns_session(sess_owner, caller) -> bool: def _caller_owns_session(sess_owner, caller) -> bool:
@@ -278,15 +276,21 @@ def setup_webhook_routes(
api_key = body.api_key.strip() api_key = body.api_key.strip()
model = body.model or "deepseek-chat" model = body.model or "deepseek-chat"
# Resolve base_url: explicit > provider name > model prefix auto-detect # Validate only token-supplied direct base_url; auto-resolved known-provider
base_url = body.base_url.strip().rstrip("/") if body.base_url else None # URLs are not subject to extra local/LAN blocking beyond existing provider logic.
if not base_url: direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None
if direct_base_url:
try:
base_url = validate_public_http_url(direct_base_url)
except ValueError as e:
detail = str(e).replace("URL", "base_url", 1)
raise HTTPException(400, detail)
else:
base_url = _resolve_base_url(model, body.provider) base_url = _resolve_base_url(model, body.provider)
if not base_url: if not base_url:
raise HTTPException(400, raise HTTPException(400,
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') " "Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
"or provider ('deepseek', 'openai', 'groq', etc.)") "or provider ('deepseek', 'openai', 'groq', etc.)")
base_url = normalize_base(base_url) base_url = normalize_base(base_url)
endpoint_url = build_chat_url(base_url) endpoint_url = build_chat_url(base_url)
@@ -306,9 +310,7 @@ def setup_webhook_routes(
if not sess: if not sess:
db = SessionLocal() db = SessionLocal()
try: try:
# Owner-scoped: only THIS token owner's endpoints + legacy ep = _select_api_chat_fallback_endpoint(db, token_owner)
# shared rows, never another user's private endpoint/api_key.
ep = _first_enabled_endpoint(db, token_owner)
finally: finally:
db.close() db.close()

View File

@@ -2,12 +2,49 @@
import re import re
import logging import logging
from datetime import datetime from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S")
def _utcnow_naive() -> datetime:
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
The age is measured against UTC, not local time. The previous code used
``datetime.now()`` (local) against UTC-style published dates, so the age was
skewed by the host's UTC offset; it was also a latent crash once neighbouring
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
"""
if not age_str:
return 0.0
dt = None
for fmt in _AGE_FORMATS:
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
now = now or _utcnow_naive()
days_old = (now - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"} _NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
_SPORTS_HINTS = { _SPORTS_HINTS = {
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb", "sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
@@ -73,24 +110,6 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
return 0.7 return 0.7
return 0.4 return 0.4
def recency_score(age_str: Optional[str]) -> float:
if not age_str:
return 0.0
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
days_old = (datetime.now() - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
def news_quality_adjustment(title: str, snippet: str, url: str) -> float: def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
if not is_news_query: if not is_news_query:
return 0.0 return 0.0

View File

@@ -1,151 +1,13 @@
"""Search result ranking based on relevance, source quality, and recency.""" """Compatibility re-export shim for the live ranking module.
import re The real implementation lives in :mod:`services.search.ranking`, which is what
import logging the search runtime (services/search/core.py) imports. This module used to hold a
from datetime import datetime, timezone parallel copy; it now re-exports so the two cannot drift out of sync again.
from typing import List, Optional """
from urllib.parse import urlparse
logger = logging.getLogger(__name__) from services.search.ranking import ( # noqa: F401
_AGE_FORMATS,
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S") _utcnow_naive,
rank_search_results,
recency_score,
def _utcnow_naive() -> datetime:
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
The age is measured against UTC, not local time. The previous code used
``datetime.now()`` (local) against UTC-style published dates, so the age was
skewed by the host's UTC offset; it was also a latent crash once neighbouring
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
"""
if not age_str:
return 0.0
dt = None
for fmt in _AGE_FORMATS:
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
now = now or _utcnow_naive()
days_old = (now - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
_SPORTS_HINTS = {
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
}
# Word-boundary match so "sport" does not fire inside "transport"/"passport"
# and a domain like "transport.gov" is not mistaken for a sports site.
_SPORTS_HINT_RE = re.compile(
r"\b(?:" + "|".join(re.escape(h) for h in _SPORTS_HINTS) + r")\b"
) )
_LOW_VALUE_NEWS_DOMAINS = {
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
"www.yahoo.com", "msn.com", "www.msn.com",
}
_TRUSTED_NEWS_DOMAINS = {
"apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com",
"bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca",
"ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca",
"theguardian.com",
"www.theguardian.com", "euronews.com", "www.euronews.com",
"dw.com", "www.dw.com", "government.se", "www.government.se",
}
def _domain(url: str) -> str:
try:
return urlparse(url).netloc.lower()
except Exception:
return ""
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
query_lc = query.lower()
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
is_sports_query = bool(_SPORTS_HINT_RE.search(query_lc))
def title_score(title: str) -> float:
if not title:
return 0.0
title_lc = title.lower()
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
return matches / len(query_terms) if query_terms else 0.0
def snippet_score(snippet: str) -> float:
if not snippet:
return 0.0
length_factor = min(len(snippet), 200) / 200
term_hits = sum(1 for term in query_terms if term in snippet.lower())
term_factor = term_hits / len(query_terms) if query_terms else 0.0
return (length_factor + term_factor) / 2
def domain_score(url: str) -> float:
netloc = _domain(url)
if not netloc:
return 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
return 1.0
if netloc.endswith(".edu") or netloc.endswith(".gov"):
return 1.0
if netloc.endswith(".org"):
return 0.7
return 0.4
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
if not is_news_query:
return 0.0
text = f"{title} {snippet}".lower()
netloc = _domain(url)
adjustment = 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
adjustment += 1.2
if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")):
adjustment += 0.4
if netloc in _LOW_VALUE_NEWS_DOMAINS:
adjustment -= 0.8
if not is_sports_query and (_SPORTS_HINT_RE.search(text) or _SPORTS_HINT_RE.search(netloc)):
adjustment -= 1.5
# A country/news query should not rank a page whose title/snippet barely
# mentions the country above actual news pages for that country.
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
adjustment -= 1.0
return adjustment
ranked = []
for result in results:
title = result.get("title", "")
snippet = result.get("snippet", "")
url = result.get("url", "")
age = result.get("age", None)
score = (
2.0 * title_score(title)
+ 1.0 * snippet_score(snippet)
+ 1.5 * domain_score(url)
+ 1.0 * recency_score(age)
+ news_quality_adjustment(title, snippet, url)
)
ranked.append((score, result))
ranked.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in ranked]

94
src/url_security.py Normal file
View File

@@ -0,0 +1,94 @@
"""URL validation helpers for server-side outbound requests."""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
_INTERNAL_HOSTNAMES = {
"localhost",
"metadata",
"metadata.google.internal",
}
_INTERNAL_SUFFIXES = (
".localhost",
".local",
".internal",
".lan",
".intranet",
)
_BLOCKED_NETWORKS = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
)
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
ips: list[ipaddress._BaseAddress] = []
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
if family in (socket.AF_INET, socket.AF_INET6):
ips.append(ipaddress.ip_address(sockaddr[0]))
return ips
def _blocked_ip(addr: ipaddress._BaseAddress) -> bool:
return (
any(addr in net for net in _BLOCKED_NETWORKS)
or addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_multicast
or addr.is_unspecified
or addr.is_reserved
)
def _host_resolves_publicly(hostname: str) -> bool:
host = hostname.strip().lower()
if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES):
return False
try:
return not _blocked_ip(ipaddress.ip_address(host))
except ValueError:
pass
try:
addrs = _resolve_hostname_ips(host)
except OSError:
return False
return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs)
def is_public_http_url(url: str) -> bool:
parsed = urlparse((url or "").strip())
if parsed.scheme not in ("http", "https") or not parsed.hostname:
return False
return _host_resolves_publicly(parsed.hostname)
def validate_public_http_url(url: str, *, max_length: int = 2048) -> str:
"""Validate a user/API-token supplied server-side HTTP(S) endpoint.
This is for untrusted outbound URLs, not admin-created model endpoints
that are intentionally allowed to point at private model providers. DNS
failures fail closed, and DNS checks reduce obvious private-network
targets but do not eliminate every DNS rebinding race by themselves.
"""
cleaned = (url or "").strip()
if len(cleaned) > max_length:
raise ValueError("URL is too long")
if not is_public_http_url(cleaned):
raise ValueError("URL must point to a public HTTP(S) endpoint")
return cleaned

View File

@@ -0,0 +1,401 @@
import ipaddress
import importlib.util
import sys
import types
from pathlib import Path
import pytest
@pytest.mark.parametrize("url", [
"http://127.0.0.1:8000/v1",
"http://localhost:8000/v1",
"http://10.0.0.5/v1",
"http://172.16.0.1/v1",
"http://192.168.1.2/v1",
"http://169.254.169.254/latest/meta-data/",
"http://metadata.google.internal/",
"http://[::1]:8000/v1",
"http://[fc00::1]/v1",
"http://224.0.0.1/v1",
"http://0.0.0.0/v1",
"file:///etc/passwd",
])
def test_public_url_validator_blocks_internal_targets(url):
from src.url_security import is_public_http_url
assert is_public_http_url(url) is False
def test_public_url_validator_allows_public_endpoint(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("10.0.0.5")],
)
with pytest.raises(ValueError):
url_security.validate_public_http_url("https://api.example.com/v1")
def _load_webhook_routes_for_test(monkeypatch):
# Load under a unique module name so each test gets a fresh module object
# rather than a cached one from a previous monkeypatch run.
core_pkg = types.ModuleType("core")
core_pkg.__path__ = []
core_db = types.ModuleType("core.database")
core_db.SessionLocal = object
core_db.Webhook = object
core_db.ModelEndpoint = object
core_middleware = types.ModuleType("core.middleware")
core_middleware.require_admin = lambda request: None
webhook_manager = types.ModuleType("src.webhook_manager")
webhook_manager.WebhookManager = object
webhook_manager.validate_webhook_url = lambda url: url
webhook_manager.validate_events = lambda events: events
monkeypatch.setitem(sys.modules, "core", core_pkg)
monkeypatch.setitem(sys.modules, "core.database", core_db)
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
module_name = "routes.webhook_routes_under_test"
spec = importlib.util.spec_from_file_location(
module_name,
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class _Expr:
def __init__(self, fn):
self.fn = fn
def __call__(self, row):
return self.fn(row)
def __or__(self, other):
return _Expr(lambda row: self(row) or other(row))
class _Column:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return _Expr(lambda row: getattr(row, self.name) == other)
def desc(self):
return ("desc", self.name)
class _ModelEndpoint:
is_enabled = _Column("is_enabled")
owner = _Column("owner")
created_at = _Column("created_at")
class _Endpoint:
def __init__(
self,
*,
owner,
is_enabled=True,
created_at=1,
base_url="https://api.example.com/v1",
api_key=None,
):
self.owner = owner
self.is_enabled = is_enabled
self.created_at = created_at
self.base_url = base_url
self.api_key = api_key
class _EndpointQuery:
def __init__(self, rows):
self.rows = rows
self.filters = []
self.orders = []
def filter(self, *exprs):
self.filters.extend(exprs)
return self
def order_by(self, *exprs):
self.orders.extend(exprs)
return self
def first(self):
rows = self.rows
for expr in self.filters:
rows = [row for row in rows if expr(row)]
# Apply sort keys right-to-left so the leftmost key ends up as the
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
# multi-column ORDER BY behaviour).
for order in reversed(self.orders):
reverse = False
name = getattr(order, "name", None)
if isinstance(order, tuple) and order[0] == "desc":
reverse = True
name = order[1]
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
if name != "owner":
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
return rows[0] if rows else None
class _DB:
def __init__(self, rows):
self.query_obj = _EndpointQuery(rows)
self.closed = False
def query(self, model):
assert model is _ModelEndpoint
return self.query_obj
def close(self):
self.closed = True
class _ChatSession:
def __init__(self, endpoint_url, model):
self.endpoint_url = endpoint_url
self.model = model
self.headers = {}
self.history = []
def add_message(self, message):
self.history.append(message)
class _SessionManager:
def __init__(self):
self.created = []
self.save_calls = 0
def create_session(self, *, session_id, name, endpoint_url, model, owner):
session = _ChatSession(endpoint_url, model)
self.created.append({
"session_id": session_id,
"name": name,
"endpoint_url": endpoint_url,
"model": model,
"owner": owner,
"session": session,
})
return session
def save_sessions(self):
self.save_calls += 1
class _Request:
def __init__(self, *, owner="alice"):
self.state = types.SimpleNamespace(
api_token=True,
api_token_scopes=["chat"],
api_token_owner=owner,
)
class _WebhookManager:
async def fire(self, event, payload):
return None
def _install_sync_chat_stubs(monkeypatch):
# FastAPI checks for python_multipart at import time when Form is used;
# stub it so the optional dependency is not required in the test environment.
python_multipart = types.ModuleType("python_multipart")
python_multipart.__version__ = "0.0.13"
core_models = types.ModuleType("core.models")
class _ChatMessage:
def __init__(self, role, content):
self.role = role
self.content = content
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
return "mocked response"
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
llm_core = types.ModuleType("src.llm_core")
llm_core.llm_call_async = _llm_call_async
core_models.ChatMessage = _ChatMessage
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
monkeypatch.setitem(sys.modules, "core.models", core_models)
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
def _sync_chat_endpoint(webhook_routes, session_manager):
router = webhook_routes.setup_webhook_routes(
_WebhookManager(),
auth_manager=None,
session_manager=session_manager,
)
for route in router.routes:
if route.path == "/api/v1/chat":
return route.endpoint
raise AssertionError("sync chat route not found")
@pytest.mark.parametrize("base_url", [
"http://127.0.0.1:11434/v1",
"http://localhost:11434/v1",
"http://10.0.0.5/v1",
"http://169.254.169.254/latest/meta-data/",
])
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url=base_url,
model="test-model",
provider=None,
session=None,
)
with pytest.raises(webhook_routes.HTTPException) as exc:
await sync_chat(_Request(), body)
assert exc.value.status_code == 400
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
assert session_manager.created == []
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url="https://api.example.com/v1",
model="test-model",
provider=None,
session=None,
)
response = await sync_chat(_Request(), body)
assert response["response"] == "mocked response"
assert response["model"] == "test-model"
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", is_enabled=False, created_at=0),
_Endpoint(owner="bob", created_at=0),
_Endpoint(owner=None, created_at=1),
_Endpoint(owner="alice", created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
assert selected.owner == "alice"
assert selected.is_enabled is True
assert selected.created_at == 2
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", created_at=0),
_Endpoint(owner=None, is_enabled=False, created_at=1),
_Endpoint(owner=None, created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
assert selected.owner is None
assert selected.is_enabled is True
assert selected.created_at == 2
@pytest.mark.asyncio
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
local_endpoint = _Endpoint(
owner=None,
base_url="http://localhost:11434/v1",
api_key="configured-key",
)
db = _DB([local_endpoint])
calls = []
def _session_local():
return db
def _validate_public_http_url(url, *, max_length=2048):
calls.append(url)
raise AssertionError("configured fallback endpoint should not be publicly validated")
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
model="local-model",
api_key=None,
base_url=None,
provider=None,
session=None,
)
response = await sync_chat(_Request(owner=None), body)
assert response["response"] == "mocked response"
assert response["model"] == "local-model"
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
assert calls == []

View File

@@ -0,0 +1,147 @@
"""Guards the standalone GPU compose files against drift.
Stack-management UIs (Portainer, Coolify, Dockhand, ...) often accept only a
single compose file and do not honor COMPOSE_FILE or multiple ``-f`` overlays,
so the repo ships standalone ``docker-compose.gpu-*.yml`` files that inline the
GPU overlay. The base ``docker-compose.yml`` plus ``docker/gpu.*.yml`` overlays
remain the source of truth; these tests assert each standalone file equals the
base compose with only the matching overlay merged into the ``odysseus``
service. No Docker / docker compose is required — everything is pure YAML.
"""
import copy
from pathlib import Path
import pytest
import yaml
ROOT = Path(__file__).resolve().parents[1]
BASE = ROOT / "docker-compose.yml"
NVIDIA_OVERLAY = ROOT / "docker" / "gpu.nvidia.yml"
AMD_OVERLAY = ROOT / "docker" / "gpu.amd.yml"
NVIDIA_STANDALONE = ROOT / "docker-compose.gpu-nvidia.yml"
AMD_STANDALONE = ROOT / "docker-compose.gpu-amd.yml"
SERVICE = "odysseus"
def _load(path: Path) -> dict:
return yaml.safe_load(path.read_text(encoding="utf-8"))
def _deep_merge(base: dict, overlay: dict) -> dict:
"""Mirror docker compose overlay semantics for the keys these files use.
Mappings merge recursively; list-valued service fields are concatenated
(compose appends override sequences such as ``environment`` rather than
replacing them); scalars are overwritten. The overlays here only append to
``environment`` and add otherwise-absent keys (``deploy``, ``devices``,
``group_add``), so this keeps the expected merge explicit without invoking
docker compose.
"""
result = copy.deepcopy(base)
for key, value in overlay.items():
if isinstance(value, dict) and isinstance(result.get(key), dict):
result[key] = _deep_merge(result[key], value)
elif isinstance(value, list) and isinstance(result.get(key), list):
result[key] = copy.deepcopy(result[key]) + copy.deepcopy(value)
else:
result[key] = copy.deepcopy(value)
return result
def _merge_overlay_into_base(base: dict, overlay: dict) -> dict:
"""Build the expected standalone config: base + overlay on odysseus only."""
expected = copy.deepcopy(base)
overlay_service = overlay["services"][SERVICE]
expected["services"][SERVICE] = _deep_merge(
expected["services"][SERVICE], overlay_service
)
return expected
@pytest.fixture(scope="module")
def base():
return _load(BASE)
# --- Equivalence: standalone == base + overlay -----------------------------
def test_nvidia_standalone_equals_base_plus_overlay(base):
overlay = _load(NVIDIA_OVERLAY)
standalone = _load(NVIDIA_STANDALONE)
assert standalone == _merge_overlay_into_base(base, overlay)
def test_amd_standalone_equals_base_plus_overlay(base):
overlay = _load(AMD_OVERLAY)
standalone = _load(AMD_STANDALONE)
assert standalone == _merge_overlay_into_base(base, overlay)
# --- Non-odysseus services and volumes untouched ---------------------------
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
def test_non_odysseus_services_match_base(base, standalone_path):
standalone = _load(standalone_path)
for name, definition in base["services"].items():
if name == SERVICE:
continue
assert standalone["services"][name] == definition
assert set(standalone["services"]) == set(base["services"])
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
def test_top_level_volumes_match_base(base, standalone_path):
standalone = _load(standalone_path)
assert standalone.get("volumes") == base.get("volumes")
# --- odysseus = base service + only the overlay additions ------------------
def test_nvidia_odysseus_adds_only_overlay(base):
standalone = _load(NVIDIA_STANDALONE)
svc = standalone["services"][SERVICE]
base_svc = base["services"][SERVICE]
# Base environment preserved, plus exactly the two NVIDIA variables.
assert "NVIDIA_VISIBLE_DEVICES=all" in svc["environment"]
assert "NVIDIA_DRIVER_CAPABILITIES=compute,utility" in svc["environment"]
added_env = set(svc["environment"]) - set(base_svc["environment"])
assert added_env == {
"NVIDIA_VISIBLE_DEVICES=all",
"NVIDIA_DRIVER_CAPABILITIES=compute,utility",
}
# deploy block is new and matches the overlay's GPU reservation exactly.
assert "deploy" not in base_svc
devices = svc["deploy"]["resources"]["reservations"]["devices"]
assert devices == [
{"driver": "nvidia", "count": "all", "capabilities": ["gpu"]}
]
# No AMD-only keys leaked in.
assert "devices" not in svc
assert "group_add" not in svc
def test_amd_odysseus_adds_only_overlay(base):
standalone = _load(AMD_STANDALONE)
svc = standalone["services"][SERVICE]
base_svc = base["services"][SERVICE]
# Environment is unchanged from base for AMD.
assert svc["environment"] == base_svc["environment"]
# devices and group_add are new and match the overlay exactly.
assert "devices" not in base_svc
assert "group_add" not in base_svc
assert svc["devices"] == ["/dev/kfd", "/dev/dri"]
assert svc["group_add"] == ["video", "${RENDER_GID:-render}"]
# No NVIDIA-only keys leaked in.
assert "deploy" not in svc

View File

@@ -66,3 +66,37 @@ def test_normal_model_payload_keeps_temperature(monkeypatch):
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2) payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2)
assert payload["temperature"] == 0.2 assert payload["temperature"] == 0.2
assert payload["max_tokens"] == 5 assert payload["max_tokens"] == 5
def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
# OpenAI/local providers may validly use temperatures above 1.0; the clamp
# is Anthropic-only and must not touch this path.
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 1.2)
assert payload["temperature"] == 1.2
def _anthropic_payload(temperature):
return llm_core._build_anthropic_payload(
"claude-3-5-sonnet",
[{"role": "user", "content": "Hi"}],
temperature,
max_tokens=5,
)
def test_anthropic_payload_clamps_above_one():
# Anthropic rejects temperature > 1.0 (e.g. the Nietzsche preset's 1.2).
assert _anthropic_payload(1.2)["temperature"] == 1.0
def test_anthropic_payload_keeps_in_range():
assert _anthropic_payload(0.7)["temperature"] == 0.7
def test_anthropic_payload_clamps_negative():
assert _anthropic_payload(-0.5)["temperature"] == 0.0
def test_anthropic_payload_none_temperature_does_not_crash():
payload = _anthropic_payload(None)
assert payload["temperature"] is None

View File

@@ -1,6 +1,9 @@
"""Tests for model route helper functions — pure logic, no server needed.""" """Tests for model route helper functions — pure logic, no server needed."""
import asyncio
import json
import sys import sys
import types import types
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import httpx import httpx
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
from routes.model_routes import ( from routes.model_routes import (
_match_provider_curated, _match_provider_curated,
_curate_models, _curate_models,
_visible_models,
_normalize_model_ids,
_is_chat_model, _is_chat_model,
_classify_endpoint, _classify_endpoint,
_probe_endpoint, _probe_endpoint,
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail) monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
assert model_routes._docker_host_gateway_reachable() is False assert model_routes._docker_host_gateway_reachable() is False
# ── pinned model IDs: normalization helper ──
class TestNormalizeModelIds:
def test_list_passthrough_trims_and_dedupes(self):
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
def test_json_string_list(self):
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
def test_comma_and_newline_string(self):
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
def test_none_and_empty(self):
assert _normalize_model_ids(None) == []
assert _normalize_model_ids("") == []
assert _normalize_model_ids(" ") == []
def test_non_string_values_ignored(self):
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
# ── pinned model IDs: _visible_models merge ──
class TestVisibleModelsPinned:
def test_includes_pinned_not_in_cached(self):
visible = _visible_models(["a"], None, ["deploy-1"])
assert visible == ["a", "deploy-1"]
def test_cached_plus_pinned_dedup_preserves_order(self):
visible = _visible_models(["a", "b"], None, ["b", "c"])
assert visible == ["a", "b", "c"]
def test_hidden_can_hide_a_pinned_model(self):
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
assert visible == ["a"]
def test_accepts_json_string_inputs(self):
visible = _visible_models('["a"]', '["a"]', '["b"]')
assert visible == ["b"]
# ── pinned model IDs: route behaviour ──
# Building the router exercises FastAPI's Form() routes, which require
# python-multipart. The test env ships without it, so register a minimal stub
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
if "python_multipart" not in sys.modules:
try:
import python_multipart # noqa: F401
except ImportError:
_mp_stub = types.ModuleType("python_multipart")
_mp_stub.__version__ = "0.0.13"
sys.modules["python_multipart"] = _mp_stub
class _PinnedFakeQuery:
def __init__(self, rows):
self.rows = list(rows)
def filter(self, *conditions):
return self
def order_by(self, *args):
return self
def first(self):
return self.rows[0] if self.rows else None
def all(self):
return list(self.rows)
class _PinnedFakeDb:
def __init__(self, rows):
self.rows = rows
self.added = []
self.committed = 0
def query(self, model):
return _PinnedFakeQuery(self.rows)
def add(self, row):
self.added.append(row)
def commit(self):
self.committed += 1
def close(self):
pass
class _FakeCol:
"""Column stand-in: every comparison/operator just returns itself so the
dedupe query expressions evaluate without a real SQLAlchemy column."""
__hash__ = None
def __eq__(self, other):
return self
def is_(self, other):
return self
def __or__(self, other):
return self
def desc(self):
return self
class _RecordingEndpoint:
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
Class-level fake columns let it double as the query class in the dedupe
lookup; instance attributes (set in __init__) shadow them per-row.
"""
id = _FakeCol()
base_url = _FakeCol()
owner = _FakeCol()
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _PinnedFakeRequest:
def __init__(self, body=None, headers=None):
self._body = body if body is not None else {}
self.headers = headers or {}
async def json(self):
return self._body
def _get_route(path, method):
from routes.model_routes import setup_model_routes
router = setup_model_routes(model_discovery=None)
for route in router.routes:
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError(f"{method} {path} not found")
def _make_endpoint(**kwargs):
base = dict(
id="ep1",
name="EP",
base_url="http://localhost:9999/v1",
api_key=None,
is_enabled=True,
hidden_models=None,
cached_models=None,
pinned_models=None,
model_type="llm",
supports_tools=None,
)
base.update(kwargs)
return SimpleNamespace(**base)
def test_patch_models_saves_pinned_models(monkeypatch):
ep = _make_endpoint()
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
result = asyncio.run(endpoint("ep1", request))
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
assert result["pinned_count"] == 2
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
asyncio.run(endpoint("ep1", request))
assert json.loads(ep.hidden_models) == ["hide-me"]
assert json.loads(ep.pinned_models) == ["deploy-1"]
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
result = endpoint("ep1", _PinnedFakeRequest())
ids = [row["id"] for row in result]
assert ids == ["deploy-1"]
assert result[0]["is_pinned"] is True
def test_reprobe_preserves_pinned_models(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
monkeypatch.setattr(
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
)
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
response = endpoint("ep1", _PinnedFakeRequest())
async def _drain():
async for _ in response.body_iterator:
pass
asyncio.run(_drain())
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
assert json.loads(ep.pinned_models) == ["deploy-1"]
assert json.loads(ep.cached_models) == ["m1"]
def test_visible_models_handles_malformed_strings():
# Non-JSON cached/pinned strings are treated as comma/newline lists and
# never raise; a malformed hidden string is normalized too.
result = _visible_models("a,b", "b", "{bad json")
assert isinstance(result, list)
assert result == ["a", "{bad json"]
assert _visible_models("", None, "") == []
assert _visible_models("only-cached", None, None) == ["only-cached"]
def _create_form_kwargs(**overrides):
"""Defaults for every Form() param create_model_endpoint reads directly.
Calling the route as a plain function bypasses FastAPI form parsing, so the
Form() sentinels must be replaced with real strings.
"""
kwargs = dict(
name="",
api_key="",
skip_probe="true", # avoid any network probe in unit tests
require_models="false",
model_type="llm",
supports_tools="",
pinned_models="",
container_local="false",
shared="true",
)
kwargs.update(overrides)
return kwargs
def _patch_create_deps(monkeypatch, db):
import src.auth_helpers as auth_helpers
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
db = _PinnedFakeDb([]) # no existing row → fresh create path
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
)
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
assert result["models"] == ["deploy-1", "deploy-2"]
assert result["online"] is True
# Persisted onto the created row.
assert len(db.added) == 1
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
hidden_models=None,
pinned_models=json.dumps(["old-pin"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="new-pin"),
)
assert result["existing"] is True
# Incoming pin merged onto the existing pins (no clobber, order preserved).
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
assert result["pinned_models"] == ["old-pin", "new-pin"]
# models = cached + pinned - hidden, visible merged list.
assert result["models"] == ["m1", "old-pin", "new-pin"]
# No new row created on the dedupe path.
assert db.added == []
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
pinned_models=json.dumps(["keep-me"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(), # pinned_models defaults to ""
)
assert json.loads(existing.pinned_models) == ["keep-me"]
assert result["pinned_models"] == ["keep-me"]
assert db.committed == 0 # nothing to persist

View File

@@ -247,10 +247,14 @@ class _Column:
def __eq__(self, value): def __eq__(self, value):
return _Predicate(lambda row: getattr(row, self.name) == value) return _Predicate(lambda row: getattr(row, self.name) == value)
def desc(self):
return self
class _ModelEndpoint: class _ModelEndpoint:
is_enabled = _Column("is_enabled") is_enabled = _Column("is_enabled")
owner = _Column("owner") owner = _Column("owner")
created_at = _Column("created_at")
class _Query: class _Query:
@@ -261,6 +265,9 @@ class _Query:
self._rows = [r for r in self._rows if all(p(r) for p in predicates)] self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
return self return self
def order_by(self, *exprs):
return self
def first(self): def first(self):
return self._rows[0] if self._rows else None return self._rows[0] if self._rows else None
@@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True):
def _select(rows, owner): def _select(rows, owner):
wh_mod = _import_webhook_helper() wh_mod = _import_webhook_helper()
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint # _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
return wh_mod._first_enabled_endpoint(_DB(rows), owner) # (not a local import), so we patch the module attribute directly.
wh_mod.ModelEndpoint = _ModelEndpoint
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
def test_sync_chat_fallback_never_picks_another_owners_endpoint(): def test_sync_chat_fallback_never_picks_another_owners_endpoint():
@@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
assert ep is not None and ep.name == "shared" assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop(): def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
# An unresolvable/empty token owner keeps the original single-user behaviour # When no token owner is known, only null-owner (shared) endpoints are
# (owner_filter no-op): first enabled row, whatever it is. # visible — private endpoints of any user must not be returned.
rows = [_ep("first", "bob"), _ep("second", "alice")] rows = [_ep("bob-private", "bob"), _ep("shared", None)]
ep = _select(rows, None) ep = _select(rows, None)
assert ep is not None and ep.name == "first" assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
# No shared rows → fail closed rather than returning another user's endpoint.
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
assert _select(rows, None) is None

View File

@@ -1,4 +1,4 @@
from src.search.ranking import rank_search_results from services.search.ranking import rank_search_results
def test_news_queries_prefer_news_sources_over_sports_and_social_results(): def test_news_queries_prefer_news_sources_over_sports_and_social_results():

View File

@@ -8,7 +8,8 @@ module-level, time-injectable function.
from datetime import datetime, timezone from datetime import datetime, timezone
from src.search.ranking import recency_score, _utcnow_naive import services.search.ranking as live_ranking
from services.search.ranking import recency_score, _utcnow_naive, rank_search_results
def test_fresh_result_scores_one(): def test_fresh_result_scores_one():
@@ -37,3 +38,37 @@ def test_default_now_is_naive_utc():
assert now.tzinfo is None assert now.tzinfo is None
reference = datetime.now(timezone.utc).replace(tzinfo=None) reference = datetime.now(timezone.utc).replace(tzinfo=None)
assert abs((now - reference).total_seconds()) < 5 assert abs((now - reference).total_seconds()) < 5
def test_supported_timestamp_formats_parse():
# All three formats the current implementation supports resolve to the same
# ~4-day-old age, so each scores a full 1.0.
now = datetime(2026, 1, 5, 12, 0, 0)
assert recency_score("2026-01-01", now=now) == 1.0
assert recency_score("2026-01-01T08:30:00", now=now) == 1.0
assert recency_score("2026-01-01 08:30:00", now=now) == 1.0
def test_shim_reexports_live_objects():
# src.search.ranking is a compatibility shim; it must expose the *same*
# objects as the live services module so the two cannot diverge.
import src.search.ranking as shim
assert shim.recency_score is live_ranking.recency_score
assert shim.rank_search_results is live_ranking.rank_search_results
assert shim._utcnow_naive is live_ranking._utcnow_naive
def test_live_rank_path_prefers_newer_result(monkeypatch):
# Pin "now" so age scoring is deterministic. The two results are identical
# apart from age, isolating recency as the only differentiator.
monkeypatch.setattr(live_ranking, "_utcnow_naive", lambda: datetime(2026, 1, 31))
results = [
{"title": "Report", "url": "https://example.org/a", "snippet": "x", "age": "2026-01-01"},
{"title": "Report", "url": "https://example.org/b", "snippet": "x", "age": "2026-01-29"},
]
ranked = rank_search_results("report", results)
assert ranked[0]["url"] == "https://example.org/b"
assert ranked[1]["url"] == "https://example.org/a"