Reapply "Merge branch 'main' of github.com:pewdiepie-archdaemon/odysseus"

This reverts commit cc8fe2f6e3.
This commit is contained in:
pewdiepie-archdaemon
2026-06-03 22:47:00 +09:00
parent cc8fe2f6e3
commit 6861c41580
16 changed files with 1647 additions and 225 deletions

View File

@@ -189,6 +189,20 @@ RENDER_GID=989
For NVIDIA/AMD GPU support, also read the comments in the selected overlay file: docker/gpu.nvidia.yml or docker/gpu.amd.yml.
**Stack-management UIs (Portainer, Coolify, Dockhand, etc.).** These tools
often accept only a single Compose file and do not reliably honor `COMPOSE_FILE`
or multiple `-f` overlays. CLI users should keep using the `COMPOSE_FILE`
overlay workflow above. For stack UIs, point the stack at one of the standalone
files instead, which bundle the base stack plus the GPU settings:
- `docker-compose.gpu-nvidia.yml` — still requires the NVIDIA Container Toolkit
on the host.
- `docker-compose.gpu-amd.yml` — still requires host ROCm/kfd/DRI setup, the
`video`/`render` group membership, and `RENDER_GID` when needed.
The base `docker-compose.yml` plus the `docker/gpu.*.yml` overlays remain the
source of truth; the standalone files mirror them for single-file deployments.
Verify after enabling either overlay:
```bash

View File

@@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base):
is_enabled = Column(Boolean, default=True)
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
# Whether models on this endpoint accept OpenAI-style function
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
@@ -856,6 +857,24 @@ def _migrate_add_cached_models_column():
except Exception as e:
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
def _migrate_add_pinned_models_column():
"""Add pinned_models column to model_endpoints if it doesn't exist."""
import sqlite3
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
columns = [row[1] for row in cursor.fetchall()]
if columns and "pinned_models" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
def _migrate_add_notes_sort_order():
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
import sqlite3
@@ -1511,6 +1530,7 @@ def init_db():
Base.metadata.create_all(bind=engine)
_migrate_add_hidden_models_column()
_migrate_add_cached_models_column()
_migrate_add_pinned_models_column()
_migrate_add_notes_sort_order()
_migrate_add_model_type_column()
_migrate_add_model_endpoint_owner_column()

164
docker-compose.gpu-amd.yml Normal file
View File

@@ -0,0 +1,164 @@
# Standalone AMD ROCm GPU Compose file for stack-management UIs (Portainer,
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
#
# This is equivalent to: docker-compose.yml + docker/gpu.amd.yml.
# The base docker-compose.yml plus the docker/gpu.amd.yml overlay remain the
# source of truth — CLI users should keep using the COMPOSE_FILE overlay
# workflow. Keep this file in sync with both when either changes.
#
# Requires ROCm drivers on the host (kfd + DRI devices) and the host user
# running Docker in the `video` and `render` groups. Set RENDER_GID to your
# host's numeric render group id when needed. See docker/gpu.amd.yml for details.
services:
odysseus:
build: .
ports:
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
volumes:
- ./data:/app/data:z
- ./logs:/app/logs:z
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
# add the shown public key to each remote server's authorized_keys.
- ./data/ssh:/app/.ssh:z
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
# container, so persist its HuggingFace cache under ./data/huggingface.
- ./data/huggingface:/app/.cache/huggingface:z
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
# land under /app/.local for the odysseus user. Persist them so a
# container recreate does not silently remove installed serve engines.
- ./data/local:/app/.local:z
extra_hosts:
# Lets the container reach local services on the Docker host, including
# Ollama at http://host.docker.internal:11434.
- "host.docker.internal:host-gateway"
environment:
- LLM_HOST=${LLM_HOST:-localhost}
- LLM_HOSTS=${LLM_HOSTS:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
- HF_TOKEN=${HF_TOKEN:-}
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
- SEARXNG_INSTANCE=http://searxng:8080
- CHROMADB_HOST=chromadb
- CHROMADB_PORT=8000
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
- AUTH_ENABLED=${AUTH_ENABLED:-true}
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
- SERPER_API_KEY=${SERPER_API_KEY:-}
# PUID / PGID — the user/group the container drops to before
# running uvicorn (entrypoint also chowns /app/data + /app/logs
# to match, so bind-mounted files stay editable from the host).
# 1000 is the default first user on most Linux installs. If your
# host user has a different id, override here or via .env, e.g.:
# PUID=1001
# PGID=1001
# Find yours with: id -u / id -g
- PUID=${PUID:-1000}
- PGID=${PGID:-1000}
depends_on:
searxng:
condition: service_healthy
chromadb:
condition: service_started
restart: unless-stopped
# AMD ROCm overlay (from docker/gpu.amd.yml).
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
- ${RENDER_GID:-render}
chromadb:
image: docker.io/chromadb/chroma:latest
ports:
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
volumes:
- chromadb-data:/chroma/chroma
environment:
- ANONYMIZED_TELEMETRY=FALSE
restart: unless-stopped
searxng:
# Pinned, not :latest — odysseus waits on searxng's healthcheck
# (depends_on: condition: service_healthy), so a broken upstream `latest`
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
# Bump this deliberately after verifying a newer tag boots clean.
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
entrypoint:
- /bin/sh
- -c
- |
set -eu
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
secret="$${SEARXNG_SECRET:-}"
if [ -z "$$secret" ]; then
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
fi
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
fi
exec /usr/local/searxng/entrypoint.sh
ports:
- "127.0.0.1:8080:8080"
volumes:
- searxng-data:/etc/searxng
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
environment:
- SEARXNG_BASE_URL=http://localhost:8080/
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
# The official searxng image runs as the non-root `searxng` user, but its
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
# su-exec, and (with our wrapper above) write settings.yml into the named
# volume. Without these capabilities the wrapper aborts at the redirection
# with EACCES and the container fails its healthcheck with permission
# errors during setup. Mirrors the cap set recommended by the upstream
# searxng-docker compose file. See issue #721.
cap_drop:
- ALL
cap_add:
- CHOWN
- SETGID
- SETUID
- DAC_OVERRIDE
healthcheck:
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
interval: 5s
timeout: 6s
retries: 20
start_period: 10s
restart: unless-stopped
ntfy:
image: docker.io/binwiederhier/ntfy
command: serve
ports:
- "${NTFY_BIND:-127.0.0.1}:8091:80"
volumes:
- ntfy-cache:/var/cache/ntfy
environment:
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
restart: unless-stopped
volumes:
searxng-data:
chromadb-data:
ntfy-cache:

View File

@@ -0,0 +1,167 @@
# Standalone NVIDIA GPU Compose file for stack-management UIs (Portainer,
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
#
# This is equivalent to: docker-compose.yml + docker/gpu.nvidia.yml.
# The base docker-compose.yml plus the docker/gpu.nvidia.yml overlay remain
# the source of truth — CLI users should keep using the COMPOSE_FILE overlay
# workflow. Keep this file in sync with both when either changes.
#
# Requires the NVIDIA Container Toolkit on the host. See docker/gpu.nvidia.yml
# for setup details.
services:
odysseus:
build: .
ports:
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
volumes:
- ./data:/app/data:z
- ./logs:/app/logs:z
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
# add the shown public key to each remote server's authorized_keys.
- ./data/ssh:/app/.ssh:z
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
# container, so persist its HuggingFace cache under ./data/huggingface.
- ./data/huggingface:/app/.cache/huggingface:z
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
# land under /app/.local for the odysseus user. Persist them so a
# container recreate does not silently remove installed serve engines.
- ./data/local:/app/.local:z
extra_hosts:
# Lets the container reach local services on the Docker host, including
# Ollama at http://host.docker.internal:11434.
- "host.docker.internal:host-gateway"
environment:
- LLM_HOST=${LLM_HOST:-localhost}
- LLM_HOSTS=${LLM_HOSTS:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
- HF_TOKEN=${HF_TOKEN:-}
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
- SEARXNG_INSTANCE=http://searxng:8080
- CHROMADB_HOST=chromadb
- CHROMADB_PORT=8000
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
- AUTH_ENABLED=${AUTH_ENABLED:-true}
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
- SERPER_API_KEY=${SERPER_API_KEY:-}
# PUID / PGID — the user/group the container drops to before
# running uvicorn (entrypoint also chowns /app/data + /app/logs
# to match, so bind-mounted files stay editable from the host).
# 1000 is the default first user on most Linux installs. If your
# host user has a different id, override here or via .env, e.g.:
# PUID=1001
# PGID=1001
# Find yours with: id -u / id -g
- PUID=${PUID:-1000}
- PGID=${PGID:-1000}
# NVIDIA overlay (from docker/gpu.nvidia.yml).
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
depends_on:
searxng:
condition: service_healthy
chromadb:
condition: service_started
restart: unless-stopped
# NVIDIA overlay (from docker/gpu.nvidia.yml).
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
chromadb:
image: docker.io/chromadb/chroma:latest
ports:
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
volumes:
- chromadb-data:/chroma/chroma
environment:
- ANONYMIZED_TELEMETRY=FALSE
restart: unless-stopped
searxng:
# Pinned, not :latest — odysseus waits on searxng's healthcheck
# (depends_on: condition: service_healthy), so a broken upstream `latest`
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
# Bump this deliberately after verifying a newer tag boots clean.
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
entrypoint:
- /bin/sh
- -c
- |
set -eu
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
secret="$${SEARXNG_SECRET:-}"
if [ -z "$$secret" ]; then
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
fi
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
fi
exec /usr/local/searxng/entrypoint.sh
ports:
- "127.0.0.1:8080:8080"
volumes:
- searxng-data:/etc/searxng
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
environment:
- SEARXNG_BASE_URL=http://localhost:8080/
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
# The official searxng image runs as the non-root `searxng` user, but its
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
# su-exec, and (with our wrapper above) write settings.yml into the named
# volume. Without these capabilities the wrapper aborts at the redirection
# with EACCES and the container fails its healthcheck with permission
# errors during setup. Mirrors the cap set recommended by the upstream
# searxng-docker compose file. See issue #721.
cap_drop:
- ALL
cap_add:
- CHOWN
- SETGID
- SETUID
- DAC_OVERRIDE
healthcheck:
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
interval: 5s
timeout: 6s
retries: 20
start_period: 10s
restart: unless-stopped
ntfy:
image: docker.io/binwiederhier/ntfy
command: serve
ports:
- "${NTFY_BIND:-127.0.0.1}:8091:80"
volumes:
- ntfy-cache:/var/cache/ntfy
environment:
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
restart: unless-stopped
volumes:
searxng-data:
chromadb-data:
ntfy-cache:

View File

@@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
return "No models found for that provider/key."
def _visible_models(cached_models, hidden_models):
"""Filter cached model IDs by hidden_models. Returns list of visible IDs."""
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
def _normalize_model_ids(value):
"""Coerce a model-ID input into a clean, ordered list of strings.
Accepts a list, a JSON-encoded list string, or a comma/newline separated
string (handy for form or backend API input). Trims whitespace, drops
empty and non-string values, and de-duplicates preserving first-seen order.
"""
if value is None:
return []
items = value
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
except Exception:
parsed = None
items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text)
if not isinstance(items, list):
return []
out, seen = [], set()
for item in items:
if not isinstance(item, str):
continue
s = item.strip()
if not s or s in seen:
continue
seen.add(s)
out.append(s)
return out
def _merge_model_ids(*lists):
"""Concatenate model-ID lists, de-duplicating and preserving order."""
out, seen = [], set()
for ids in lists:
for m in (ids or []):
if not isinstance(m, str) or m in seen:
continue
seen.add(m)
out.append(m)
return out
def _visible_models(cached_models, hidden_models, pinned_models=None):
"""Merge cached + pinned model IDs, then filter out hidden ones.
Pinned IDs are admin-entered and may not appear in cached_models (e.g.
cloud deployment IDs the provider does not list in /v1/models). Returns an
ordered, de-duplicated list of visible IDs.
"""
# Normalize each input so JSON strings, lists, comma/newline strings, and
# malformed strings are all handled without raising.
merged = _merge_model_ids(
_normalize_model_ids(cached_models),
_normalize_model_ids(pinned_models),
)
if not hidden_models:
return all_models
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or []))
return [m for m in all_models if m not in hidden]
return merged
hidden = set(_normalize_model_ids(hidden_models))
return [m for m in merged if m not in hidden]
def setup_model_routes(model_discovery):
@@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(r.hidden_models))
except Exception:
pass
visible = [m for m in all_models if m not in hidden]
status = "online" if all_models else "offline"
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
visible = _visible_models(all_models, r.hidden_models, pinned)
# Endpoint counts as reachable if it has any model — including
# admin-pinned IDs that a probe would never surface.
status = "online" if (all_models or pinned) else "offline"
ping = None
if not all_models and r.is_enabled:
if not all_models and not pinned and r.is_enabled:
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"):
status = "empty"
@@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery):
"has_key": bool(r.api_key),
"is_enabled": r.is_enabled,
"models": visible,
"pinned_models": pinned,
"hidden_count": len(hidden),
"online": status != "offline",
"status": status,
@@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery):
require_models: str = Form("false"),
model_type: str = Form("llm"),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
container_local: str = Form("false"),
# Default `shared=true` → endpoints are visible to all users (the
# app's historical behaviour). Admins can pass `shared=false` to
@@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery):
.first()
)
if existing:
# Persist any incoming pinned IDs onto the existing row. An
# empty/omitted form field must not wipe previously pinned IDs.
_incoming_pinned = _normalize_model_ids(pinned_models)
if _incoming_pinned:
_merged_pinned = _merge_model_ids(
_normalize_model_ids(getattr(existing, "pinned_models", None)),
_incoming_pinned,
)
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
_db_dedup.commit()
_invalidate_models_cache()
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
return {
"id": existing.id,
"name": existing.name,
"base_url": existing.base_url,
"models": json.loads(existing.cached_models) if existing.cached_models else [],
"models": _visible_models(
getattr(existing, "cached_models", None),
getattr(existing, "hidden_models", None),
existing.pinned_models,
),
"pinned_models": _existing_pinned,
"online": True,
"status": "online",
"existing": True,
@@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery):
try:
_st_raw = (supports_tools or "").strip().lower()
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
_pinned = _normalize_model_ids(pinned_models)
# Stamp owner so the picker only shows this endpoint to the admin
# who added it. Pass `shared=true` to mark it null-owner (visible
# to all users), preserving the pre-fix "everyone sees everything"
@@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery):
is_enabled=True,
model_type=model_type.strip() if model_type else "llm",
cached_models=json.dumps(model_ids) if model_ids else None,
pinned_models=json.dumps(_pinned) if _pinned else None,
supports_tools=_st,
owner=_owner_val,
)
@@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery):
"id": ep_id,
"name": name.strip(),
"base_url": base_url,
"models": model_ids,
"online": bool(model_ids) or bool(ping.get("reachable")),
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"),
"models": _merge_model_ids(model_ids, _pinned),
"pinned_models": _pinned,
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None,
}
@@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(ep.hidden_models))
except Exception:
pass
# Try live probe, fall back to cached
# Try live probe, fall back to cached. Pinned IDs are admin-entered
# and persist regardless of probe results — never overwritten here.
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
if all_models:
ep.cached_models = json.dumps(all_models)
@@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery):
all_models = json.loads(ep.cached_models)
except Exception:
pass
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
pinned_set = set(pinned)
return [
{"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden}
for m in all_models
{
"id": m,
"display": m.split("/")[-1],
"is_hidden": m in hidden,
"is_pinned": m in pinned_set,
}
for m in _merge_model_ids(all_models, pinned)
]
finally:
db.close()
@router.patch("/model-endpoints/{ep_id}/models")
async def update_hidden_models(ep_id: str, request: Request):
"""Bulk update hidden models list for an endpoint.
"""Bulk update hidden and/or pinned model lists for an endpoint.
Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]}
Expects JSON body with optional keys:
{"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]}
Each key is updated only when present, so callers can patch one list
without clobbering the other.
"""
require_admin(request)
db = SessionLocal()
@@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery):
if not ep:
raise HTTPException(404, "Endpoint not found")
body = await request.json()
hidden = body.get("hidden", [])
if not isinstance(hidden, list):
raise HTTPException(400, "hidden must be a list of model IDs")
ep.hidden_models = json.dumps(hidden) if hidden else None
if not isinstance(body, dict):
raise HTTPException(400, "Body must be a JSON object")
if "hidden" in body:
hidden = body.get("hidden")
if not isinstance(hidden, list):
raise HTTPException(400, "hidden must be a list of model IDs")
ep.hidden_models = json.dumps(hidden) if hidden else None
# Accept either "pinned" or "pinned_models" for the manual IDs list.
if "pinned_models" in body or "pinned" in body:
pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned")))
ep.pinned_models = json.dumps(pinned) if pinned else None
db.commit()
_invalidate_models_cache()
return {"id": ep_id, "hidden_count": len(hidden)}
hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0
pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0
return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count}
finally:
db.close()
@@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery):
return {"endpoint_id": "", "endpoint_url": "", "model": ""}
base = _normalize_base(ep.base_url)
chat_url = build_chat_url(base)
if not model and getattr(ep, "cached_models", None):
if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)):
try:
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None))
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None))
if visible:
model = visible[0]
except Exception:
@@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery):
ep.name = body["name"].strip() or ep.name
if "model_type" in body and isinstance(body["model_type"], str):
ep.model_type = body["model_type"].strip() or ep.model_type
if "pinned_models" in body:
_pinned = _normalize_model_ids(body["pinned_models"])
ep.pinned_models = json.dumps(_pinned) if _pinned else None
# Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key
@@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery):
"name": ep.name,
"model_type": ep.model_type,
"base_url": ep.base_url,
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
}
finally:
db.close()

View File

@@ -9,7 +9,9 @@ import httpx
from fastapi import APIRouter, HTTPException, Request, Form
from pydantic import BaseModel, Field
from core.database import SessionLocal, Webhook
from core.database import SessionLocal, Webhook, ModelEndpoint
from src.auth_helpers import owner_filter
from src.url_security import validate_public_http_url
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
logger = logging.getLogger(__name__)
@@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000
from core.middleware import require_admin as _require_admin
def _first_enabled_endpoint(db, owner):
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint
is per-user (core/database.py — "when non-null, the model picker only shows
the endpoint to that user"), and the sync-chat fallback uses the row's
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token
(e.g. a paired mobile device) fall back onto ANOTHER user's private
endpoint and silently spend that owner's API key / quota — and reach
whatever internal base_url they configured. Mirrors the owner_filter scoping
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
no-op (single-user / legacy mode), preserving the original behaviour.
def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]):
"""First enabled ModelEndpoint visible to token_owner — their own rows plus
legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would
let a chat-scoped token fall back onto another user's private endpoint and
silently spend that owner's API key/quota. Prefer owner rows before shared
rows. Fails closed to null-owner rows only when token_owner is absent.
Does not validate base_url — admin-configured local/LAN endpoints remain allowed.
"""
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
q = owner_filter(q, ModelEndpoint, owner)
return q.first()
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
if token_owner:
query = owner_filter(query, ModelEndpoint, token_owner)
return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first()
return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711
def _caller_owns_session(sess_owner, caller) -> bool:
@@ -278,15 +276,21 @@ def setup_webhook_routes(
api_key = body.api_key.strip()
model = body.model or "deepseek-chat"
# Resolve base_url: explicit > provider name > model prefix auto-detect
base_url = body.base_url.strip().rstrip("/") if body.base_url else None
if not base_url:
# Validate only token-supplied direct base_url; auto-resolved known-provider
# URLs are not subject to extra local/LAN blocking beyond existing provider logic.
direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None
if direct_base_url:
try:
base_url = validate_public_http_url(direct_base_url)
except ValueError as e:
detail = str(e).replace("URL", "base_url", 1)
raise HTTPException(400, detail)
else:
base_url = _resolve_base_url(model, body.provider)
if not base_url:
raise HTTPException(400,
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
"or provider ('deepseek', 'openai', 'groq', etc.)")
base_url = normalize_base(base_url)
endpoint_url = build_chat_url(base_url)
@@ -306,9 +310,7 @@ def setup_webhook_routes(
if not sess:
db = SessionLocal()
try:
# Owner-scoped: only THIS token owner's endpoints + legacy
# shared rows, never another user's private endpoint/api_key.
ep = _first_enabled_endpoint(db, token_owner)
ep = _select_api_chat_fallback_endpoint(db, token_owner)
finally:
db.close()

View File

@@ -2,12 +2,49 @@
import re
import logging
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S")
def _utcnow_naive() -> datetime:
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
The age is measured against UTC, not local time. The previous code used
``datetime.now()`` (local) against UTC-style published dates, so the age was
skewed by the host's UTC offset; it was also a latent crash once neighbouring
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
"""
if not age_str:
return 0.0
dt = None
for fmt in _AGE_FORMATS:
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
now = now or _utcnow_naive()
days_old = (now - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
_SPORTS_HINTS = {
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
@@ -73,24 +110,6 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
return 0.7
return 0.4
def recency_score(age_str: Optional[str]) -> float:
if not age_str:
return 0.0
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
days_old = (datetime.now() - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
if not is_news_query:
return 0.0

View File

@@ -1,151 +1,13 @@
"""Search result ranking based on relevance, source quality, and recency."""
"""Compatibility re-export shim for the live ranking module.
import re
import logging
from datetime import datetime, timezone
from typing import List, Optional
from urllib.parse import urlparse
The real implementation lives in :mod:`services.search.ranking`, which is what
the search runtime (services/search/core.py) imports. This module used to hold a
parallel copy; it now re-exports so the two cannot drift out of sync again.
"""
logger = logging.getLogger(__name__)
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S")
def _utcnow_naive() -> datetime:
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
The age is measured against UTC, not local time. The previous code used
``datetime.now()`` (local) against UTC-style published dates, so the age was
skewed by the host's UTC offset; it was also a latent crash once neighbouring
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
"""
if not age_str:
return 0.0
dt = None
for fmt in _AGE_FORMATS:
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
now = now or _utcnow_naive()
days_old = (now - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
_SPORTS_HINTS = {
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
}
# Word-boundary match so "sport" does not fire inside "transport"/"passport"
# and a domain like "transport.gov" is not mistaken for a sports site.
_SPORTS_HINT_RE = re.compile(
r"\b(?:" + "|".join(re.escape(h) for h in _SPORTS_HINTS) + r")\b"
from services.search.ranking import ( # noqa: F401
_AGE_FORMATS,
_utcnow_naive,
rank_search_results,
recency_score,
)
_LOW_VALUE_NEWS_DOMAINS = {
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
"www.yahoo.com", "msn.com", "www.msn.com",
}
_TRUSTED_NEWS_DOMAINS = {
"apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com",
"bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca",
"ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca",
"theguardian.com",
"www.theguardian.com", "euronews.com", "www.euronews.com",
"dw.com", "www.dw.com", "government.se", "www.government.se",
}
def _domain(url: str) -> str:
try:
return urlparse(url).netloc.lower()
except Exception:
return ""
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
query_lc = query.lower()
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
is_sports_query = bool(_SPORTS_HINT_RE.search(query_lc))
def title_score(title: str) -> float:
if not title:
return 0.0
title_lc = title.lower()
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
return matches / len(query_terms) if query_terms else 0.0
def snippet_score(snippet: str) -> float:
if not snippet:
return 0.0
length_factor = min(len(snippet), 200) / 200
term_hits = sum(1 for term in query_terms if term in snippet.lower())
term_factor = term_hits / len(query_terms) if query_terms else 0.0
return (length_factor + term_factor) / 2
def domain_score(url: str) -> float:
netloc = _domain(url)
if not netloc:
return 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
return 1.0
if netloc.endswith(".edu") or netloc.endswith(".gov"):
return 1.0
if netloc.endswith(".org"):
return 0.7
return 0.4
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
if not is_news_query:
return 0.0
text = f"{title} {snippet}".lower()
netloc = _domain(url)
adjustment = 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
adjustment += 1.2
if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")):
adjustment += 0.4
if netloc in _LOW_VALUE_NEWS_DOMAINS:
adjustment -= 0.8
if not is_sports_query and (_SPORTS_HINT_RE.search(text) or _SPORTS_HINT_RE.search(netloc)):
adjustment -= 1.5
# A country/news query should not rank a page whose title/snippet barely
# mentions the country above actual news pages for that country.
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
adjustment -= 1.0
return adjustment
ranked = []
for result in results:
title = result.get("title", "")
snippet = result.get("snippet", "")
url = result.get("url", "")
age = result.get("age", None)
score = (
2.0 * title_score(title)
+ 1.0 * snippet_score(snippet)
+ 1.5 * domain_score(url)
+ 1.0 * recency_score(age)
+ news_quality_adjustment(title, snippet, url)
)
ranked.append((score, result))
ranked.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in ranked]

94
src/url_security.py Normal file
View File

@@ -0,0 +1,94 @@
"""URL validation helpers for server-side outbound requests."""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
_INTERNAL_HOSTNAMES = {
"localhost",
"metadata",
"metadata.google.internal",
}
_INTERNAL_SUFFIXES = (
".localhost",
".local",
".internal",
".lan",
".intranet",
)
_BLOCKED_NETWORKS = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
)
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
ips: list[ipaddress._BaseAddress] = []
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
if family in (socket.AF_INET, socket.AF_INET6):
ips.append(ipaddress.ip_address(sockaddr[0]))
return ips
def _blocked_ip(addr: ipaddress._BaseAddress) -> bool:
return (
any(addr in net for net in _BLOCKED_NETWORKS)
or addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_multicast
or addr.is_unspecified
or addr.is_reserved
)
def _host_resolves_publicly(hostname: str) -> bool:
host = hostname.strip().lower()
if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES):
return False
try:
return not _blocked_ip(ipaddress.ip_address(host))
except ValueError:
pass
try:
addrs = _resolve_hostname_ips(host)
except OSError:
return False
return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs)
def is_public_http_url(url: str) -> bool:
parsed = urlparse((url or "").strip())
if parsed.scheme not in ("http", "https") or not parsed.hostname:
return False
return _host_resolves_publicly(parsed.hostname)
def validate_public_http_url(url: str, *, max_length: int = 2048) -> str:
"""Validate a user/API-token supplied server-side HTTP(S) endpoint.
This is for untrusted outbound URLs, not admin-created model endpoints
that are intentionally allowed to point at private model providers. DNS
failures fail closed, and DNS checks reduce obvious private-network
targets but do not eliminate every DNS rebinding race by themselves.
"""
cleaned = (url or "").strip()
if len(cleaned) > max_length:
raise ValueError("URL is too long")
if not is_public_http_url(cleaned):
raise ValueError("URL must point to a public HTTP(S) endpoint")
return cleaned

View File

@@ -0,0 +1,401 @@
import ipaddress
import importlib.util
import sys
import types
from pathlib import Path
import pytest
@pytest.mark.parametrize("url", [
"http://127.0.0.1:8000/v1",
"http://localhost:8000/v1",
"http://10.0.0.5/v1",
"http://172.16.0.1/v1",
"http://192.168.1.2/v1",
"http://169.254.169.254/latest/meta-data/",
"http://metadata.google.internal/",
"http://[::1]:8000/v1",
"http://[fc00::1]/v1",
"http://224.0.0.1/v1",
"http://0.0.0.0/v1",
"file:///etc/passwd",
])
def test_public_url_validator_blocks_internal_targets(url):
from src.url_security import is_public_http_url
assert is_public_http_url(url) is False
def test_public_url_validator_allows_public_endpoint(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("10.0.0.5")],
)
with pytest.raises(ValueError):
url_security.validate_public_http_url("https://api.example.com/v1")
def _load_webhook_routes_for_test(monkeypatch):
# Load under a unique module name so each test gets a fresh module object
# rather than a cached one from a previous monkeypatch run.
core_pkg = types.ModuleType("core")
core_pkg.__path__ = []
core_db = types.ModuleType("core.database")
core_db.SessionLocal = object
core_db.Webhook = object
core_db.ModelEndpoint = object
core_middleware = types.ModuleType("core.middleware")
core_middleware.require_admin = lambda request: None
webhook_manager = types.ModuleType("src.webhook_manager")
webhook_manager.WebhookManager = object
webhook_manager.validate_webhook_url = lambda url: url
webhook_manager.validate_events = lambda events: events
monkeypatch.setitem(sys.modules, "core", core_pkg)
monkeypatch.setitem(sys.modules, "core.database", core_db)
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
module_name = "routes.webhook_routes_under_test"
spec = importlib.util.spec_from_file_location(
module_name,
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class _Expr:
def __init__(self, fn):
self.fn = fn
def __call__(self, row):
return self.fn(row)
def __or__(self, other):
return _Expr(lambda row: self(row) or other(row))
class _Column:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return _Expr(lambda row: getattr(row, self.name) == other)
def desc(self):
return ("desc", self.name)
class _ModelEndpoint:
is_enabled = _Column("is_enabled")
owner = _Column("owner")
created_at = _Column("created_at")
class _Endpoint:
def __init__(
self,
*,
owner,
is_enabled=True,
created_at=1,
base_url="https://api.example.com/v1",
api_key=None,
):
self.owner = owner
self.is_enabled = is_enabled
self.created_at = created_at
self.base_url = base_url
self.api_key = api_key
class _EndpointQuery:
def __init__(self, rows):
self.rows = rows
self.filters = []
self.orders = []
def filter(self, *exprs):
self.filters.extend(exprs)
return self
def order_by(self, *exprs):
self.orders.extend(exprs)
return self
def first(self):
rows = self.rows
for expr in self.filters:
rows = [row for row in rows if expr(row)]
# Apply sort keys right-to-left so the leftmost key ends up as the
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
# multi-column ORDER BY behaviour).
for order in reversed(self.orders):
reverse = False
name = getattr(order, "name", None)
if isinstance(order, tuple) and order[0] == "desc":
reverse = True
name = order[1]
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
if name != "owner":
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
return rows[0] if rows else None
class _DB:
def __init__(self, rows):
self.query_obj = _EndpointQuery(rows)
self.closed = False
def query(self, model):
assert model is _ModelEndpoint
return self.query_obj
def close(self):
self.closed = True
class _ChatSession:
def __init__(self, endpoint_url, model):
self.endpoint_url = endpoint_url
self.model = model
self.headers = {}
self.history = []
def add_message(self, message):
self.history.append(message)
class _SessionManager:
def __init__(self):
self.created = []
self.save_calls = 0
def create_session(self, *, session_id, name, endpoint_url, model, owner):
session = _ChatSession(endpoint_url, model)
self.created.append({
"session_id": session_id,
"name": name,
"endpoint_url": endpoint_url,
"model": model,
"owner": owner,
"session": session,
})
return session
def save_sessions(self):
self.save_calls += 1
class _Request:
def __init__(self, *, owner="alice"):
self.state = types.SimpleNamespace(
api_token=True,
api_token_scopes=["chat"],
api_token_owner=owner,
)
class _WebhookManager:
async def fire(self, event, payload):
return None
def _install_sync_chat_stubs(monkeypatch):
# FastAPI checks for python_multipart at import time when Form is used;
# stub it so the optional dependency is not required in the test environment.
python_multipart = types.ModuleType("python_multipart")
python_multipart.__version__ = "0.0.13"
core_models = types.ModuleType("core.models")
class _ChatMessage:
def __init__(self, role, content):
self.role = role
self.content = content
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
return "mocked response"
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
llm_core = types.ModuleType("src.llm_core")
llm_core.llm_call_async = _llm_call_async
core_models.ChatMessage = _ChatMessage
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
monkeypatch.setitem(sys.modules, "core.models", core_models)
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
def _sync_chat_endpoint(webhook_routes, session_manager):
router = webhook_routes.setup_webhook_routes(
_WebhookManager(),
auth_manager=None,
session_manager=session_manager,
)
for route in router.routes:
if route.path == "/api/v1/chat":
return route.endpoint
raise AssertionError("sync chat route not found")
@pytest.mark.parametrize("base_url", [
"http://127.0.0.1:11434/v1",
"http://localhost:11434/v1",
"http://10.0.0.5/v1",
"http://169.254.169.254/latest/meta-data/",
])
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url=base_url,
model="test-model",
provider=None,
session=None,
)
with pytest.raises(webhook_routes.HTTPException) as exc:
await sync_chat(_Request(), body)
assert exc.value.status_code == 400
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
assert session_manager.created == []
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url="https://api.example.com/v1",
model="test-model",
provider=None,
session=None,
)
response = await sync_chat(_Request(), body)
assert response["response"] == "mocked response"
assert response["model"] == "test-model"
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", is_enabled=False, created_at=0),
_Endpoint(owner="bob", created_at=0),
_Endpoint(owner=None, created_at=1),
_Endpoint(owner="alice", created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
assert selected.owner == "alice"
assert selected.is_enabled is True
assert selected.created_at == 2
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", created_at=0),
_Endpoint(owner=None, is_enabled=False, created_at=1),
_Endpoint(owner=None, created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
assert selected.owner is None
assert selected.is_enabled is True
assert selected.created_at == 2
@pytest.mark.asyncio
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
local_endpoint = _Endpoint(
owner=None,
base_url="http://localhost:11434/v1",
api_key="configured-key",
)
db = _DB([local_endpoint])
calls = []
def _session_local():
return db
def _validate_public_http_url(url, *, max_length=2048):
calls.append(url)
raise AssertionError("configured fallback endpoint should not be publicly validated")
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
model="local-model",
api_key=None,
base_url=None,
provider=None,
session=None,
)
response = await sync_chat(_Request(owner=None), body)
assert response["response"] == "mocked response"
assert response["model"] == "local-model"
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
assert calls == []

View File

@@ -0,0 +1,147 @@
"""Guards the standalone GPU compose files against drift.
Stack-management UIs (Portainer, Coolify, Dockhand, ...) often accept only a
single compose file and do not honor COMPOSE_FILE or multiple ``-f`` overlays,
so the repo ships standalone ``docker-compose.gpu-*.yml`` files that inline the
GPU overlay. The base ``docker-compose.yml`` plus ``docker/gpu.*.yml`` overlays
remain the source of truth; these tests assert each standalone file equals the
base compose with only the matching overlay merged into the ``odysseus``
service. No Docker / docker compose is required — everything is pure YAML.
"""
import copy
from pathlib import Path
import pytest
import yaml
ROOT = Path(__file__).resolve().parents[1]
BASE = ROOT / "docker-compose.yml"
NVIDIA_OVERLAY = ROOT / "docker" / "gpu.nvidia.yml"
AMD_OVERLAY = ROOT / "docker" / "gpu.amd.yml"
NVIDIA_STANDALONE = ROOT / "docker-compose.gpu-nvidia.yml"
AMD_STANDALONE = ROOT / "docker-compose.gpu-amd.yml"
SERVICE = "odysseus"
def _load(path: Path) -> dict:
return yaml.safe_load(path.read_text(encoding="utf-8"))
def _deep_merge(base: dict, overlay: dict) -> dict:
"""Mirror docker compose overlay semantics for the keys these files use.
Mappings merge recursively; list-valued service fields are concatenated
(compose appends override sequences such as ``environment`` rather than
replacing them); scalars are overwritten. The overlays here only append to
``environment`` and add otherwise-absent keys (``deploy``, ``devices``,
``group_add``), so this keeps the expected merge explicit without invoking
docker compose.
"""
result = copy.deepcopy(base)
for key, value in overlay.items():
if isinstance(value, dict) and isinstance(result.get(key), dict):
result[key] = _deep_merge(result[key], value)
elif isinstance(value, list) and isinstance(result.get(key), list):
result[key] = copy.deepcopy(result[key]) + copy.deepcopy(value)
else:
result[key] = copy.deepcopy(value)
return result
def _merge_overlay_into_base(base: dict, overlay: dict) -> dict:
"""Build the expected standalone config: base + overlay on odysseus only."""
expected = copy.deepcopy(base)
overlay_service = overlay["services"][SERVICE]
expected["services"][SERVICE] = _deep_merge(
expected["services"][SERVICE], overlay_service
)
return expected
@pytest.fixture(scope="module")
def base():
return _load(BASE)
# --- Equivalence: standalone == base + overlay -----------------------------
def test_nvidia_standalone_equals_base_plus_overlay(base):
overlay = _load(NVIDIA_OVERLAY)
standalone = _load(NVIDIA_STANDALONE)
assert standalone == _merge_overlay_into_base(base, overlay)
def test_amd_standalone_equals_base_plus_overlay(base):
overlay = _load(AMD_OVERLAY)
standalone = _load(AMD_STANDALONE)
assert standalone == _merge_overlay_into_base(base, overlay)
# --- Non-odysseus services and volumes untouched ---------------------------
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
def test_non_odysseus_services_match_base(base, standalone_path):
standalone = _load(standalone_path)
for name, definition in base["services"].items():
if name == SERVICE:
continue
assert standalone["services"][name] == definition
assert set(standalone["services"]) == set(base["services"])
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
def test_top_level_volumes_match_base(base, standalone_path):
standalone = _load(standalone_path)
assert standalone.get("volumes") == base.get("volumes")
# --- odysseus = base service + only the overlay additions ------------------
def test_nvidia_odysseus_adds_only_overlay(base):
standalone = _load(NVIDIA_STANDALONE)
svc = standalone["services"][SERVICE]
base_svc = base["services"][SERVICE]
# Base environment preserved, plus exactly the two NVIDIA variables.
assert "NVIDIA_VISIBLE_DEVICES=all" in svc["environment"]
assert "NVIDIA_DRIVER_CAPABILITIES=compute,utility" in svc["environment"]
added_env = set(svc["environment"]) - set(base_svc["environment"])
assert added_env == {
"NVIDIA_VISIBLE_DEVICES=all",
"NVIDIA_DRIVER_CAPABILITIES=compute,utility",
}
# deploy block is new and matches the overlay's GPU reservation exactly.
assert "deploy" not in base_svc
devices = svc["deploy"]["resources"]["reservations"]["devices"]
assert devices == [
{"driver": "nvidia", "count": "all", "capabilities": ["gpu"]}
]
# No AMD-only keys leaked in.
assert "devices" not in svc
assert "group_add" not in svc
def test_amd_odysseus_adds_only_overlay(base):
standalone = _load(AMD_STANDALONE)
svc = standalone["services"][SERVICE]
base_svc = base["services"][SERVICE]
# Environment is unchanged from base for AMD.
assert svc["environment"] == base_svc["environment"]
# devices and group_add are new and match the overlay exactly.
assert "devices" not in base_svc
assert "group_add" not in base_svc
assert svc["devices"] == ["/dev/kfd", "/dev/dri"]
assert svc["group_add"] == ["video", "${RENDER_GID:-render}"]
# No NVIDIA-only keys leaked in.
assert "deploy" not in svc

View File

@@ -66,3 +66,37 @@ def test_normal_model_payload_keeps_temperature(monkeypatch):
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2)
assert payload["temperature"] == 0.2
assert payload["max_tokens"] == 5
def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
# OpenAI/local providers may validly use temperatures above 1.0; the clamp
# is Anthropic-only and must not touch this path.
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 1.2)
assert payload["temperature"] == 1.2
def _anthropic_payload(temperature):
return llm_core._build_anthropic_payload(
"claude-3-5-sonnet",
[{"role": "user", "content": "Hi"}],
temperature,
max_tokens=5,
)
def test_anthropic_payload_clamps_above_one():
# Anthropic rejects temperature > 1.0 (e.g. the Nietzsche preset's 1.2).
assert _anthropic_payload(1.2)["temperature"] == 1.0
def test_anthropic_payload_keeps_in_range():
assert _anthropic_payload(0.7)["temperature"] == 0.7
def test_anthropic_payload_clamps_negative():
assert _anthropic_payload(-0.5)["temperature"] == 0.0
def test_anthropic_payload_none_temperature_does_not_crash():
payload = _anthropic_payload(None)
assert payload["temperature"] is None

View File

@@ -1,6 +1,9 @@
"""Tests for model route helper functions — pure logic, no server needed."""
import asyncio
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import httpx
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
from routes.model_routes import (
_match_provider_curated,
_curate_models,
_visible_models,
_normalize_model_ids,
_is_chat_model,
_classify_endpoint,
_probe_endpoint,
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
assert model_routes._docker_host_gateway_reachable() is False
# ── pinned model IDs: normalization helper ──
class TestNormalizeModelIds:
def test_list_passthrough_trims_and_dedupes(self):
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
def test_json_string_list(self):
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
def test_comma_and_newline_string(self):
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
def test_none_and_empty(self):
assert _normalize_model_ids(None) == []
assert _normalize_model_ids("") == []
assert _normalize_model_ids(" ") == []
def test_non_string_values_ignored(self):
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
# ── pinned model IDs: _visible_models merge ──
class TestVisibleModelsPinned:
def test_includes_pinned_not_in_cached(self):
visible = _visible_models(["a"], None, ["deploy-1"])
assert visible == ["a", "deploy-1"]
def test_cached_plus_pinned_dedup_preserves_order(self):
visible = _visible_models(["a", "b"], None, ["b", "c"])
assert visible == ["a", "b", "c"]
def test_hidden_can_hide_a_pinned_model(self):
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
assert visible == ["a"]
def test_accepts_json_string_inputs(self):
visible = _visible_models('["a"]', '["a"]', '["b"]')
assert visible == ["b"]
# ── pinned model IDs: route behaviour ──
# Building the router exercises FastAPI's Form() routes, which require
# python-multipart. The test env ships without it, so register a minimal stub
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
if "python_multipart" not in sys.modules:
try:
import python_multipart # noqa: F401
except ImportError:
_mp_stub = types.ModuleType("python_multipart")
_mp_stub.__version__ = "0.0.13"
sys.modules["python_multipart"] = _mp_stub
class _PinnedFakeQuery:
def __init__(self, rows):
self.rows = list(rows)
def filter(self, *conditions):
return self
def order_by(self, *args):
return self
def first(self):
return self.rows[0] if self.rows else None
def all(self):
return list(self.rows)
class _PinnedFakeDb:
def __init__(self, rows):
self.rows = rows
self.added = []
self.committed = 0
def query(self, model):
return _PinnedFakeQuery(self.rows)
def add(self, row):
self.added.append(row)
def commit(self):
self.committed += 1
def close(self):
pass
class _FakeCol:
"""Column stand-in: every comparison/operator just returns itself so the
dedupe query expressions evaluate without a real SQLAlchemy column."""
__hash__ = None
def __eq__(self, other):
return self
def is_(self, other):
return self
def __or__(self, other):
return self
def desc(self):
return self
class _RecordingEndpoint:
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
Class-level fake columns let it double as the query class in the dedupe
lookup; instance attributes (set in __init__) shadow them per-row.
"""
id = _FakeCol()
base_url = _FakeCol()
owner = _FakeCol()
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _PinnedFakeRequest:
def __init__(self, body=None, headers=None):
self._body = body if body is not None else {}
self.headers = headers or {}
async def json(self):
return self._body
def _get_route(path, method):
from routes.model_routes import setup_model_routes
router = setup_model_routes(model_discovery=None)
for route in router.routes:
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError(f"{method} {path} not found")
def _make_endpoint(**kwargs):
base = dict(
id="ep1",
name="EP",
base_url="http://localhost:9999/v1",
api_key=None,
is_enabled=True,
hidden_models=None,
cached_models=None,
pinned_models=None,
model_type="llm",
supports_tools=None,
)
base.update(kwargs)
return SimpleNamespace(**base)
def test_patch_models_saves_pinned_models(monkeypatch):
ep = _make_endpoint()
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
result = asyncio.run(endpoint("ep1", request))
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
assert result["pinned_count"] == 2
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
asyncio.run(endpoint("ep1", request))
assert json.loads(ep.hidden_models) == ["hide-me"]
assert json.loads(ep.pinned_models) == ["deploy-1"]
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
result = endpoint("ep1", _PinnedFakeRequest())
ids = [row["id"] for row in result]
assert ids == ["deploy-1"]
assert result[0]["is_pinned"] is True
def test_reprobe_preserves_pinned_models(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
monkeypatch.setattr(
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
)
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
response = endpoint("ep1", _PinnedFakeRequest())
async def _drain():
async for _ in response.body_iterator:
pass
asyncio.run(_drain())
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
assert json.loads(ep.pinned_models) == ["deploy-1"]
assert json.loads(ep.cached_models) == ["m1"]
def test_visible_models_handles_malformed_strings():
# Non-JSON cached/pinned strings are treated as comma/newline lists and
# never raise; a malformed hidden string is normalized too.
result = _visible_models("a,b", "b", "{bad json")
assert isinstance(result, list)
assert result == ["a", "{bad json"]
assert _visible_models("", None, "") == []
assert _visible_models("only-cached", None, None) == ["only-cached"]
def _create_form_kwargs(**overrides):
"""Defaults for every Form() param create_model_endpoint reads directly.
Calling the route as a plain function bypasses FastAPI form parsing, so the
Form() sentinels must be replaced with real strings.
"""
kwargs = dict(
name="",
api_key="",
skip_probe="true", # avoid any network probe in unit tests
require_models="false",
model_type="llm",
supports_tools="",
pinned_models="",
container_local="false",
shared="true",
)
kwargs.update(overrides)
return kwargs
def _patch_create_deps(monkeypatch, db):
import src.auth_helpers as auth_helpers
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
db = _PinnedFakeDb([]) # no existing row → fresh create path
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
)
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
assert result["models"] == ["deploy-1", "deploy-2"]
assert result["online"] is True
# Persisted onto the created row.
assert len(db.added) == 1
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
hidden_models=None,
pinned_models=json.dumps(["old-pin"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="new-pin"),
)
assert result["existing"] is True
# Incoming pin merged onto the existing pins (no clobber, order preserved).
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
assert result["pinned_models"] == ["old-pin", "new-pin"]
# models = cached + pinned - hidden, visible merged list.
assert result["models"] == ["m1", "old-pin", "new-pin"]
# No new row created on the dedupe path.
assert db.added == []
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
pinned_models=json.dumps(["keep-me"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(), # pinned_models defaults to ""
)
assert json.loads(existing.pinned_models) == ["keep-me"]
assert result["pinned_models"] == ["keep-me"]
assert db.committed == 0 # nothing to persist

View File

@@ -247,10 +247,14 @@ class _Column:
def __eq__(self, value):
return _Predicate(lambda row: getattr(row, self.name) == value)
def desc(self):
return self
class _ModelEndpoint:
is_enabled = _Column("is_enabled")
owner = _Column("owner")
created_at = _Column("created_at")
class _Query:
@@ -261,6 +265,9 @@ class _Query:
self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
return self
def order_by(self, *exprs):
return self
def first(self):
return self._rows[0] if self._rows else None
@@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True):
def _select(rows, owner):
wh_mod = _import_webhook_helper()
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint
return wh_mod._first_enabled_endpoint(_DB(rows), owner)
# _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
# (not a local import), so we patch the module attribute directly.
wh_mod.ModelEndpoint = _ModelEndpoint
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
def test_sync_chat_fallback_never_picks_another_owners_endpoint():
@@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop():
# An unresolvable/empty token owner keeps the original single-user behaviour
# (owner_filter no-op): first enabled row, whatever it is.
rows = [_ep("first", "bob"), _ep("second", "alice")]
def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
# When no token owner is known, only null-owner (shared) endpoints are
# visible — private endpoints of any user must not be returned.
rows = [_ep("bob-private", "bob"), _ep("shared", None)]
ep = _select(rows, None)
assert ep is not None and ep.name == "first"
assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
# No shared rows → fail closed rather than returning another user's endpoint.
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
assert _select(rows, None) is None

View File

@@ -1,4 +1,4 @@
from src.search.ranking import rank_search_results
from services.search.ranking import rank_search_results
def test_news_queries_prefer_news_sources_over_sports_and_social_results():

View File

@@ -8,7 +8,8 @@ module-level, time-injectable function.
from datetime import datetime, timezone
from src.search.ranking import recency_score, _utcnow_naive
import services.search.ranking as live_ranking
from services.search.ranking import recency_score, _utcnow_naive, rank_search_results
def test_fresh_result_scores_one():
@@ -37,3 +38,37 @@ def test_default_now_is_naive_utc():
assert now.tzinfo is None
reference = datetime.now(timezone.utc).replace(tzinfo=None)
assert abs((now - reference).total_seconds()) < 5
def test_supported_timestamp_formats_parse():
# All three formats the current implementation supports resolve to the same
# ~4-day-old age, so each scores a full 1.0.
now = datetime(2026, 1, 5, 12, 0, 0)
assert recency_score("2026-01-01", now=now) == 1.0
assert recency_score("2026-01-01T08:30:00", now=now) == 1.0
assert recency_score("2026-01-01 08:30:00", now=now) == 1.0
def test_shim_reexports_live_objects():
# src.search.ranking is a compatibility shim; it must expose the *same*
# objects as the live services module so the two cannot diverge.
import src.search.ranking as shim
assert shim.recency_score is live_ranking.recency_score
assert shim.rank_search_results is live_ranking.rank_search_results
assert shim._utcnow_naive is live_ranking._utcnow_naive
def test_live_rank_path_prefers_newer_result(monkeypatch):
# Pin "now" so age scoring is deterministic. The two results are identical
# apart from age, isolating recency as the only differentiator.
monkeypatch.setattr(live_ranking, "_utcnow_naive", lambda: datetime(2026, 1, 31))
results = [
{"title": "Report", "url": "https://example.org/a", "snippet": "x", "age": "2026-01-01"},
{"title": "Report", "url": "https://example.org/b", "snippet": "x", "age": "2026-01-29"},
]
ranked = rank_search_results("report", results)
assert ranked[0]["url"] == "https://example.org/b"
assert ranked[1]["url"] == "https://example.org/a"