Reapply "Merge branch 'main' of github.com:pewdiepie-archdaemon/odysseus"
This reverts commit cc8fe2f6e3.
This commit is contained in:
14
README.md
14
README.md
@@ -189,6 +189,20 @@ RENDER_GID=989
|
||||
|
||||
For NVIDIA/AMD GPU support, also read the comments in the selected overlay file: docker/gpu.nvidia.yml or docker/gpu.amd.yml.
|
||||
|
||||
**Stack-management UIs (Portainer, Coolify, Dockhand, etc.).** These tools
|
||||
often accept only a single Compose file and do not reliably honor `COMPOSE_FILE`
|
||||
or multiple `-f` overlays. CLI users should keep using the `COMPOSE_FILE`
|
||||
overlay workflow above. For stack UIs, point the stack at one of the standalone
|
||||
files instead, which bundle the base stack plus the GPU settings:
|
||||
|
||||
- `docker-compose.gpu-nvidia.yml` — still requires the NVIDIA Container Toolkit
|
||||
on the host.
|
||||
- `docker-compose.gpu-amd.yml` — still requires host ROCm/kfd/DRI setup, the
|
||||
`video`/`render` group membership, and `RENDER_GID` when needed.
|
||||
|
||||
The base `docker-compose.yml` plus the `docker/gpu.*.yml` overlays remain the
|
||||
source of truth; the standalone files mirror them for single-file deployments.
|
||||
|
||||
Verify after enabling either overlay:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base):
|
||||
is_enabled = Column(Boolean, default=True)
|
||||
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
|
||||
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
|
||||
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
|
||||
model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
|
||||
# Whether models on this endpoint accept OpenAI-style function
|
||||
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
|
||||
@@ -856,6 +857,24 @@ def _migrate_add_cached_models_column():
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
||||
|
||||
def _migrate_add_pinned_models_column():
|
||||
"""Add pinned_models column to model_endpoints if it doesn't exist."""
|
||||
import sqlite3
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
if columns and "pinned_models" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
|
||||
|
||||
def _migrate_add_notes_sort_order():
|
||||
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
||||
import sqlite3
|
||||
@@ -1511,6 +1530,7 @@ def init_db():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_add_hidden_models_column()
|
||||
_migrate_add_cached_models_column()
|
||||
_migrate_add_pinned_models_column()
|
||||
_migrate_add_notes_sort_order()
|
||||
_migrate_add_model_type_column()
|
||||
_migrate_add_model_endpoint_owner_column()
|
||||
|
||||
164
docker-compose.gpu-amd.yml
Normal file
164
docker-compose.gpu-amd.yml
Normal file
@@ -0,0 +1,164 @@
|
||||
# Standalone AMD ROCm GPU Compose file for stack-management UIs (Portainer,
|
||||
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
|
||||
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
|
||||
#
|
||||
# This is equivalent to: docker-compose.yml + docker/gpu.amd.yml.
|
||||
# The base docker-compose.yml plus the docker/gpu.amd.yml overlay remain the
|
||||
# source of truth — CLI users should keep using the COMPOSE_FILE overlay
|
||||
# workflow. Keep this file in sync with both when either changes.
|
||||
#
|
||||
# Requires ROCm drivers on the host (kfd + DRI devices) and the host user
|
||||
# running Docker in the `video` and `render` groups. Set RENDER_GID to your
|
||||
# host's numeric render group id when needed. See docker/gpu.amd.yml for details.
|
||||
services:
|
||||
odysseus:
|
||||
build: .
|
||||
ports:
|
||||
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
|
||||
volumes:
|
||||
- ./data:/app/data:z
|
||||
- ./logs:/app/logs:z
|
||||
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
|
||||
# add the shown public key to each remote server's authorized_keys.
|
||||
- ./data/ssh:/app/.ssh:z
|
||||
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
|
||||
# container, so persist its HuggingFace cache under ./data/huggingface.
|
||||
- ./data/huggingface:/app/.cache/huggingface:z
|
||||
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
|
||||
# land under /app/.local for the odysseus user. Persist them so a
|
||||
# container recreate does not silently remove installed serve engines.
|
||||
- ./data/local:/app/.local:z
|
||||
extra_hosts:
|
||||
# Lets the container reach local services on the Docker host, including
|
||||
# Ollama at http://host.docker.internal:11434.
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
- LLM_HOST=${LLM_HOST:-localhost}
|
||||
- LLM_HOSTS=${LLM_HOSTS:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
|
||||
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
|
||||
- HF_TOKEN=${HF_TOKEN:-}
|
||||
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
|
||||
- SEARXNG_INSTANCE=http://searxng:8080
|
||||
- CHROMADB_HOST=chromadb
|
||||
- CHROMADB_PORT=8000
|
||||
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
|
||||
- AUTH_ENABLED=${AUTH_ENABLED:-true}
|
||||
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
|
||||
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
|
||||
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
|
||||
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
|
||||
- SERPER_API_KEY=${SERPER_API_KEY:-}
|
||||
# PUID / PGID — the user/group the container drops to before
|
||||
# running uvicorn (entrypoint also chowns /app/data + /app/logs
|
||||
# to match, so bind-mounted files stay editable from the host).
|
||||
# 1000 is the default first user on most Linux installs. If your
|
||||
# host user has a different id, override here or via .env, e.g.:
|
||||
# PUID=1001
|
||||
# PGID=1001
|
||||
# Find yours with: id -u / id -g
|
||||
- PUID=${PUID:-1000}
|
||||
- PGID=${PGID:-1000}
|
||||
depends_on:
|
||||
searxng:
|
||||
condition: service_healthy
|
||||
chromadb:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
# AMD ROCm overlay (from docker/gpu.amd.yml).
|
||||
devices:
|
||||
- /dev/kfd
|
||||
- /dev/dri
|
||||
group_add:
|
||||
- video
|
||||
- ${RENDER_GID:-render}
|
||||
|
||||
chromadb:
|
||||
image: docker.io/chromadb/chroma:latest
|
||||
ports:
|
||||
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
|
||||
volumes:
|
||||
- chromadb-data:/chroma/chroma
|
||||
environment:
|
||||
- ANONYMIZED_TELEMETRY=FALSE
|
||||
restart: unless-stopped
|
||||
|
||||
searxng:
|
||||
# Pinned, not :latest — odysseus waits on searxng's healthcheck
|
||||
# (depends_on: condition: service_healthy), so a broken upstream `latest`
|
||||
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
|
||||
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
|
||||
# Bump this deliberately after verifying a newer tag boots clean.
|
||||
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
|
||||
entrypoint:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
set -eu
|
||||
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
|
||||
secret="$${SEARXNG_SECRET:-}"
|
||||
if [ -z "$$secret" ]; then
|
||||
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
|
||||
fi
|
||||
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
|
||||
fi
|
||||
exec /usr/local/searxng/entrypoint.sh
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080"
|
||||
volumes:
|
||||
- searxng-data:/etc/searxng
|
||||
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
|
||||
environment:
|
||||
- SEARXNG_BASE_URL=http://localhost:8080/
|
||||
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
|
||||
# The official searxng image runs as the non-root `searxng` user, but its
|
||||
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
|
||||
# su-exec, and (with our wrapper above) write settings.yml into the named
|
||||
# volume. Without these capabilities the wrapper aborts at the redirection
|
||||
# with EACCES and the container fails its healthcheck with permission
|
||||
# errors during setup. Mirrors the cap set recommended by the upstream
|
||||
# searxng-docker compose file. See issue #721.
|
||||
cap_drop:
|
||||
- ALL
|
||||
cap_add:
|
||||
- CHOWN
|
||||
- SETGID
|
||||
- SETUID
|
||||
- DAC_OVERRIDE
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
|
||||
interval: 5s
|
||||
timeout: 6s
|
||||
retries: 20
|
||||
start_period: 10s
|
||||
restart: unless-stopped
|
||||
|
||||
ntfy:
|
||||
image: docker.io/binwiederhier/ntfy
|
||||
command: serve
|
||||
ports:
|
||||
- "${NTFY_BIND:-127.0.0.1}:8091:80"
|
||||
volumes:
|
||||
- ntfy-cache:/var/cache/ntfy
|
||||
environment:
|
||||
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
searxng-data:
|
||||
chromadb-data:
|
||||
ntfy-cache:
|
||||
167
docker-compose.gpu-nvidia.yml
Normal file
167
docker-compose.gpu-nvidia.yml
Normal file
@@ -0,0 +1,167 @@
|
||||
# Standalone NVIDIA GPU Compose file for stack-management UIs (Portainer,
|
||||
# Coolify, Dockhand, etc.) that accept only a single Compose file and do not
|
||||
# reliably honor COMPOSE_FILE or multiple `-f` overlays.
|
||||
#
|
||||
# This is equivalent to: docker-compose.yml + docker/gpu.nvidia.yml.
|
||||
# The base docker-compose.yml plus the docker/gpu.nvidia.yml overlay remain
|
||||
# the source of truth — CLI users should keep using the COMPOSE_FILE overlay
|
||||
# workflow. Keep this file in sync with both when either changes.
|
||||
#
|
||||
# Requires the NVIDIA Container Toolkit on the host. See docker/gpu.nvidia.yml
|
||||
# for setup details.
|
||||
services:
|
||||
odysseus:
|
||||
build: .
|
||||
ports:
|
||||
- "${APP_BIND:-127.0.0.1}:${APP_PORT:-7000}:7000"
|
||||
volumes:
|
||||
- ./data:/app/data:z
|
||||
- ./logs:/app/logs:z
|
||||
# Cookbook remote-server SSH identity. Odysseus can generate a key here;
|
||||
# add the shown public key to each remote server's authorized_keys.
|
||||
- ./data/ssh:/app/.ssh:z
|
||||
# Cookbook local model cache. Inside Docker, "Local" means the Odysseus
|
||||
# container, so persist its HuggingFace cache under ./data/huggingface.
|
||||
- ./data/huggingface:/app/.cache/huggingface:z
|
||||
# Cookbook-installed Python CLIs/packages (vLLM, llama-cpp-python, etc.)
|
||||
# land under /app/.local for the odysseus user. Persist them so a
|
||||
# container recreate does not silently remove installed serve engines.
|
||||
- ./data/local:/app/.local:z
|
||||
extra_hosts:
|
||||
# Lets the container reach local services on the Docker host, including
|
||||
# Ollama at http://host.docker.internal:11434.
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
- LLM_HOST=${LLM_HOST:-localhost}
|
||||
- LLM_HOSTS=${LLM_HOSTS:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
|
||||
- RESEARCH_LLM_ENDPOINT=${RESEARCH_LLM_ENDPOINT:-}
|
||||
- HF_TOKEN=${HF_TOKEN:-}
|
||||
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
|
||||
- SEARXNG_INSTANCE=http://searxng:8080
|
||||
- CHROMADB_HOST=chromadb
|
||||
- CHROMADB_PORT=8000
|
||||
- DATABASE_URL=${DATABASE_URL:-sqlite:///./data/app.db}
|
||||
- AUTH_ENABLED=${AUTH_ENABLED:-true}
|
||||
- LOCALHOST_BYPASS=${LOCALHOST_BYPASS:-false}
|
||||
- ODYSSEUS_ADMIN_USER=${ODYSSEUS_ADMIN_USER:-admin}
|
||||
- ODYSSEUS_ADMIN_PASSWORD=${ODYSSEUS_ADMIN_PASSWORD:-}
|
||||
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost,http://127.0.0.1}
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
|
||||
- SERPER_API_KEY=${SERPER_API_KEY:-}
|
||||
# PUID / PGID — the user/group the container drops to before
|
||||
# running uvicorn (entrypoint also chowns /app/data + /app/logs
|
||||
# to match, so bind-mounted files stay editable from the host).
|
||||
# 1000 is the default first user on most Linux installs. If your
|
||||
# host user has a different id, override here or via .env, e.g.:
|
||||
# PUID=1001
|
||||
# PGID=1001
|
||||
# Find yours with: id -u / id -g
|
||||
- PUID=${PUID:-1000}
|
||||
- PGID=${PGID:-1000}
|
||||
# NVIDIA overlay (from docker/gpu.nvidia.yml).
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
depends_on:
|
||||
searxng:
|
||||
condition: service_healthy
|
||||
chromadb:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
# NVIDIA overlay (from docker/gpu.nvidia.yml).
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
|
||||
chromadb:
|
||||
image: docker.io/chromadb/chroma:latest
|
||||
ports:
|
||||
- "${CHROMADB_BIND:-127.0.0.1}:8100:8000"
|
||||
volumes:
|
||||
- chromadb-data:/chroma/chroma
|
||||
environment:
|
||||
- ANONYMIZED_TELEMETRY=FALSE
|
||||
restart: unless-stopped
|
||||
|
||||
searxng:
|
||||
# Pinned, not :latest — odysseus waits on searxng's healthcheck
|
||||
# (depends_on: condition: service_healthy), so a broken upstream `latest`
|
||||
# tag blocks the whole app from starting. 2026.6.2 crashes on boot with
|
||||
# `KeyError: 'default_doi_resolver'`, failing the healthcheck (issue #1414).
|
||||
# Bump this deliberately after verifying a newer tag boots clean.
|
||||
image: docker.io/searxng/searxng:2026.5.31-7159b8aed
|
||||
entrypoint:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
set -eu
|
||||
if [ ! -s /etc/searxng/settings.yml ] || grep -q 'odysseus-local-searxng-json-2026-05-30\|__SEARXNG_SECRET__' /etc/searxng/settings.yml; then
|
||||
secret="$${SEARXNG_SECRET:-}"
|
||||
if [ -z "$$secret" ]; then
|
||||
secret="$$(python -c 'import secrets; print(secrets.token_urlsafe(48))')"
|
||||
fi
|
||||
sed "s|__SEARXNG_SECRET__|$$secret|g" /tmp/searxng-settings.yml.template > /etc/searxng/settings.yml
|
||||
fi
|
||||
exec /usr/local/searxng/entrypoint.sh
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080"
|
||||
volumes:
|
||||
- searxng-data:/etc/searxng
|
||||
- ./config/searxng/settings.yml:/tmp/searxng-settings.yml.template:ro,z
|
||||
environment:
|
||||
- SEARXNG_BASE_URL=http://localhost:8080/
|
||||
- SEARXNG_SECRET=${SEARXNG_SECRET:-}
|
||||
# The official searxng image runs as the non-root `searxng` user, but its
|
||||
# entrypoint still needs to chown /etc/searxng on first boot, drop privs via
|
||||
# su-exec, and (with our wrapper above) write settings.yml into the named
|
||||
# volume. Without these capabilities the wrapper aborts at the redirection
|
||||
# with EACCES and the container fails its healthcheck with permission
|
||||
# errors during setup. Mirrors the cap set recommended by the upstream
|
||||
# searxng-docker compose file. See issue #721.
|
||||
cap_drop:
|
||||
- ALL
|
||||
cap_add:
|
||||
- CHOWN
|
||||
- SETGID
|
||||
- SETUID
|
||||
- DAC_OVERRIDE
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "python -c \"import urllib.request; urllib.request.urlopen('http://localhost:8080/', timeout=5).read(1)\""]
|
||||
interval: 5s
|
||||
timeout: 6s
|
||||
retries: 20
|
||||
start_period: 10s
|
||||
restart: unless-stopped
|
||||
|
||||
ntfy:
|
||||
image: docker.io/binwiederhier/ntfy
|
||||
command: serve
|
||||
ports:
|
||||
- "${NTFY_BIND:-127.0.0.1}:8091:80"
|
||||
volumes:
|
||||
- ntfy-cache:/var/cache/ntfy
|
||||
environment:
|
||||
- NTFY_BASE_URL=${NTFY_BASE_URL:-http://localhost:8091}
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
searxng-data:
|
||||
chromadb-data:
|
||||
ntfy-cache:
|
||||
@@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
|
||||
return "No models found for that provider/key."
|
||||
|
||||
|
||||
def _visible_models(cached_models, hidden_models):
|
||||
"""Filter cached model IDs by hidden_models. Returns list of visible IDs."""
|
||||
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
|
||||
def _normalize_model_ids(value):
|
||||
"""Coerce a model-ID input into a clean, ordered list of strings.
|
||||
|
||||
Accepts a list, a JSON-encoded list string, or a comma/newline separated
|
||||
string (handy for form or backend API input). Trims whitespace, drops
|
||||
empty and non-string values, and de-duplicates preserving first-seen order.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
items = value
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except Exception:
|
||||
parsed = None
|
||||
items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text)
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
out, seen = [], set()
|
||||
for item in items:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
s = item.strip()
|
||||
if not s or s in seen:
|
||||
continue
|
||||
seen.add(s)
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _merge_model_ids(*lists):
|
||||
"""Concatenate model-ID lists, de-duplicating and preserving order."""
|
||||
out, seen = [], set()
|
||||
for ids in lists:
|
||||
for m in (ids or []):
|
||||
if not isinstance(m, str) or m in seen:
|
||||
continue
|
||||
seen.add(m)
|
||||
out.append(m)
|
||||
return out
|
||||
|
||||
|
||||
def _visible_models(cached_models, hidden_models, pinned_models=None):
|
||||
"""Merge cached + pinned model IDs, then filter out hidden ones.
|
||||
|
||||
Pinned IDs are admin-entered and may not appear in cached_models (e.g.
|
||||
cloud deployment IDs the provider does not list in /v1/models). Returns an
|
||||
ordered, de-duplicated list of visible IDs.
|
||||
"""
|
||||
# Normalize each input so JSON strings, lists, comma/newline strings, and
|
||||
# malformed strings are all handled without raising.
|
||||
merged = _merge_model_ids(
|
||||
_normalize_model_ids(cached_models),
|
||||
_normalize_model_ids(pinned_models),
|
||||
)
|
||||
if not hidden_models:
|
||||
return all_models
|
||||
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or []))
|
||||
return [m for m in all_models if m not in hidden]
|
||||
return merged
|
||||
hidden = set(_normalize_model_ids(hidden_models))
|
||||
return [m for m in merged if m not in hidden]
|
||||
|
||||
|
||||
def setup_model_routes(model_discovery):
|
||||
@@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery):
|
||||
hidden = set(json.loads(r.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
visible = [m for m in all_models if m not in hidden]
|
||||
status = "online" if all_models else "offline"
|
||||
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
|
||||
visible = _visible_models(all_models, r.hidden_models, pinned)
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
# admin-pinned IDs that a probe would never surface.
|
||||
status = "online" if (all_models or pinned) else "offline"
|
||||
ping = None
|
||||
if not all_models and r.is_enabled:
|
||||
if not all_models and not pinned and r.is_enabled:
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
@@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery):
|
||||
"has_key": bool(r.api_key),
|
||||
"is_enabled": r.is_enabled,
|
||||
"models": visible,
|
||||
"pinned_models": pinned,
|
||||
"hidden_count": len(hidden),
|
||||
"online": status != "offline",
|
||||
"status": status,
|
||||
@@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery):
|
||||
require_models: str = Form("false"),
|
||||
model_type: str = Form("llm"),
|
||||
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
||||
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
|
||||
container_local: str = Form("false"),
|
||||
# Default `shared=true` → endpoints are visible to all users (the
|
||||
# app's historical behaviour). Admins can pass `shared=false` to
|
||||
@@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery):
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
# Persist any incoming pinned IDs onto the existing row. An
|
||||
# empty/omitted form field must not wipe previously pinned IDs.
|
||||
_incoming_pinned = _normalize_model_ids(pinned_models)
|
||||
if _incoming_pinned:
|
||||
_merged_pinned = _merge_model_ids(
|
||||
_normalize_model_ids(getattr(existing, "pinned_models", None)),
|
||||
_incoming_pinned,
|
||||
)
|
||||
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
|
||||
_db_dedup.commit()
|
||||
_invalidate_models_cache()
|
||||
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
|
||||
return {
|
||||
"id": existing.id,
|
||||
"name": existing.name,
|
||||
"base_url": existing.base_url,
|
||||
"models": json.loads(existing.cached_models) if existing.cached_models else [],
|
||||
"models": _visible_models(
|
||||
getattr(existing, "cached_models", None),
|
||||
getattr(existing, "hidden_models", None),
|
||||
existing.pinned_models,
|
||||
),
|
||||
"pinned_models": _existing_pinned,
|
||||
"online": True,
|
||||
"status": "online",
|
||||
"existing": True,
|
||||
@@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery):
|
||||
try:
|
||||
_st_raw = (supports_tools or "").strip().lower()
|
||||
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
|
||||
_pinned = _normalize_model_ids(pinned_models)
|
||||
# Stamp owner so the picker only shows this endpoint to the admin
|
||||
# who added it. Pass `shared=true` to mark it null-owner (visible
|
||||
# to all users), preserving the pre-fix "everyone sees everything"
|
||||
@@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery):
|
||||
is_enabled=True,
|
||||
model_type=model_type.strip() if model_type else "llm",
|
||||
cached_models=json.dumps(model_ids) if model_ids else None,
|
||||
pinned_models=json.dumps(_pinned) if _pinned else None,
|
||||
supports_tools=_st,
|
||||
owner=_owner_val,
|
||||
)
|
||||
@@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep_id,
|
||||
"name": name.strip(),
|
||||
"base_url": base_url,
|
||||
"models": model_ids,
|
||||
"online": bool(model_ids) or bool(ping.get("reachable")),
|
||||
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"),
|
||||
"models": _merge_model_ids(model_ids, _pinned),
|
||||
"pinned_models": _pinned,
|
||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
}
|
||||
|
||||
@@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery):
|
||||
hidden = set(json.loads(ep.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
# Try live probe, fall back to cached
|
||||
# Try live probe, fall back to cached. Pinned IDs are admin-entered
|
||||
# and persist regardless of probe results — never overwritten here.
|
||||
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
|
||||
if all_models:
|
||||
ep.cached_models = json.dumps(all_models)
|
||||
@@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery):
|
||||
all_models = json.loads(ep.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
|
||||
pinned_set = set(pinned)
|
||||
return [
|
||||
{"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden}
|
||||
for m in all_models
|
||||
{
|
||||
"id": m,
|
||||
"display": m.split("/")[-1],
|
||||
"is_hidden": m in hidden,
|
||||
"is_pinned": m in pinned_set,
|
||||
}
|
||||
for m in _merge_model_ids(all_models, pinned)
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.patch("/model-endpoints/{ep_id}/models")
|
||||
async def update_hidden_models(ep_id: str, request: Request):
|
||||
"""Bulk update hidden models list for an endpoint.
|
||||
"""Bulk update hidden and/or pinned model lists for an endpoint.
|
||||
|
||||
Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]}
|
||||
Expects JSON body with optional keys:
|
||||
{"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]}
|
||||
Each key is updated only when present, so callers can patch one list
|
||||
without clobbering the other.
|
||||
"""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
@@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery):
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
body = await request.json()
|
||||
hidden = body.get("hidden", [])
|
||||
if not isinstance(hidden, list):
|
||||
raise HTTPException(400, "hidden must be a list of model IDs")
|
||||
ep.hidden_models = json.dumps(hidden) if hidden else None
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(400, "Body must be a JSON object")
|
||||
if "hidden" in body:
|
||||
hidden = body.get("hidden")
|
||||
if not isinstance(hidden, list):
|
||||
raise HTTPException(400, "hidden must be a list of model IDs")
|
||||
ep.hidden_models = json.dumps(hidden) if hidden else None
|
||||
# Accept either "pinned" or "pinned_models" for the manual IDs list.
|
||||
if "pinned_models" in body or "pinned" in body:
|
||||
pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned")))
|
||||
ep.pinned_models = json.dumps(pinned) if pinned else None
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
return {"id": ep_id, "hidden_count": len(hidden)}
|
||||
hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0
|
||||
pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0
|
||||
return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery):
|
||||
return {"endpoint_id": "", "endpoint_url": "", "model": ""}
|
||||
base = _normalize_base(ep.base_url)
|
||||
chat_url = build_chat_url(base)
|
||||
if not model and getattr(ep, "cached_models", None):
|
||||
if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)):
|
||||
try:
|
||||
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None))
|
||||
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None))
|
||||
if visible:
|
||||
model = visible[0]
|
||||
except Exception:
|
||||
@@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery):
|
||||
ep.name = body["name"].strip() or ep.name
|
||||
if "model_type" in body and isinstance(body["model_type"], str):
|
||||
ep.model_type = body["model_type"].strip() or ep.model_type
|
||||
if "pinned_models" in body:
|
||||
_pinned = _normalize_model_ids(body["pinned_models"])
|
||||
ep.pinned_models = json.dumps(_pinned) if _pinned else None
|
||||
# Rotating an API key used to require DELETE+POST, which wiped
|
||||
# endpoint_url/model from every session referencing the old base
|
||||
# URL. Allow in-place updates so the admin can change the key
|
||||
@@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery):
|
||||
"name": ep.name,
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -9,7 +9,9 @@ import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.database import SessionLocal, Webhook
|
||||
from core.database import SessionLocal, Webhook, ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.url_security import validate_public_http_url
|
||||
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000
|
||||
from core.middleware import require_admin as _require_admin
|
||||
|
||||
|
||||
def _first_enabled_endpoint(db, owner):
|
||||
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus
|
||||
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint
|
||||
is per-user (core/database.py — "when non-null, the model picker only shows
|
||||
the endpoint to that user"), and the sync-chat fallback uses the row's
|
||||
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token
|
||||
(e.g. a paired mobile device) fall back onto ANOTHER user's private
|
||||
endpoint and silently spend that owner's API key / quota — and reach
|
||||
whatever internal base_url they configured. Mirrors the owner_filter scoping
|
||||
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
|
||||
no-op (single-user / legacy mode), preserving the original behaviour.
|
||||
def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]):
|
||||
"""First enabled ModelEndpoint visible to token_owner — their own rows plus
|
||||
legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would
|
||||
let a chat-scoped token fall back onto another user's private endpoint and
|
||||
silently spend that owner's API key/quota. Prefer owner rows before shared
|
||||
rows. Fails closed to null-owner rows only when token_owner is absent.
|
||||
Does not validate base_url — admin-configured local/LAN endpoints remain allowed.
|
||||
"""
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
return q.first()
|
||||
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
|
||||
if token_owner:
|
||||
query = owner_filter(query, ModelEndpoint, token_owner)
|
||||
return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first()
|
||||
return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711
|
||||
|
||||
|
||||
def _caller_owns_session(sess_owner, caller) -> bool:
|
||||
@@ -278,15 +276,21 @@ def setup_webhook_routes(
|
||||
api_key = body.api_key.strip()
|
||||
model = body.model or "deepseek-chat"
|
||||
|
||||
# Resolve base_url: explicit > provider name > model prefix auto-detect
|
||||
base_url = body.base_url.strip().rstrip("/") if body.base_url else None
|
||||
if not base_url:
|
||||
# Validate only token-supplied direct base_url; auto-resolved known-provider
|
||||
# URLs are not subject to extra local/LAN blocking beyond existing provider logic.
|
||||
direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None
|
||||
if direct_base_url:
|
||||
try:
|
||||
base_url = validate_public_http_url(direct_base_url)
|
||||
except ValueError as e:
|
||||
detail = str(e).replace("URL", "base_url", 1)
|
||||
raise HTTPException(400, detail)
|
||||
else:
|
||||
base_url = _resolve_base_url(model, body.provider)
|
||||
if not base_url:
|
||||
raise HTTPException(400,
|
||||
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
|
||||
"or provider ('deepseek', 'openai', 'groq', etc.)")
|
||||
|
||||
base_url = normalize_base(base_url)
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
|
||||
@@ -306,9 +310,7 @@ def setup_webhook_routes(
|
||||
if not sess:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped: only THIS token owner's endpoints + legacy
|
||||
# shared rows, never another user's private endpoint/api_key.
|
||||
ep = _first_enabled_endpoint(db, token_owner)
|
||||
ep = _select_api_chat_fallback_endpoint(db, token_owner)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -2,12 +2,49 @@
|
||||
|
||||
import re
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _utcnow_naive() -> datetime:
|
||||
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
|
||||
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
|
||||
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
|
||||
|
||||
The age is measured against UTC, not local time. The previous code used
|
||||
``datetime.now()`` (local) against UTC-style published dates, so the age was
|
||||
skewed by the host's UTC offset; it was also a latent crash once neighbouring
|
||||
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
|
||||
"""
|
||||
if not age_str:
|
||||
return 0.0
|
||||
dt = None
|
||||
for fmt in _AGE_FORMATS:
|
||||
try:
|
||||
dt = datetime.strptime(age_str, fmt)
|
||||
break
|
||||
except Exception:
|
||||
dt = None
|
||||
if not dt:
|
||||
return 0.0
|
||||
now = now or _utcnow_naive()
|
||||
days_old = (now - dt).days
|
||||
if days_old <= 7:
|
||||
return 1.0
|
||||
if days_old >= 30:
|
||||
return 0.0
|
||||
return (30 - days_old) / 23
|
||||
|
||||
|
||||
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
|
||||
_SPORTS_HINTS = {
|
||||
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
|
||||
@@ -73,24 +110,6 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
return 0.7
|
||||
return 0.4
|
||||
|
||||
def recency_score(age_str: Optional[str]) -> float:
|
||||
if not age_str:
|
||||
return 0.0
|
||||
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
|
||||
try:
|
||||
dt = datetime.strptime(age_str, fmt)
|
||||
break
|
||||
except Exception:
|
||||
dt = None
|
||||
if not dt:
|
||||
return 0.0
|
||||
days_old = (datetime.now() - dt).days
|
||||
if days_old <= 7:
|
||||
return 1.0
|
||||
if days_old >= 30:
|
||||
return 0.0
|
||||
return (30 - days_old) / 23
|
||||
|
||||
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
|
||||
if not is_news_query:
|
||||
return 0.0
|
||||
|
||||
@@ -1,151 +1,13 @@
|
||||
"""Search result ranking based on relevance, source quality, and recency."""
|
||||
"""Compatibility re-export shim for the live ranking module.
|
||||
|
||||
import re
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
The real implementation lives in :mod:`services.search.ranking`, which is what
|
||||
the search runtime (services/search/core.py) imports. This module used to hold a
|
||||
parallel copy; it now re-exports so the two cannot drift out of sync again.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_AGE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _utcnow_naive() -> datetime:
|
||||
"""Naive UTC 'now'. Matches the naive, UTC-style published dates parsed below,
|
||||
and is safe on Python 3.14 where ``datetime.utcnow()`` is removed (#1116)."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
def recency_score(age_str: Optional[str], now: Optional[datetime] = None) -> float:
|
||||
"""Score how recent a result is: 1.0 for <=7 days old, 0.0 for >=30 days.
|
||||
|
||||
The age is measured against UTC, not local time. The previous code used
|
||||
``datetime.now()`` (local) against UTC-style published dates, so the age was
|
||||
skewed by the host's UTC offset; it was also a latent crash once neighbouring
|
||||
code moves to timezone-aware datetimes (#1116). ``now`` is injectable for tests.
|
||||
"""
|
||||
if not age_str:
|
||||
return 0.0
|
||||
dt = None
|
||||
for fmt in _AGE_FORMATS:
|
||||
try:
|
||||
dt = datetime.strptime(age_str, fmt)
|
||||
break
|
||||
except Exception:
|
||||
dt = None
|
||||
if not dt:
|
||||
return 0.0
|
||||
now = now or _utcnow_naive()
|
||||
days_old = (now - dt).days
|
||||
if days_old <= 7:
|
||||
return 1.0
|
||||
if days_old >= 30:
|
||||
return 0.0
|
||||
return (30 - days_old) / 23
|
||||
|
||||
|
||||
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
|
||||
_SPORTS_HINTS = {
|
||||
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
|
||||
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
|
||||
}
|
||||
# Word-boundary match so "sport" does not fire inside "transport"/"passport"
|
||||
# and a domain like "transport.gov" is not mistaken for a sports site.
|
||||
_SPORTS_HINT_RE = re.compile(
|
||||
r"\b(?:" + "|".join(re.escape(h) for h in _SPORTS_HINTS) + r")\b"
|
||||
from services.search.ranking import ( # noqa: F401
|
||||
_AGE_FORMATS,
|
||||
_utcnow_naive,
|
||||
rank_search_results,
|
||||
recency_score,
|
||||
)
|
||||
_LOW_VALUE_NEWS_DOMAINS = {
|
||||
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
|
||||
"www.yahoo.com", "msn.com", "www.msn.com",
|
||||
}
|
||||
_TRUSTED_NEWS_DOMAINS = {
|
||||
"apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com",
|
||||
"bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca",
|
||||
"ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca",
|
||||
"theguardian.com",
|
||||
"www.theguardian.com", "euronews.com", "www.euronews.com",
|
||||
"dw.com", "www.dw.com", "government.se", "www.government.se",
|
||||
}
|
||||
|
||||
|
||||
def _domain(url: str) -> str:
|
||||
try:
|
||||
return urlparse(url).netloc.lower()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||
query_lc = query.lower()
|
||||
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
|
||||
is_sports_query = bool(_SPORTS_HINT_RE.search(query_lc))
|
||||
|
||||
def title_score(title: str) -> float:
|
||||
if not title:
|
||||
return 0.0
|
||||
title_lc = title.lower()
|
||||
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
|
||||
return matches / len(query_terms) if query_terms else 0.0
|
||||
|
||||
def snippet_score(snippet: str) -> float:
|
||||
if not snippet:
|
||||
return 0.0
|
||||
length_factor = min(len(snippet), 200) / 200
|
||||
term_hits = sum(1 for term in query_terms if term in snippet.lower())
|
||||
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
||||
return (length_factor + term_factor) / 2
|
||||
|
||||
def domain_score(url: str) -> float:
|
||||
netloc = _domain(url)
|
||||
if not netloc:
|
||||
return 0.0
|
||||
if netloc in _TRUSTED_NEWS_DOMAINS:
|
||||
return 1.0
|
||||
if netloc.endswith(".edu") or netloc.endswith(".gov"):
|
||||
return 1.0
|
||||
if netloc.endswith(".org"):
|
||||
return 0.7
|
||||
return 0.4
|
||||
|
||||
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
|
||||
if not is_news_query:
|
||||
return 0.0
|
||||
text = f"{title} {snippet}".lower()
|
||||
netloc = _domain(url)
|
||||
adjustment = 0.0
|
||||
if netloc in _TRUSTED_NEWS_DOMAINS:
|
||||
adjustment += 1.2
|
||||
if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")):
|
||||
adjustment += 0.4
|
||||
if netloc in _LOW_VALUE_NEWS_DOMAINS:
|
||||
adjustment -= 0.8
|
||||
if not is_sports_query and (_SPORTS_HINT_RE.search(text) or _SPORTS_HINT_RE.search(netloc)):
|
||||
adjustment -= 1.5
|
||||
# A country/news query should not rank a page whose title/snippet barely
|
||||
# mentions the country above actual news pages for that country.
|
||||
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
||||
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
|
||||
adjustment -= 1.0
|
||||
return adjustment
|
||||
|
||||
ranked = []
|
||||
for result in results:
|
||||
title = result.get("title", "")
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
age = result.get("age", None)
|
||||
|
||||
score = (
|
||||
2.0 * title_score(title)
|
||||
+ 1.0 * snippet_score(snippet)
|
||||
+ 1.5 * domain_score(url)
|
||||
+ 1.0 * recency_score(age)
|
||||
+ news_quality_adjustment(title, snippet, url)
|
||||
)
|
||||
ranked.append((score, result))
|
||||
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in ranked]
|
||||
|
||||
94
src/url_security.py
Normal file
94
src/url_security.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""URL validation helpers for server-side outbound requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
_INTERNAL_HOSTNAMES = {
|
||||
"localhost",
|
||||
"metadata",
|
||||
"metadata.google.internal",
|
||||
}
|
||||
|
||||
_INTERNAL_SUFFIXES = (
|
||||
".localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".lan",
|
||||
".intranet",
|
||||
)
|
||||
|
||||
_BLOCKED_NETWORKS = (
|
||||
ipaddress.ip_network("0.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("100.64.0.0/10"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("::/128"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
|
||||
ips: list[ipaddress._BaseAddress] = []
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
if family in (socket.AF_INET, socket.AF_INET6):
|
||||
ips.append(ipaddress.ip_address(sockaddr[0]))
|
||||
return ips
|
||||
|
||||
|
||||
def _blocked_ip(addr: ipaddress._BaseAddress) -> bool:
|
||||
return (
|
||||
any(addr in net for net in _BLOCKED_NETWORKS)
|
||||
or addr.is_private
|
||||
or addr.is_loopback
|
||||
or addr.is_link_local
|
||||
or addr.is_multicast
|
||||
or addr.is_unspecified
|
||||
or addr.is_reserved
|
||||
)
|
||||
|
||||
|
||||
def _host_resolves_publicly(hostname: str) -> bool:
|
||||
host = hostname.strip().lower()
|
||||
if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES):
|
||||
return False
|
||||
try:
|
||||
return not _blocked_ip(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
addrs = _resolve_hostname_ips(host)
|
||||
except OSError:
|
||||
return False
|
||||
return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs)
|
||||
|
||||
|
||||
def is_public_http_url(url: str) -> bool:
|
||||
parsed = urlparse((url or "").strip())
|
||||
if parsed.scheme not in ("http", "https") or not parsed.hostname:
|
||||
return False
|
||||
return _host_resolves_publicly(parsed.hostname)
|
||||
|
||||
|
||||
def validate_public_http_url(url: str, *, max_length: int = 2048) -> str:
|
||||
"""Validate a user/API-token supplied server-side HTTP(S) endpoint.
|
||||
|
||||
This is for untrusted outbound URLs, not admin-created model endpoints
|
||||
that are intentionally allowed to point at private model providers. DNS
|
||||
failures fail closed, and DNS checks reduce obvious private-network
|
||||
targets but do not eliminate every DNS rebinding race by themselves.
|
||||
"""
|
||||
cleaned = (url or "").strip()
|
||||
if len(cleaned) > max_length:
|
||||
raise ValueError("URL is too long")
|
||||
if not is_public_http_url(cleaned):
|
||||
raise ValueError("URL must point to a public HTTP(S) endpoint")
|
||||
return cleaned
|
||||
401
tests/test_api_chat_security.py
Normal file
401
tests/test_api_chat_security.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import ipaddress
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://127.0.0.1:8000/v1",
|
||||
"http://localhost:8000/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://172.16.0.1/v1",
|
||||
"http://192.168.1.2/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://metadata.google.internal/",
|
||||
"http://[::1]:8000/v1",
|
||||
"http://[fc00::1]/v1",
|
||||
"http://224.0.0.1/v1",
|
||||
"http://0.0.0.0/v1",
|
||||
"file:///etc/passwd",
|
||||
])
|
||||
def test_public_url_validator_blocks_internal_targets(url):
|
||||
from src.url_security import is_public_http_url
|
||||
|
||||
assert is_public_http_url(url) is False
|
||||
|
||||
|
||||
def test_public_url_validator_allows_public_endpoint(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
|
||||
|
||||
|
||||
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("10.0.0.5")],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
url_security.validate_public_http_url("https://api.example.com/v1")
|
||||
|
||||
|
||||
def _load_webhook_routes_for_test(monkeypatch):
|
||||
# Load under a unique module name so each test gets a fresh module object
|
||||
# rather than a cached one from a previous monkeypatch run.
|
||||
core_pkg = types.ModuleType("core")
|
||||
core_pkg.__path__ = []
|
||||
core_db = types.ModuleType("core.database")
|
||||
core_db.SessionLocal = object
|
||||
core_db.Webhook = object
|
||||
core_db.ModelEndpoint = object
|
||||
core_middleware = types.ModuleType("core.middleware")
|
||||
core_middleware.require_admin = lambda request: None
|
||||
webhook_manager = types.ModuleType("src.webhook_manager")
|
||||
webhook_manager.WebhookManager = object
|
||||
webhook_manager.validate_webhook_url = lambda url: url
|
||||
webhook_manager.validate_events = lambda events: events
|
||||
|
||||
monkeypatch.setitem(sys.modules, "core", core_pkg)
|
||||
monkeypatch.setitem(sys.modules, "core.database", core_db)
|
||||
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
|
||||
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
|
||||
|
||||
module_name = "routes.webhook_routes_under_test"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class _Expr:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, row):
|
||||
return self.fn(row)
|
||||
|
||||
def __or__(self, other):
|
||||
return _Expr(lambda row: self(row) or other(row))
|
||||
|
||||
|
||||
class _Column:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return _Expr(lambda row: getattr(row, self.name) == other)
|
||||
|
||||
def desc(self):
|
||||
return ("desc", self.name)
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Endpoint:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
owner,
|
||||
is_enabled=True,
|
||||
created_at=1,
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key=None,
|
||||
):
|
||||
self.owner = owner
|
||||
self.is_enabled = is_enabled
|
||||
self.created_at = created_at
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
|
||||
class _EndpointQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.filters = []
|
||||
self.orders = []
|
||||
|
||||
def filter(self, *exprs):
|
||||
self.filters.extend(exprs)
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
self.orders.extend(exprs)
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
rows = self.rows
|
||||
for expr in self.filters:
|
||||
rows = [row for row in rows if expr(row)]
|
||||
# Apply sort keys right-to-left so the leftmost key ends up as the
|
||||
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
|
||||
# multi-column ORDER BY behaviour).
|
||||
for order in reversed(self.orders):
|
||||
reverse = False
|
||||
name = getattr(order, "name", None)
|
||||
if isinstance(order, tuple) and order[0] == "desc":
|
||||
reverse = True
|
||||
name = order[1]
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
|
||||
if name != "owner":
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
|
||||
return rows[0] if rows else None
|
||||
|
||||
|
||||
class _DB:
|
||||
def __init__(self, rows):
|
||||
self.query_obj = _EndpointQuery(rows)
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is _ModelEndpoint
|
||||
return self.query_obj
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _ChatSession:
|
||||
def __init__(self, endpoint_url, model):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.model = model
|
||||
self.headers = {}
|
||||
self.history = []
|
||||
|
||||
def add_message(self, message):
|
||||
self.history.append(message)
|
||||
|
||||
|
||||
class _SessionManager:
|
||||
def __init__(self):
|
||||
self.created = []
|
||||
self.save_calls = 0
|
||||
|
||||
def create_session(self, *, session_id, name, endpoint_url, model, owner):
|
||||
session = _ChatSession(endpoint_url, model)
|
||||
self.created.append({
|
||||
"session_id": session_id,
|
||||
"name": name,
|
||||
"endpoint_url": endpoint_url,
|
||||
"model": model,
|
||||
"owner": owner,
|
||||
"session": session,
|
||||
})
|
||||
return session
|
||||
|
||||
def save_sessions(self):
|
||||
self.save_calls += 1
|
||||
|
||||
|
||||
class _Request:
|
||||
def __init__(self, *, owner="alice"):
|
||||
self.state = types.SimpleNamespace(
|
||||
api_token=True,
|
||||
api_token_scopes=["chat"],
|
||||
api_token_owner=owner,
|
||||
)
|
||||
|
||||
|
||||
class _WebhookManager:
|
||||
async def fire(self, event, payload):
|
||||
return None
|
||||
|
||||
|
||||
def _install_sync_chat_stubs(monkeypatch):
|
||||
# FastAPI checks for python_multipart at import time when Form is used;
|
||||
# stub it so the optional dependency is not required in the test environment.
|
||||
python_multipart = types.ModuleType("python_multipart")
|
||||
python_multipart.__version__ = "0.0.13"
|
||||
core_models = types.ModuleType("core.models")
|
||||
|
||||
class _ChatMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
|
||||
return "mocked response"
|
||||
|
||||
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
|
||||
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
|
||||
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
|
||||
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
|
||||
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
llm_core = types.ModuleType("src.llm_core")
|
||||
llm_core.llm_call_async = _llm_call_async
|
||||
core_models.ChatMessage = _ChatMessage
|
||||
|
||||
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
|
||||
monkeypatch.setitem(sys.modules, "core.models", core_models)
|
||||
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
|
||||
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
|
||||
|
||||
|
||||
def _sync_chat_endpoint(webhook_routes, session_manager):
|
||||
router = webhook_routes.setup_webhook_routes(
|
||||
_WebhookManager(),
|
||||
auth_manager=None,
|
||||
session_manager=session_manager,
|
||||
)
|
||||
for route in router.routes:
|
||||
if route.path == "/api/v1/chat":
|
||||
return route.endpoint
|
||||
raise AssertionError("sync chat route not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"http://127.0.0.1:11434/v1",
|
||||
"http://localhost:11434/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url=base_url,
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
with pytest.raises(webhook_routes.HTTPException) as exc:
|
||||
await sync_chat(_Request(), body)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
|
||||
assert session_manager.created == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url="https://api.example.com/v1",
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "test-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
|
||||
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", is_enabled=False, created_at=0),
|
||||
_Endpoint(owner="bob", created_at=0),
|
||||
_Endpoint(owner=None, created_at=1),
|
||||
_Endpoint(owner="alice", created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
|
||||
|
||||
assert selected.owner == "alice"
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", created_at=0),
|
||||
_Endpoint(owner=None, is_enabled=False, created_at=1),
|
||||
_Endpoint(owner=None, created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
|
||||
|
||||
assert selected.owner is None
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
local_endpoint = _Endpoint(
|
||||
owner=None,
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="configured-key",
|
||||
)
|
||||
db = _DB([local_endpoint])
|
||||
calls = []
|
||||
|
||||
def _session_local():
|
||||
return db
|
||||
|
||||
def _validate_public_http_url(url, *, max_length=2048):
|
||||
calls.append(url)
|
||||
raise AssertionError("configured fallback endpoint should not be publicly validated")
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
|
||||
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
model="local-model",
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(owner=None), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "local-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
|
||||
assert calls == []
|
||||
147
tests/test_gpu_compose_standalone.py
Normal file
147
tests/test_gpu_compose_standalone.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Guards the standalone GPU compose files against drift.
|
||||
|
||||
Stack-management UIs (Portainer, Coolify, Dockhand, ...) often accept only a
|
||||
single compose file and do not honor COMPOSE_FILE or multiple ``-f`` overlays,
|
||||
so the repo ships standalone ``docker-compose.gpu-*.yml`` files that inline the
|
||||
GPU overlay. The base ``docker-compose.yml`` plus ``docker/gpu.*.yml`` overlays
|
||||
remain the source of truth; these tests assert each standalone file equals the
|
||||
base compose with only the matching overlay merged into the ``odysseus``
|
||||
service. No Docker / docker compose is required — everything is pure YAML.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
BASE = ROOT / "docker-compose.yml"
|
||||
NVIDIA_OVERLAY = ROOT / "docker" / "gpu.nvidia.yml"
|
||||
AMD_OVERLAY = ROOT / "docker" / "gpu.amd.yml"
|
||||
NVIDIA_STANDALONE = ROOT / "docker-compose.gpu-nvidia.yml"
|
||||
AMD_STANDALONE = ROOT / "docker-compose.gpu-amd.yml"
|
||||
|
||||
SERVICE = "odysseus"
|
||||
|
||||
|
||||
def _load(path: Path) -> dict:
|
||||
return yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _deep_merge(base: dict, overlay: dict) -> dict:
|
||||
"""Mirror docker compose overlay semantics for the keys these files use.
|
||||
|
||||
Mappings merge recursively; list-valued service fields are concatenated
|
||||
(compose appends override sequences such as ``environment`` rather than
|
||||
replacing them); scalars are overwritten. The overlays here only append to
|
||||
``environment`` and add otherwise-absent keys (``deploy``, ``devices``,
|
||||
``group_add``), so this keeps the expected merge explicit without invoking
|
||||
docker compose.
|
||||
"""
|
||||
result = copy.deepcopy(base)
|
||||
for key, value in overlay.items():
|
||||
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
elif isinstance(value, list) and isinstance(result.get(key), list):
|
||||
result[key] = copy.deepcopy(result[key]) + copy.deepcopy(value)
|
||||
else:
|
||||
result[key] = copy.deepcopy(value)
|
||||
return result
|
||||
|
||||
|
||||
def _merge_overlay_into_base(base: dict, overlay: dict) -> dict:
|
||||
"""Build the expected standalone config: base + overlay on odysseus only."""
|
||||
expected = copy.deepcopy(base)
|
||||
overlay_service = overlay["services"][SERVICE]
|
||||
expected["services"][SERVICE] = _deep_merge(
|
||||
expected["services"][SERVICE], overlay_service
|
||||
)
|
||||
return expected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def base():
|
||||
return _load(BASE)
|
||||
|
||||
|
||||
# --- Equivalence: standalone == base + overlay -----------------------------
|
||||
|
||||
|
||||
def test_nvidia_standalone_equals_base_plus_overlay(base):
|
||||
overlay = _load(NVIDIA_OVERLAY)
|
||||
standalone = _load(NVIDIA_STANDALONE)
|
||||
assert standalone == _merge_overlay_into_base(base, overlay)
|
||||
|
||||
|
||||
def test_amd_standalone_equals_base_plus_overlay(base):
|
||||
overlay = _load(AMD_OVERLAY)
|
||||
standalone = _load(AMD_STANDALONE)
|
||||
assert standalone == _merge_overlay_into_base(base, overlay)
|
||||
|
||||
|
||||
# --- Non-odysseus services and volumes untouched ---------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
|
||||
def test_non_odysseus_services_match_base(base, standalone_path):
|
||||
standalone = _load(standalone_path)
|
||||
for name, definition in base["services"].items():
|
||||
if name == SERVICE:
|
||||
continue
|
||||
assert standalone["services"][name] == definition
|
||||
assert set(standalone["services"]) == set(base["services"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
|
||||
def test_top_level_volumes_match_base(base, standalone_path):
|
||||
standalone = _load(standalone_path)
|
||||
assert standalone.get("volumes") == base.get("volumes")
|
||||
|
||||
|
||||
# --- odysseus = base service + only the overlay additions ------------------
|
||||
|
||||
|
||||
def test_nvidia_odysseus_adds_only_overlay(base):
|
||||
standalone = _load(NVIDIA_STANDALONE)
|
||||
svc = standalone["services"][SERVICE]
|
||||
base_svc = base["services"][SERVICE]
|
||||
|
||||
# Base environment preserved, plus exactly the two NVIDIA variables.
|
||||
assert "NVIDIA_VISIBLE_DEVICES=all" in svc["environment"]
|
||||
assert "NVIDIA_DRIVER_CAPABILITIES=compute,utility" in svc["environment"]
|
||||
added_env = set(svc["environment"]) - set(base_svc["environment"])
|
||||
assert added_env == {
|
||||
"NVIDIA_VISIBLE_DEVICES=all",
|
||||
"NVIDIA_DRIVER_CAPABILITIES=compute,utility",
|
||||
}
|
||||
|
||||
# deploy block is new and matches the overlay's GPU reservation exactly.
|
||||
assert "deploy" not in base_svc
|
||||
devices = svc["deploy"]["resources"]["reservations"]["devices"]
|
||||
assert devices == [
|
||||
{"driver": "nvidia", "count": "all", "capabilities": ["gpu"]}
|
||||
]
|
||||
|
||||
# No AMD-only keys leaked in.
|
||||
assert "devices" not in svc
|
||||
assert "group_add" not in svc
|
||||
|
||||
|
||||
def test_amd_odysseus_adds_only_overlay(base):
|
||||
standalone = _load(AMD_STANDALONE)
|
||||
svc = standalone["services"][SERVICE]
|
||||
base_svc = base["services"][SERVICE]
|
||||
|
||||
# Environment is unchanged from base for AMD.
|
||||
assert svc["environment"] == base_svc["environment"]
|
||||
|
||||
# devices and group_add are new and match the overlay exactly.
|
||||
assert "devices" not in base_svc
|
||||
assert "group_add" not in base_svc
|
||||
assert svc["devices"] == ["/dev/kfd", "/dev/dri"]
|
||||
assert svc["group_add"] == ["video", "${RENDER_GID:-render}"]
|
||||
|
||||
# No NVIDIA-only keys leaked in.
|
||||
assert "deploy" not in svc
|
||||
@@ -66,3 +66,37 @@ def test_normal_model_payload_keeps_temperature(monkeypatch):
|
||||
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2)
|
||||
assert payload["temperature"] == 0.2
|
||||
assert payload["max_tokens"] == 5
|
||||
|
||||
|
||||
def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
|
||||
# OpenAI/local providers may validly use temperatures above 1.0; the clamp
|
||||
# is Anthropic-only and must not touch this path.
|
||||
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 1.2)
|
||||
assert payload["temperature"] == 1.2
|
||||
|
||||
|
||||
def _anthropic_payload(temperature):
|
||||
return llm_core._build_anthropic_payload(
|
||||
"claude-3-5-sonnet",
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
temperature,
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_payload_clamps_above_one():
|
||||
# Anthropic rejects temperature > 1.0 (e.g. the Nietzsche preset's 1.2).
|
||||
assert _anthropic_payload(1.2)["temperature"] == 1.0
|
||||
|
||||
|
||||
def test_anthropic_payload_keeps_in_range():
|
||||
assert _anthropic_payload(0.7)["temperature"] == 0.7
|
||||
|
||||
|
||||
def test_anthropic_payload_clamps_negative():
|
||||
assert _anthropic_payload(-0.5)["temperature"] == 0.0
|
||||
|
||||
|
||||
def test_anthropic_payload_none_temperature_does_not_crash():
|
||||
payload = _anthropic_payload(None)
|
||||
assert payload["temperature"] is None
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Tests for model route helper functions — pure logic, no server needed."""
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
_visible_models,
|
||||
_normalize_model_ids,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_probe_endpoint,
|
||||
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
|
||||
|
||||
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
|
||||
assert model_routes._docker_host_gateway_reachable() is False
|
||||
|
||||
|
||||
# ── pinned model IDs: normalization helper ──
|
||||
|
||||
|
||||
class TestNormalizeModelIds:
|
||||
def test_list_passthrough_trims_and_dedupes(self):
|
||||
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
|
||||
|
||||
def test_json_string_list(self):
|
||||
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
|
||||
|
||||
def test_comma_and_newline_string(self):
|
||||
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
|
||||
|
||||
def test_none_and_empty(self):
|
||||
assert _normalize_model_ids(None) == []
|
||||
assert _normalize_model_ids("") == []
|
||||
assert _normalize_model_ids(" ") == []
|
||||
|
||||
def test_non_string_values_ignored(self):
|
||||
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
|
||||
|
||||
|
||||
# ── pinned model IDs: _visible_models merge ──
|
||||
|
||||
|
||||
class TestVisibleModelsPinned:
|
||||
def test_includes_pinned_not_in_cached(self):
|
||||
visible = _visible_models(["a"], None, ["deploy-1"])
|
||||
assert visible == ["a", "deploy-1"]
|
||||
|
||||
def test_cached_plus_pinned_dedup_preserves_order(self):
|
||||
visible = _visible_models(["a", "b"], None, ["b", "c"])
|
||||
assert visible == ["a", "b", "c"]
|
||||
|
||||
def test_hidden_can_hide_a_pinned_model(self):
|
||||
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
|
||||
assert visible == ["a"]
|
||||
|
||||
def test_accepts_json_string_inputs(self):
|
||||
visible = _visible_models('["a"]', '["a"]', '["b"]')
|
||||
assert visible == ["b"]
|
||||
|
||||
|
||||
# ── pinned model IDs: route behaviour ──
|
||||
|
||||
# Building the router exercises FastAPI's Form() routes, which require
|
||||
# python-multipart. The test env ships without it, so register a minimal stub
|
||||
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
|
||||
if "python_multipart" not in sys.modules:
|
||||
try:
|
||||
import python_multipart # noqa: F401
|
||||
except ImportError:
|
||||
_mp_stub = types.ModuleType("python_multipart")
|
||||
_mp_stub.__version__ = "0.0.13"
|
||||
sys.modules["python_multipart"] = _mp_stub
|
||||
|
||||
|
||||
class _PinnedFakeQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
return self
|
||||
|
||||
def order_by(self, *args):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _PinnedFakeDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.added = []
|
||||
self.committed = 0
|
||||
|
||||
def query(self, model):
|
||||
return _PinnedFakeQuery(self.rows)
|
||||
|
||||
def add(self, row):
|
||||
self.added.append(row)
|
||||
|
||||
def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeCol:
|
||||
"""Column stand-in: every comparison/operator just returns itself so the
|
||||
dedupe query expressions evaluate without a real SQLAlchemy column."""
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self
|
||||
|
||||
def is_(self, other):
|
||||
return self
|
||||
|
||||
def __or__(self, other):
|
||||
return self
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _RecordingEndpoint:
|
||||
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
|
||||
|
||||
Class-level fake columns let it double as the query class in the dedupe
|
||||
lookup; instance attributes (set in __init__) shadow them per-row.
|
||||
"""
|
||||
|
||||
id = _FakeCol()
|
||||
base_url = _FakeCol()
|
||||
owner = _FakeCol()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _PinnedFakeRequest:
|
||||
def __init__(self, body=None, headers=None):
|
||||
self._body = body if body is not None else {}
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self):
|
||||
return self._body
|
||||
|
||||
|
||||
def _get_route(path, method):
|
||||
from routes.model_routes import setup_model_routes
|
||||
router = setup_model_routes(model_discovery=None)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{method} {path} not found")
|
||||
|
||||
|
||||
def _make_endpoint(**kwargs):
|
||||
base = dict(
|
||||
id="ep1",
|
||||
name="EP",
|
||||
base_url="http://localhost:9999/v1",
|
||||
api_key=None,
|
||||
is_enabled=True,
|
||||
hidden_models=None,
|
||||
cached_models=None,
|
||||
pinned_models=None,
|
||||
model_type="llm",
|
||||
supports_tools=None,
|
||||
)
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_patch_models_saves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint()
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
|
||||
result = asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
|
||||
assert result["pinned_count"] == 2
|
||||
|
||||
|
||||
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
|
||||
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
|
||||
asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.hidden_models) == ["hide-me"]
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
|
||||
|
||||
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
|
||||
|
||||
result = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
ids = [row["id"] for row in result]
|
||||
assert ids == ["deploy-1"]
|
||||
assert result[0]["is_pinned"] is True
|
||||
|
||||
|
||||
def test_reprobe_preserves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
|
||||
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||
monkeypatch.setattr(
|
||||
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
|
||||
)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||
|
||||
response = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
async def _drain():
|
||||
async for _ in response.body_iterator:
|
||||
pass
|
||||
|
||||
asyncio.run(_drain())
|
||||
|
||||
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
assert json.loads(ep.cached_models) == ["m1"]
|
||||
|
||||
|
||||
def test_visible_models_handles_malformed_strings():
|
||||
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||
# never raise; a malformed hidden string is normalized too.
|
||||
result = _visible_models("a,b", "b", "{bad json")
|
||||
assert isinstance(result, list)
|
||||
assert result == ["a", "{bad json"]
|
||||
assert _visible_models("", None, "") == []
|
||||
assert _visible_models("only-cached", None, None) == ["only-cached"]
|
||||
|
||||
|
||||
def _create_form_kwargs(**overrides):
|
||||
"""Defaults for every Form() param create_model_endpoint reads directly.
|
||||
|
||||
Calling the route as a plain function bypasses FastAPI form parsing, so the
|
||||
Form() sentinels must be replaced with real strings.
|
||||
"""
|
||||
kwargs = dict(
|
||||
name="",
|
||||
api_key="",
|
||||
skip_probe="true", # avoid any network probe in unit tests
|
||||
require_models="false",
|
||||
model_type="llm",
|
||||
supports_tools="",
|
||||
pinned_models="",
|
||||
container_local="false",
|
||||
shared="true",
|
||||
)
|
||||
kwargs.update(overrides)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _patch_create_deps(monkeypatch, db):
|
||||
import src.auth_helpers as auth_helpers
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
||||
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
||||
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
||||
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
||||
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
||||
|
||||
|
||||
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
|
||||
db = _PinnedFakeDb([]) # no existing row → fresh create path
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
|
||||
)
|
||||
|
||||
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["online"] is True
|
||||
# Persisted onto the created row.
|
||||
assert len(db.added) == 1
|
||||
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
|
||||
|
||||
|
||||
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
hidden_models=None,
|
||||
pinned_models=json.dumps(["old-pin"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="new-pin"),
|
||||
)
|
||||
|
||||
assert result["existing"] is True
|
||||
# Incoming pin merged onto the existing pins (no clobber, order preserved).
|
||||
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
|
||||
assert result["pinned_models"] == ["old-pin", "new-pin"]
|
||||
# models = cached + pinned - hidden, visible merged list.
|
||||
assert result["models"] == ["m1", "old-pin", "new-pin"]
|
||||
# No new row created on the dedupe path.
|
||||
assert db.added == []
|
||||
|
||||
|
||||
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
pinned_models=json.dumps(["keep-me"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(), # pinned_models defaults to ""
|
||||
)
|
||||
|
||||
assert json.loads(existing.pinned_models) == ["keep-me"]
|
||||
assert result["pinned_models"] == ["keep-me"]
|
||||
assert db.committed == 0 # nothing to persist
|
||||
|
||||
@@ -247,10 +247,14 @@ class _Column:
|
||||
def __eq__(self, value):
|
||||
return _Predicate(lambda row: getattr(row, self.name) == value)
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Query:
|
||||
@@ -261,6 +265,9 @@ class _Query:
|
||||
self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
@@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True):
|
||||
|
||||
def _select(rows, owner):
|
||||
wh_mod = _import_webhook_helper()
|
||||
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._first_enabled_endpoint(_DB(rows), owner)
|
||||
# _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
|
||||
# (not a local import), so we patch the module attribute directly.
|
||||
wh_mod.ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
|
||||
|
||||
|
||||
def test_sync_chat_fallback_never_picks_another_owners_endpoint():
|
||||
@@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop():
|
||||
# An unresolvable/empty token owner keeps the original single-user behaviour
|
||||
# (owner_filter no-op): first enabled row, whatever it is.
|
||||
rows = [_ep("first", "bob"), _ep("second", "alice")]
|
||||
def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
|
||||
# When no token owner is known, only null-owner (shared) endpoints are
|
||||
# visible — private endpoints of any user must not be returned.
|
||||
rows = [_ep("bob-private", "bob"), _ep("shared", None)]
|
||||
ep = _select(rows, None)
|
||||
assert ep is not None and ep.name == "first"
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
|
||||
# No shared rows → fail closed rather than returning another user's endpoint.
|
||||
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
|
||||
assert _select(rows, None) is None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.search.ranking import rank_search_results
|
||||
from services.search.ranking import rank_search_results
|
||||
|
||||
|
||||
def test_news_queries_prefer_news_sources_over_sports_and_social_results():
|
||||
|
||||
@@ -8,7 +8,8 @@ module-level, time-injectable function.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.search.ranking import recency_score, _utcnow_naive
|
||||
import services.search.ranking as live_ranking
|
||||
from services.search.ranking import recency_score, _utcnow_naive, rank_search_results
|
||||
|
||||
|
||||
def test_fresh_result_scores_one():
|
||||
@@ -37,3 +38,37 @@ def test_default_now_is_naive_utc():
|
||||
assert now.tzinfo is None
|
||||
reference = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
assert abs((now - reference).total_seconds()) < 5
|
||||
|
||||
|
||||
def test_supported_timestamp_formats_parse():
|
||||
# All three formats the current implementation supports resolve to the same
|
||||
# ~4-day-old age, so each scores a full 1.0.
|
||||
now = datetime(2026, 1, 5, 12, 0, 0)
|
||||
assert recency_score("2026-01-01", now=now) == 1.0
|
||||
assert recency_score("2026-01-01T08:30:00", now=now) == 1.0
|
||||
assert recency_score("2026-01-01 08:30:00", now=now) == 1.0
|
||||
|
||||
|
||||
def test_shim_reexports_live_objects():
|
||||
# src.search.ranking is a compatibility shim; it must expose the *same*
|
||||
# objects as the live services module so the two cannot diverge.
|
||||
import src.search.ranking as shim
|
||||
|
||||
assert shim.recency_score is live_ranking.recency_score
|
||||
assert shim.rank_search_results is live_ranking.rank_search_results
|
||||
assert shim._utcnow_naive is live_ranking._utcnow_naive
|
||||
|
||||
|
||||
def test_live_rank_path_prefers_newer_result(monkeypatch):
|
||||
# Pin "now" so age scoring is deterministic. The two results are identical
|
||||
# apart from age, isolating recency as the only differentiator.
|
||||
monkeypatch.setattr(live_ranking, "_utcnow_naive", lambda: datetime(2026, 1, 31))
|
||||
results = [
|
||||
{"title": "Report", "url": "https://example.org/a", "snippet": "x", "age": "2026-01-01"},
|
||||
{"title": "Report", "url": "https://example.org/b", "snippet": "x", "age": "2026-01-29"},
|
||||
]
|
||||
|
||||
ranked = rank_search_results("report", results)
|
||||
|
||||
assert ranked[0]["url"] == "https://example.org/b"
|
||||
assert ranked[1]["url"] == "https://example.org/a"
|
||||
|
||||
Reference in New Issue
Block a user