Revert "Merge branch 'main' of github.com:pewdiepie-archdaemon/odysseus"
This reverts commit8161c1253d, reversing changes made to8c2705b42a.
This commit is contained in:
@@ -1,401 +0,0 @@
|
||||
import ipaddress
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://127.0.0.1:8000/v1",
|
||||
"http://localhost:8000/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://172.16.0.1/v1",
|
||||
"http://192.168.1.2/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://metadata.google.internal/",
|
||||
"http://[::1]:8000/v1",
|
||||
"http://[fc00::1]/v1",
|
||||
"http://224.0.0.1/v1",
|
||||
"http://0.0.0.0/v1",
|
||||
"file:///etc/passwd",
|
||||
])
|
||||
def test_public_url_validator_blocks_internal_targets(url):
|
||||
from src.url_security import is_public_http_url
|
||||
|
||||
assert is_public_http_url(url) is False
|
||||
|
||||
|
||||
def test_public_url_validator_allows_public_endpoint(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
|
||||
|
||||
|
||||
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("10.0.0.5")],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
url_security.validate_public_http_url("https://api.example.com/v1")
|
||||
|
||||
|
||||
def _load_webhook_routes_for_test(monkeypatch):
|
||||
# Load under a unique module name so each test gets a fresh module object
|
||||
# rather than a cached one from a previous monkeypatch run.
|
||||
core_pkg = types.ModuleType("core")
|
||||
core_pkg.__path__ = []
|
||||
core_db = types.ModuleType("core.database")
|
||||
core_db.SessionLocal = object
|
||||
core_db.Webhook = object
|
||||
core_db.ModelEndpoint = object
|
||||
core_middleware = types.ModuleType("core.middleware")
|
||||
core_middleware.require_admin = lambda request: None
|
||||
webhook_manager = types.ModuleType("src.webhook_manager")
|
||||
webhook_manager.WebhookManager = object
|
||||
webhook_manager.validate_webhook_url = lambda url: url
|
||||
webhook_manager.validate_events = lambda events: events
|
||||
|
||||
monkeypatch.setitem(sys.modules, "core", core_pkg)
|
||||
monkeypatch.setitem(sys.modules, "core.database", core_db)
|
||||
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
|
||||
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
|
||||
|
||||
module_name = "routes.webhook_routes_under_test"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class _Expr:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, row):
|
||||
return self.fn(row)
|
||||
|
||||
def __or__(self, other):
|
||||
return _Expr(lambda row: self(row) or other(row))
|
||||
|
||||
|
||||
class _Column:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return _Expr(lambda row: getattr(row, self.name) == other)
|
||||
|
||||
def desc(self):
|
||||
return ("desc", self.name)
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Endpoint:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
owner,
|
||||
is_enabled=True,
|
||||
created_at=1,
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key=None,
|
||||
):
|
||||
self.owner = owner
|
||||
self.is_enabled = is_enabled
|
||||
self.created_at = created_at
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
|
||||
class _EndpointQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.filters = []
|
||||
self.orders = []
|
||||
|
||||
def filter(self, *exprs):
|
||||
self.filters.extend(exprs)
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
self.orders.extend(exprs)
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
rows = self.rows
|
||||
for expr in self.filters:
|
||||
rows = [row for row in rows if expr(row)]
|
||||
# Apply sort keys right-to-left so the leftmost key ends up as the
|
||||
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
|
||||
# multi-column ORDER BY behaviour).
|
||||
for order in reversed(self.orders):
|
||||
reverse = False
|
||||
name = getattr(order, "name", None)
|
||||
if isinstance(order, tuple) and order[0] == "desc":
|
||||
reverse = True
|
||||
name = order[1]
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
|
||||
if name != "owner":
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
|
||||
return rows[0] if rows else None
|
||||
|
||||
|
||||
class _DB:
|
||||
def __init__(self, rows):
|
||||
self.query_obj = _EndpointQuery(rows)
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is _ModelEndpoint
|
||||
return self.query_obj
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _ChatSession:
|
||||
def __init__(self, endpoint_url, model):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.model = model
|
||||
self.headers = {}
|
||||
self.history = []
|
||||
|
||||
def add_message(self, message):
|
||||
self.history.append(message)
|
||||
|
||||
|
||||
class _SessionManager:
|
||||
def __init__(self):
|
||||
self.created = []
|
||||
self.save_calls = 0
|
||||
|
||||
def create_session(self, *, session_id, name, endpoint_url, model, owner):
|
||||
session = _ChatSession(endpoint_url, model)
|
||||
self.created.append({
|
||||
"session_id": session_id,
|
||||
"name": name,
|
||||
"endpoint_url": endpoint_url,
|
||||
"model": model,
|
||||
"owner": owner,
|
||||
"session": session,
|
||||
})
|
||||
return session
|
||||
|
||||
def save_sessions(self):
|
||||
self.save_calls += 1
|
||||
|
||||
|
||||
class _Request:
|
||||
def __init__(self, *, owner="alice"):
|
||||
self.state = types.SimpleNamespace(
|
||||
api_token=True,
|
||||
api_token_scopes=["chat"],
|
||||
api_token_owner=owner,
|
||||
)
|
||||
|
||||
|
||||
class _WebhookManager:
|
||||
async def fire(self, event, payload):
|
||||
return None
|
||||
|
||||
|
||||
def _install_sync_chat_stubs(monkeypatch):
|
||||
# FastAPI checks for python_multipart at import time when Form is used;
|
||||
# stub it so the optional dependency is not required in the test environment.
|
||||
python_multipart = types.ModuleType("python_multipart")
|
||||
python_multipart.__version__ = "0.0.13"
|
||||
core_models = types.ModuleType("core.models")
|
||||
|
||||
class _ChatMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
|
||||
return "mocked response"
|
||||
|
||||
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
|
||||
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
|
||||
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
|
||||
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
|
||||
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
llm_core = types.ModuleType("src.llm_core")
|
||||
llm_core.llm_call_async = _llm_call_async
|
||||
core_models.ChatMessage = _ChatMessage
|
||||
|
||||
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
|
||||
monkeypatch.setitem(sys.modules, "core.models", core_models)
|
||||
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
|
||||
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
|
||||
|
||||
|
||||
def _sync_chat_endpoint(webhook_routes, session_manager):
|
||||
router = webhook_routes.setup_webhook_routes(
|
||||
_WebhookManager(),
|
||||
auth_manager=None,
|
||||
session_manager=session_manager,
|
||||
)
|
||||
for route in router.routes:
|
||||
if route.path == "/api/v1/chat":
|
||||
return route.endpoint
|
||||
raise AssertionError("sync chat route not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"http://127.0.0.1:11434/v1",
|
||||
"http://localhost:11434/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url=base_url,
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
with pytest.raises(webhook_routes.HTTPException) as exc:
|
||||
await sync_chat(_Request(), body)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
|
||||
assert session_manager.created == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url="https://api.example.com/v1",
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "test-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
|
||||
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", is_enabled=False, created_at=0),
|
||||
_Endpoint(owner="bob", created_at=0),
|
||||
_Endpoint(owner=None, created_at=1),
|
||||
_Endpoint(owner="alice", created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
|
||||
|
||||
assert selected.owner == "alice"
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", created_at=0),
|
||||
_Endpoint(owner=None, is_enabled=False, created_at=1),
|
||||
_Endpoint(owner=None, created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
|
||||
|
||||
assert selected.owner is None
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
local_endpoint = _Endpoint(
|
||||
owner=None,
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="configured-key",
|
||||
)
|
||||
db = _DB([local_endpoint])
|
||||
calls = []
|
||||
|
||||
def _session_local():
|
||||
return db
|
||||
|
||||
def _validate_public_http_url(url, *, max_length=2048):
|
||||
calls.append(url)
|
||||
raise AssertionError("configured fallback endpoint should not be publicly validated")
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
|
||||
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
model="local-model",
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(owner=None), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "local-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
|
||||
assert calls == []
|
||||
@@ -1,147 +0,0 @@
|
||||
"""Guards the standalone GPU compose files against drift.
|
||||
|
||||
Stack-management UIs (Portainer, Coolify, Dockhand, ...) often accept only a
|
||||
single compose file and do not honor COMPOSE_FILE or multiple ``-f`` overlays,
|
||||
so the repo ships standalone ``docker-compose.gpu-*.yml`` files that inline the
|
||||
GPU overlay. The base ``docker-compose.yml`` plus ``docker/gpu.*.yml`` overlays
|
||||
remain the source of truth; these tests assert each standalone file equals the
|
||||
base compose with only the matching overlay merged into the ``odysseus``
|
||||
service. No Docker / docker compose is required — everything is pure YAML.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
BASE = ROOT / "docker-compose.yml"
|
||||
NVIDIA_OVERLAY = ROOT / "docker" / "gpu.nvidia.yml"
|
||||
AMD_OVERLAY = ROOT / "docker" / "gpu.amd.yml"
|
||||
NVIDIA_STANDALONE = ROOT / "docker-compose.gpu-nvidia.yml"
|
||||
AMD_STANDALONE = ROOT / "docker-compose.gpu-amd.yml"
|
||||
|
||||
SERVICE = "odysseus"
|
||||
|
||||
|
||||
def _load(path: Path) -> dict:
|
||||
return yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _deep_merge(base: dict, overlay: dict) -> dict:
|
||||
"""Mirror docker compose overlay semantics for the keys these files use.
|
||||
|
||||
Mappings merge recursively; list-valued service fields are concatenated
|
||||
(compose appends override sequences such as ``environment`` rather than
|
||||
replacing them); scalars are overwritten. The overlays here only append to
|
||||
``environment`` and add otherwise-absent keys (``deploy``, ``devices``,
|
||||
``group_add``), so this keeps the expected merge explicit without invoking
|
||||
docker compose.
|
||||
"""
|
||||
result = copy.deepcopy(base)
|
||||
for key, value in overlay.items():
|
||||
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
elif isinstance(value, list) and isinstance(result.get(key), list):
|
||||
result[key] = copy.deepcopy(result[key]) + copy.deepcopy(value)
|
||||
else:
|
||||
result[key] = copy.deepcopy(value)
|
||||
return result
|
||||
|
||||
|
||||
def _merge_overlay_into_base(base: dict, overlay: dict) -> dict:
|
||||
"""Build the expected standalone config: base + overlay on odysseus only."""
|
||||
expected = copy.deepcopy(base)
|
||||
overlay_service = overlay["services"][SERVICE]
|
||||
expected["services"][SERVICE] = _deep_merge(
|
||||
expected["services"][SERVICE], overlay_service
|
||||
)
|
||||
return expected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def base():
|
||||
return _load(BASE)
|
||||
|
||||
|
||||
# --- Equivalence: standalone == base + overlay -----------------------------
|
||||
|
||||
|
||||
def test_nvidia_standalone_equals_base_plus_overlay(base):
|
||||
overlay = _load(NVIDIA_OVERLAY)
|
||||
standalone = _load(NVIDIA_STANDALONE)
|
||||
assert standalone == _merge_overlay_into_base(base, overlay)
|
||||
|
||||
|
||||
def test_amd_standalone_equals_base_plus_overlay(base):
|
||||
overlay = _load(AMD_OVERLAY)
|
||||
standalone = _load(AMD_STANDALONE)
|
||||
assert standalone == _merge_overlay_into_base(base, overlay)
|
||||
|
||||
|
||||
# --- Non-odysseus services and volumes untouched ---------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
|
||||
def test_non_odysseus_services_match_base(base, standalone_path):
|
||||
standalone = _load(standalone_path)
|
||||
for name, definition in base["services"].items():
|
||||
if name == SERVICE:
|
||||
continue
|
||||
assert standalone["services"][name] == definition
|
||||
assert set(standalone["services"]) == set(base["services"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("standalone_path", [NVIDIA_STANDALONE, AMD_STANDALONE])
|
||||
def test_top_level_volumes_match_base(base, standalone_path):
|
||||
standalone = _load(standalone_path)
|
||||
assert standalone.get("volumes") == base.get("volumes")
|
||||
|
||||
|
||||
# --- odysseus = base service + only the overlay additions ------------------
|
||||
|
||||
|
||||
def test_nvidia_odysseus_adds_only_overlay(base):
|
||||
standalone = _load(NVIDIA_STANDALONE)
|
||||
svc = standalone["services"][SERVICE]
|
||||
base_svc = base["services"][SERVICE]
|
||||
|
||||
# Base environment preserved, plus exactly the two NVIDIA variables.
|
||||
assert "NVIDIA_VISIBLE_DEVICES=all" in svc["environment"]
|
||||
assert "NVIDIA_DRIVER_CAPABILITIES=compute,utility" in svc["environment"]
|
||||
added_env = set(svc["environment"]) - set(base_svc["environment"])
|
||||
assert added_env == {
|
||||
"NVIDIA_VISIBLE_DEVICES=all",
|
||||
"NVIDIA_DRIVER_CAPABILITIES=compute,utility",
|
||||
}
|
||||
|
||||
# deploy block is new and matches the overlay's GPU reservation exactly.
|
||||
assert "deploy" not in base_svc
|
||||
devices = svc["deploy"]["resources"]["reservations"]["devices"]
|
||||
assert devices == [
|
||||
{"driver": "nvidia", "count": "all", "capabilities": ["gpu"]}
|
||||
]
|
||||
|
||||
# No AMD-only keys leaked in.
|
||||
assert "devices" not in svc
|
||||
assert "group_add" not in svc
|
||||
|
||||
|
||||
def test_amd_odysseus_adds_only_overlay(base):
|
||||
standalone = _load(AMD_STANDALONE)
|
||||
svc = standalone["services"][SERVICE]
|
||||
base_svc = base["services"][SERVICE]
|
||||
|
||||
# Environment is unchanged from base for AMD.
|
||||
assert svc["environment"] == base_svc["environment"]
|
||||
|
||||
# devices and group_add are new and match the overlay exactly.
|
||||
assert "devices" not in base_svc
|
||||
assert "group_add" not in base_svc
|
||||
assert svc["devices"] == ["/dev/kfd", "/dev/dri"]
|
||||
assert svc["group_add"] == ["video", "${RENDER_GID:-render}"]
|
||||
|
||||
# No NVIDIA-only keys leaked in.
|
||||
assert "deploy" not in svc
|
||||
@@ -66,37 +66,3 @@ def test_normal_model_payload_keeps_temperature(monkeypatch):
|
||||
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 0.2)
|
||||
assert payload["temperature"] == 0.2
|
||||
assert payload["max_tokens"] == 5
|
||||
|
||||
|
||||
def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
|
||||
# OpenAI/local providers may validly use temperatures above 1.0; the clamp
|
||||
# is Anthropic-only and must not touch this path.
|
||||
payload = _capture_openai_payload(monkeypatch, "gpt-4o", 1.2)
|
||||
assert payload["temperature"] == 1.2
|
||||
|
||||
|
||||
def _anthropic_payload(temperature):
|
||||
return llm_core._build_anthropic_payload(
|
||||
"claude-3-5-sonnet",
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
temperature,
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_payload_clamps_above_one():
|
||||
# Anthropic rejects temperature > 1.0 (e.g. the Nietzsche preset's 1.2).
|
||||
assert _anthropic_payload(1.2)["temperature"] == 1.0
|
||||
|
||||
|
||||
def test_anthropic_payload_keeps_in_range():
|
||||
assert _anthropic_payload(0.7)["temperature"] == 0.7
|
||||
|
||||
|
||||
def test_anthropic_payload_clamps_negative():
|
||||
assert _anthropic_payload(-0.5)["temperature"] == 0.0
|
||||
|
||||
|
||||
def test_anthropic_payload_none_temperature_does_not_crash():
|
||||
payload = _anthropic_payload(None)
|
||||
assert payload["temperature"] is None
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
"""Tests for model route helper functions — pure logic, no server needed."""
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -32,8 +29,6 @@ import src.endpoint_resolver as endpoint_resolver
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
_visible_models,
|
||||
_normalize_model_ids,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_probe_endpoint,
|
||||
@@ -475,342 +470,3 @@ class TestDockerHostGatewayReachable:
|
||||
|
||||
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
|
||||
assert model_routes._docker_host_gateway_reachable() is False
|
||||
|
||||
|
||||
# ── pinned model IDs: normalization helper ──
|
||||
|
||||
|
||||
class TestNormalizeModelIds:
|
||||
def test_list_passthrough_trims_and_dedupes(self):
|
||||
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
|
||||
|
||||
def test_json_string_list(self):
|
||||
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
|
||||
|
||||
def test_comma_and_newline_string(self):
|
||||
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
|
||||
|
||||
def test_none_and_empty(self):
|
||||
assert _normalize_model_ids(None) == []
|
||||
assert _normalize_model_ids("") == []
|
||||
assert _normalize_model_ids(" ") == []
|
||||
|
||||
def test_non_string_values_ignored(self):
|
||||
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
|
||||
|
||||
|
||||
# ── pinned model IDs: _visible_models merge ──
|
||||
|
||||
|
||||
class TestVisibleModelsPinned:
|
||||
def test_includes_pinned_not_in_cached(self):
|
||||
visible = _visible_models(["a"], None, ["deploy-1"])
|
||||
assert visible == ["a", "deploy-1"]
|
||||
|
||||
def test_cached_plus_pinned_dedup_preserves_order(self):
|
||||
visible = _visible_models(["a", "b"], None, ["b", "c"])
|
||||
assert visible == ["a", "b", "c"]
|
||||
|
||||
def test_hidden_can_hide_a_pinned_model(self):
|
||||
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
|
||||
assert visible == ["a"]
|
||||
|
||||
def test_accepts_json_string_inputs(self):
|
||||
visible = _visible_models('["a"]', '["a"]', '["b"]')
|
||||
assert visible == ["b"]
|
||||
|
||||
|
||||
# ── pinned model IDs: route behaviour ──
|
||||
|
||||
# Building the router exercises FastAPI's Form() routes, which require
|
||||
# python-multipart. The test env ships without it, so register a minimal stub
|
||||
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
|
||||
if "python_multipart" not in sys.modules:
|
||||
try:
|
||||
import python_multipart # noqa: F401
|
||||
except ImportError:
|
||||
_mp_stub = types.ModuleType("python_multipart")
|
||||
_mp_stub.__version__ = "0.0.13"
|
||||
sys.modules["python_multipart"] = _mp_stub
|
||||
|
||||
|
||||
class _PinnedFakeQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
|
||||
def filter(self, *conditions):
|
||||
return self
|
||||
|
||||
def order_by(self, *args):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self):
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _PinnedFakeDb:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.added = []
|
||||
self.committed = 0
|
||||
|
||||
def query(self, model):
|
||||
return _PinnedFakeQuery(self.rows)
|
||||
|
||||
def add(self, row):
|
||||
self.added.append(row)
|
||||
|
||||
def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeCol:
|
||||
"""Column stand-in: every comparison/operator just returns itself so the
|
||||
dedupe query expressions evaluate without a real SQLAlchemy column."""
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self
|
||||
|
||||
def is_(self, other):
|
||||
return self
|
||||
|
||||
def __or__(self, other):
|
||||
return self
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _RecordingEndpoint:
|
||||
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
|
||||
|
||||
Class-level fake columns let it double as the query class in the dedupe
|
||||
lookup; instance attributes (set in __init__) shadow them per-row.
|
||||
"""
|
||||
|
||||
id = _FakeCol()
|
||||
base_url = _FakeCol()
|
||||
owner = _FakeCol()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _PinnedFakeRequest:
|
||||
def __init__(self, body=None, headers=None):
|
||||
self._body = body if body is not None else {}
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self):
|
||||
return self._body
|
||||
|
||||
|
||||
def _get_route(path, method):
|
||||
from routes.model_routes import setup_model_routes
|
||||
router = setup_model_routes(model_discovery=None)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{method} {path} not found")
|
||||
|
||||
|
||||
def _make_endpoint(**kwargs):
|
||||
base = dict(
|
||||
id="ep1",
|
||||
name="EP",
|
||||
base_url="http://localhost:9999/v1",
|
||||
api_key=None,
|
||||
is_enabled=True,
|
||||
hidden_models=None,
|
||||
cached_models=None,
|
||||
pinned_models=None,
|
||||
model_type="llm",
|
||||
supports_tools=None,
|
||||
)
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_patch_models_saves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint()
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
|
||||
result = asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
|
||||
assert result["pinned_count"] == 2
|
||||
|
||||
|
||||
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
|
||||
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||
|
||||
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
|
||||
asyncio.run(endpoint("ep1", request))
|
||||
|
||||
assert json.loads(ep.hidden_models) == ["hide-me"]
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
|
||||
|
||||
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
|
||||
|
||||
result = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
ids = [row["id"] for row in result]
|
||||
assert ids == ["deploy-1"]
|
||||
assert result[0]["is_pinned"] is True
|
||||
|
||||
|
||||
def test_reprobe_preserves_pinned_models(monkeypatch):
|
||||
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
|
||||
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||
monkeypatch.setattr(
|
||||
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
|
||||
)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||
|
||||
response = endpoint("ep1", _PinnedFakeRequest())
|
||||
|
||||
async def _drain():
|
||||
async for _ in response.body_iterator:
|
||||
pass
|
||||
|
||||
asyncio.run(_drain())
|
||||
|
||||
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
|
||||
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||
assert json.loads(ep.cached_models) == ["m1"]
|
||||
|
||||
|
||||
def test_visible_models_handles_malformed_strings():
|
||||
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||
# never raise; a malformed hidden string is normalized too.
|
||||
result = _visible_models("a,b", "b", "{bad json")
|
||||
assert isinstance(result, list)
|
||||
assert result == ["a", "{bad json"]
|
||||
assert _visible_models("", None, "") == []
|
||||
assert _visible_models("only-cached", None, None) == ["only-cached"]
|
||||
|
||||
|
||||
def _create_form_kwargs(**overrides):
|
||||
"""Defaults for every Form() param create_model_endpoint reads directly.
|
||||
|
||||
Calling the route as a plain function bypasses FastAPI form parsing, so the
|
||||
Form() sentinels must be replaced with real strings.
|
||||
"""
|
||||
kwargs = dict(
|
||||
name="",
|
||||
api_key="",
|
||||
skip_probe="true", # avoid any network probe in unit tests
|
||||
require_models="false",
|
||||
model_type="llm",
|
||||
supports_tools="",
|
||||
pinned_models="",
|
||||
container_local="false",
|
||||
shared="true",
|
||||
)
|
||||
kwargs.update(overrides)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _patch_create_deps(monkeypatch, db):
|
||||
import src.auth_helpers as auth_helpers
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
||||
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
||||
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
||||
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
||||
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
||||
|
||||
|
||||
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
|
||||
db = _PinnedFakeDb([]) # no existing row → fresh create path
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
|
||||
)
|
||||
|
||||
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["models"] == ["deploy-1", "deploy-2"]
|
||||
assert result["online"] is True
|
||||
# Persisted onto the created row.
|
||||
assert len(db.added) == 1
|
||||
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
|
||||
|
||||
|
||||
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
hidden_models=None,
|
||||
pinned_models=json.dumps(["old-pin"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(pinned_models="new-pin"),
|
||||
)
|
||||
|
||||
assert result["existing"] is True
|
||||
# Incoming pin merged onto the existing pins (no clobber, order preserved).
|
||||
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
|
||||
assert result["pinned_models"] == ["old-pin", "new-pin"]
|
||||
# models = cached + pinned - hidden, visible merged list.
|
||||
assert result["models"] == ["m1", "old-pin", "new-pin"]
|
||||
# No new row created on the dedupe path.
|
||||
assert db.added == []
|
||||
|
||||
|
||||
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||
existing = _make_endpoint(
|
||||
cached_models=json.dumps(["m1"]),
|
||||
pinned_models=json.dumps(["keep-me"]),
|
||||
)
|
||||
db = _PinnedFakeDb([existing])
|
||||
_patch_create_deps(monkeypatch, db)
|
||||
create = _get_route("/api/model-endpoints", "POST")
|
||||
|
||||
result = create(
|
||||
_PinnedFakeRequest(),
|
||||
base_url="http://host:1234/v1",
|
||||
**_create_form_kwargs(), # pinned_models defaults to ""
|
||||
)
|
||||
|
||||
assert json.loads(existing.pinned_models) == ["keep-me"]
|
||||
assert result["pinned_models"] == ["keep-me"]
|
||||
assert db.committed == 0 # nothing to persist
|
||||
|
||||
@@ -247,14 +247,10 @@ class _Column:
|
||||
def __eq__(self, value):
|
||||
return _Predicate(lambda row: getattr(row, self.name) == value)
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Query:
|
||||
@@ -265,9 +261,6 @@ class _Query:
|
||||
self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
@@ -287,10 +280,8 @@ def _ep(name, owner, *, is_enabled=True):
|
||||
|
||||
def _select(rows, owner):
|
||||
wh_mod = _import_webhook_helper()
|
||||
# _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
|
||||
# (not a local import), so we patch the module attribute directly.
|
||||
wh_mod.ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
|
||||
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._first_enabled_endpoint(_DB(rows), owner)
|
||||
|
||||
|
||||
def test_sync_chat_fallback_never_picks_another_owners_endpoint():
|
||||
@@ -319,15 +310,9 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
|
||||
# When no token owner is known, only null-owner (shared) endpoints are
|
||||
# visible — private endpoints of any user must not be returned.
|
||||
rows = [_ep("bob-private", "bob"), _ep("shared", None)]
|
||||
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop():
|
||||
# An unresolvable/empty token owner keeps the original single-user behaviour
|
||||
# (owner_filter no-op): first enabled row, whatever it is.
|
||||
rows = [_ep("first", "bob"), _ep("second", "alice")]
|
||||
ep = _select(rows, None)
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
|
||||
# No shared rows → fail closed rather than returning another user's endpoint.
|
||||
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
|
||||
assert _select(rows, None) is None
|
||||
assert ep is not None and ep.name == "first"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from services.search.ranking import rank_search_results
|
||||
from src.search.ranking import rank_search_results
|
||||
|
||||
|
||||
def test_news_queries_prefer_news_sources_over_sports_and_social_results():
|
||||
|
||||
@@ -8,8 +8,7 @@ module-level, time-injectable function.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import services.search.ranking as live_ranking
|
||||
from services.search.ranking import recency_score, _utcnow_naive, rank_search_results
|
||||
from src.search.ranking import recency_score, _utcnow_naive
|
||||
|
||||
|
||||
def test_fresh_result_scores_one():
|
||||
@@ -38,37 +37,3 @@ def test_default_now_is_naive_utc():
|
||||
assert now.tzinfo is None
|
||||
reference = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
assert abs((now - reference).total_seconds()) < 5
|
||||
|
||||
|
||||
def test_supported_timestamp_formats_parse():
|
||||
# All three formats the current implementation supports resolve to the same
|
||||
# ~4-day-old age, so each scores a full 1.0.
|
||||
now = datetime(2026, 1, 5, 12, 0, 0)
|
||||
assert recency_score("2026-01-01", now=now) == 1.0
|
||||
assert recency_score("2026-01-01T08:30:00", now=now) == 1.0
|
||||
assert recency_score("2026-01-01 08:30:00", now=now) == 1.0
|
||||
|
||||
|
||||
def test_shim_reexports_live_objects():
|
||||
# src.search.ranking is a compatibility shim; it must expose the *same*
|
||||
# objects as the live services module so the two cannot diverge.
|
||||
import src.search.ranking as shim
|
||||
|
||||
assert shim.recency_score is live_ranking.recency_score
|
||||
assert shim.rank_search_results is live_ranking.rank_search_results
|
||||
assert shim._utcnow_naive is live_ranking._utcnow_naive
|
||||
|
||||
|
||||
def test_live_rank_path_prefers_newer_result(monkeypatch):
|
||||
# Pin "now" so age scoring is deterministic. The two results are identical
|
||||
# apart from age, isolating recency as the only differentiator.
|
||||
monkeypatch.setattr(live_ranking, "_utcnow_naive", lambda: datetime(2026, 1, 31))
|
||||
results = [
|
||||
{"title": "Report", "url": "https://example.org/a", "snippet": "x", "age": "2026-01-01"},
|
||||
{"title": "Report", "url": "https://example.org/b", "snippet": "x", "age": "2026-01-29"},
|
||||
]
|
||||
|
||||
ranked = rank_search_results("report", results)
|
||||
|
||||
assert ranked[0]["url"] == "https://example.org/b"
|
||||
assert ranked[1]["url"] == "https://example.org/a"
|
||||
|
||||
Reference in New Issue
Block a user