Harden API-token chat endpoint selection
Validate only token-supplied direct base_url values for API-token chat requests, while keeping admin-configured endpoints available for local/LAN providers. Scope configured endpoint fallback selection to the API token owner, fail closed for unknown token owners, and preserve strict session ownership checks when resuming sessions from chat-scoped API tokens. Add focused regression coverage for direct base_url SSRF rejection, configured endpoint fallback behavior, token-owner scoping, URL validation, and null-owner session/endpoint handling.
This commit is contained in:
committed by
GitHub
parent
145f4fd2b4
commit
b1a4ed13b0
@@ -9,7 +9,9 @@ import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.database import SessionLocal, Webhook
|
||||
from core.database import SessionLocal, Webhook, ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.url_security import validate_public_http_url
|
||||
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000
|
||||
from core.middleware import require_admin as _require_admin
|
||||
|
||||
|
||||
def _first_enabled_endpoint(db, owner):
|
||||
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus
|
||||
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint
|
||||
is per-user (core/database.py — "when non-null, the model picker only shows
|
||||
the endpoint to that user"), and the sync-chat fallback uses the row's
|
||||
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token
|
||||
(e.g. a paired mobile device) fall back onto ANOTHER user's private
|
||||
endpoint and silently spend that owner's API key / quota — and reach
|
||||
whatever internal base_url they configured. Mirrors the owner_filter scoping
|
||||
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
|
||||
no-op (single-user / legacy mode), preserving the original behaviour.
|
||||
def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]):
|
||||
"""First enabled ModelEndpoint visible to token_owner — their own rows plus
|
||||
legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would
|
||||
let a chat-scoped token fall back onto another user's private endpoint and
|
||||
silently spend that owner's API key/quota. Prefer owner rows before shared
|
||||
rows. Fails closed to null-owner rows only when token_owner is absent.
|
||||
Does not validate base_url — admin-configured local/LAN endpoints remain allowed.
|
||||
"""
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
return q.first()
|
||||
query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
|
||||
if token_owner:
|
||||
query = owner_filter(query, ModelEndpoint, token_owner)
|
||||
return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first()
|
||||
return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711
|
||||
|
||||
|
||||
def _caller_owns_session(sess_owner, caller) -> bool:
|
||||
@@ -278,15 +276,21 @@ def setup_webhook_routes(
|
||||
api_key = body.api_key.strip()
|
||||
model = body.model or "deepseek-chat"
|
||||
|
||||
# Resolve base_url: explicit > provider name > model prefix auto-detect
|
||||
base_url = body.base_url.strip().rstrip("/") if body.base_url else None
|
||||
if not base_url:
|
||||
# Validate only token-supplied direct base_url; auto-resolved known-provider
|
||||
# URLs are not subject to extra local/LAN blocking beyond existing provider logic.
|
||||
direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None
|
||||
if direct_base_url:
|
||||
try:
|
||||
base_url = validate_public_http_url(direct_base_url)
|
||||
except ValueError as e:
|
||||
detail = str(e).replace("URL", "base_url", 1)
|
||||
raise HTTPException(400, detail)
|
||||
else:
|
||||
base_url = _resolve_base_url(model, body.provider)
|
||||
if not base_url:
|
||||
raise HTTPException(400,
|
||||
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
|
||||
"or provider ('deepseek', 'openai', 'groq', etc.)")
|
||||
|
||||
base_url = normalize_base(base_url)
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
|
||||
@@ -306,9 +310,7 @@ def setup_webhook_routes(
|
||||
if not sess:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped: only THIS token owner's endpoints + legacy
|
||||
# shared rows, never another user's private endpoint/api_key.
|
||||
ep = _first_enabled_endpoint(db, token_owner)
|
||||
ep = _select_api_chat_fallback_endpoint(db, token_owner)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
94
src/url_security.py
Normal file
94
src/url_security.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""URL validation helpers for server-side outbound requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
_INTERNAL_HOSTNAMES = {
|
||||
"localhost",
|
||||
"metadata",
|
||||
"metadata.google.internal",
|
||||
}
|
||||
|
||||
_INTERNAL_SUFFIXES = (
|
||||
".localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".lan",
|
||||
".intranet",
|
||||
)
|
||||
|
||||
_BLOCKED_NETWORKS = (
|
||||
ipaddress.ip_network("0.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("100.64.0.0/10"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("::/128"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
|
||||
ips: list[ipaddress._BaseAddress] = []
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
if family in (socket.AF_INET, socket.AF_INET6):
|
||||
ips.append(ipaddress.ip_address(sockaddr[0]))
|
||||
return ips
|
||||
|
||||
|
||||
def _blocked_ip(addr: ipaddress._BaseAddress) -> bool:
|
||||
return (
|
||||
any(addr in net for net in _BLOCKED_NETWORKS)
|
||||
or addr.is_private
|
||||
or addr.is_loopback
|
||||
or addr.is_link_local
|
||||
or addr.is_multicast
|
||||
or addr.is_unspecified
|
||||
or addr.is_reserved
|
||||
)
|
||||
|
||||
|
||||
def _host_resolves_publicly(hostname: str) -> bool:
|
||||
host = hostname.strip().lower()
|
||||
if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES):
|
||||
return False
|
||||
try:
|
||||
return not _blocked_ip(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
addrs = _resolve_hostname_ips(host)
|
||||
except OSError:
|
||||
return False
|
||||
return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs)
|
||||
|
||||
|
||||
def is_public_http_url(url: str) -> bool:
|
||||
parsed = urlparse((url or "").strip())
|
||||
if parsed.scheme not in ("http", "https") or not parsed.hostname:
|
||||
return False
|
||||
return _host_resolves_publicly(parsed.hostname)
|
||||
|
||||
|
||||
def validate_public_http_url(url: str, *, max_length: int = 2048) -> str:
|
||||
"""Validate a user/API-token supplied server-side HTTP(S) endpoint.
|
||||
|
||||
This is for untrusted outbound URLs, not admin-created model endpoints
|
||||
that are intentionally allowed to point at private model providers. DNS
|
||||
failures fail closed, and DNS checks reduce obvious private-network
|
||||
targets but do not eliminate every DNS rebinding race by themselves.
|
||||
"""
|
||||
cleaned = (url or "").strip()
|
||||
if len(cleaned) > max_length:
|
||||
raise ValueError("URL is too long")
|
||||
if not is_public_http_url(cleaned):
|
||||
raise ValueError("URL must point to a public HTTP(S) endpoint")
|
||||
return cleaned
|
||||
401
tests/test_api_chat_security.py
Normal file
401
tests/test_api_chat_security.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import ipaddress
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://127.0.0.1:8000/v1",
|
||||
"http://localhost:8000/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://172.16.0.1/v1",
|
||||
"http://192.168.1.2/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://metadata.google.internal/",
|
||||
"http://[::1]:8000/v1",
|
||||
"http://[fc00::1]/v1",
|
||||
"http://224.0.0.1/v1",
|
||||
"http://0.0.0.0/v1",
|
||||
"file:///etc/passwd",
|
||||
])
|
||||
def test_public_url_validator_blocks_internal_targets(url):
|
||||
from src.url_security import is_public_http_url
|
||||
|
||||
assert is_public_http_url(url) is False
|
||||
|
||||
|
||||
def test_public_url_validator_allows_public_endpoint(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
|
||||
|
||||
|
||||
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("10.0.0.5")],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
url_security.validate_public_http_url("https://api.example.com/v1")
|
||||
|
||||
|
||||
def _load_webhook_routes_for_test(monkeypatch):
|
||||
# Load under a unique module name so each test gets a fresh module object
|
||||
# rather than a cached one from a previous monkeypatch run.
|
||||
core_pkg = types.ModuleType("core")
|
||||
core_pkg.__path__ = []
|
||||
core_db = types.ModuleType("core.database")
|
||||
core_db.SessionLocal = object
|
||||
core_db.Webhook = object
|
||||
core_db.ModelEndpoint = object
|
||||
core_middleware = types.ModuleType("core.middleware")
|
||||
core_middleware.require_admin = lambda request: None
|
||||
webhook_manager = types.ModuleType("src.webhook_manager")
|
||||
webhook_manager.WebhookManager = object
|
||||
webhook_manager.validate_webhook_url = lambda url: url
|
||||
webhook_manager.validate_events = lambda events: events
|
||||
|
||||
monkeypatch.setitem(sys.modules, "core", core_pkg)
|
||||
monkeypatch.setitem(sys.modules, "core.database", core_db)
|
||||
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
|
||||
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
|
||||
|
||||
module_name = "routes.webhook_routes_under_test"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class _Expr:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, row):
|
||||
return self.fn(row)
|
||||
|
||||
def __or__(self, other):
|
||||
return _Expr(lambda row: self(row) or other(row))
|
||||
|
||||
|
||||
class _Column:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return _Expr(lambda row: getattr(row, self.name) == other)
|
||||
|
||||
def desc(self):
|
||||
return ("desc", self.name)
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Endpoint:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
owner,
|
||||
is_enabled=True,
|
||||
created_at=1,
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key=None,
|
||||
):
|
||||
self.owner = owner
|
||||
self.is_enabled = is_enabled
|
||||
self.created_at = created_at
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
|
||||
class _EndpointQuery:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.filters = []
|
||||
self.orders = []
|
||||
|
||||
def filter(self, *exprs):
|
||||
self.filters.extend(exprs)
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
self.orders.extend(exprs)
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
rows = self.rows
|
||||
for expr in self.filters:
|
||||
rows = [row for row in rows if expr(row)]
|
||||
# Apply sort keys right-to-left so the leftmost key ends up as the
|
||||
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
|
||||
# multi-column ORDER BY behaviour).
|
||||
for order in reversed(self.orders):
|
||||
reverse = False
|
||||
name = getattr(order, "name", None)
|
||||
if isinstance(order, tuple) and order[0] == "desc":
|
||||
reverse = True
|
||||
name = order[1]
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
|
||||
if name != "owner":
|
||||
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
|
||||
return rows[0] if rows else None
|
||||
|
||||
|
||||
class _DB:
|
||||
def __init__(self, rows):
|
||||
self.query_obj = _EndpointQuery(rows)
|
||||
self.closed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is _ModelEndpoint
|
||||
return self.query_obj
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _ChatSession:
|
||||
def __init__(self, endpoint_url, model):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.model = model
|
||||
self.headers = {}
|
||||
self.history = []
|
||||
|
||||
def add_message(self, message):
|
||||
self.history.append(message)
|
||||
|
||||
|
||||
class _SessionManager:
|
||||
def __init__(self):
|
||||
self.created = []
|
||||
self.save_calls = 0
|
||||
|
||||
def create_session(self, *, session_id, name, endpoint_url, model, owner):
|
||||
session = _ChatSession(endpoint_url, model)
|
||||
self.created.append({
|
||||
"session_id": session_id,
|
||||
"name": name,
|
||||
"endpoint_url": endpoint_url,
|
||||
"model": model,
|
||||
"owner": owner,
|
||||
"session": session,
|
||||
})
|
||||
return session
|
||||
|
||||
def save_sessions(self):
|
||||
self.save_calls += 1
|
||||
|
||||
|
||||
class _Request:
|
||||
def __init__(self, *, owner="alice"):
|
||||
self.state = types.SimpleNamespace(
|
||||
api_token=True,
|
||||
api_token_scopes=["chat"],
|
||||
api_token_owner=owner,
|
||||
)
|
||||
|
||||
|
||||
class _WebhookManager:
|
||||
async def fire(self, event, payload):
|
||||
return None
|
||||
|
||||
|
||||
def _install_sync_chat_stubs(monkeypatch):
|
||||
# FastAPI checks for python_multipart at import time when Form is used;
|
||||
# stub it so the optional dependency is not required in the test environment.
|
||||
python_multipart = types.ModuleType("python_multipart")
|
||||
python_multipart.__version__ = "0.0.13"
|
||||
core_models = types.ModuleType("core.models")
|
||||
|
||||
class _ChatMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
|
||||
return "mocked response"
|
||||
|
||||
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
|
||||
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
|
||||
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
|
||||
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
|
||||
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
llm_core = types.ModuleType("src.llm_core")
|
||||
llm_core.llm_call_async = _llm_call_async
|
||||
core_models.ChatMessage = _ChatMessage
|
||||
|
||||
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
|
||||
monkeypatch.setitem(sys.modules, "core.models", core_models)
|
||||
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
|
||||
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
|
||||
|
||||
|
||||
def _sync_chat_endpoint(webhook_routes, session_manager):
|
||||
router = webhook_routes.setup_webhook_routes(
|
||||
_WebhookManager(),
|
||||
auth_manager=None,
|
||||
session_manager=session_manager,
|
||||
)
|
||||
for route in router.routes:
|
||||
if route.path == "/api/v1/chat":
|
||||
return route.endpoint
|
||||
raise AssertionError("sync chat route not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"http://127.0.0.1:11434/v1",
|
||||
"http://localhost:11434/v1",
|
||||
"http://10.0.0.5/v1",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url=base_url,
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
with pytest.raises(webhook_routes.HTTPException) as exc:
|
||||
await sync_chat(_Request(), body)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
|
||||
assert session_manager.created == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
|
||||
from src import url_security
|
||||
|
||||
monkeypatch.setattr(
|
||||
url_security,
|
||||
"_resolve_hostname_ips",
|
||||
lambda host: [ipaddress.ip_address("93.184.216.34")],
|
||||
)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
api_key="test-key",
|
||||
base_url="https://api.example.com/v1",
|
||||
model="test-model",
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "test-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
|
||||
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", is_enabled=False, created_at=0),
|
||||
_Endpoint(owner="bob", created_at=0),
|
||||
_Endpoint(owner=None, created_at=1),
|
||||
_Endpoint(owner="alice", created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
|
||||
|
||||
assert selected.owner == "alice"
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
rows = [
|
||||
_Endpoint(owner="alice", created_at=0),
|
||||
_Endpoint(owner=None, is_enabled=False, created_at=1),
|
||||
_Endpoint(owner=None, created_at=2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
|
||||
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
|
||||
|
||||
assert selected.owner is None
|
||||
assert selected.is_enabled is True
|
||||
assert selected.created_at == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
|
||||
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
|
||||
_install_sync_chat_stubs(monkeypatch)
|
||||
local_endpoint = _Endpoint(
|
||||
owner=None,
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="configured-key",
|
||||
)
|
||||
db = _DB([local_endpoint])
|
||||
calls = []
|
||||
|
||||
def _session_local():
|
||||
return db
|
||||
|
||||
def _validate_public_http_url(url, *, max_length=2048):
|
||||
calls.append(url)
|
||||
raise AssertionError("configured fallback endpoint should not be publicly validated")
|
||||
|
||||
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
|
||||
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
|
||||
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
|
||||
|
||||
session_manager = _SessionManager()
|
||||
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
|
||||
body = types.SimpleNamespace(
|
||||
message="hello",
|
||||
model="local-model",
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
provider=None,
|
||||
session=None,
|
||||
)
|
||||
|
||||
response = await sync_chat(_Request(owner=None), body)
|
||||
|
||||
assert response["response"] == "mocked response"
|
||||
assert response["model"] == "local-model"
|
||||
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
|
||||
assert calls == []
|
||||
@@ -247,10 +247,14 @@ class _Column:
|
||||
def __eq__(self, value):
|
||||
return _Predicate(lambda row: getattr(row, self.name) == value)
|
||||
|
||||
def desc(self):
|
||||
return self
|
||||
|
||||
|
||||
class _ModelEndpoint:
|
||||
is_enabled = _Column("is_enabled")
|
||||
owner = _Column("owner")
|
||||
created_at = _Column("created_at")
|
||||
|
||||
|
||||
class _Query:
|
||||
@@ -261,6 +265,9 @@ class _Query:
|
||||
self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
|
||||
return self
|
||||
|
||||
def order_by(self, *exprs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
@@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True):
|
||||
|
||||
def _select(rows, owner):
|
||||
wh_mod = _import_webhook_helper()
|
||||
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._first_enabled_endpoint(_DB(rows), owner)
|
||||
# _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
|
||||
# (not a local import), so we patch the module attribute directly.
|
||||
wh_mod.ModelEndpoint = _ModelEndpoint
|
||||
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
|
||||
|
||||
|
||||
def test_sync_chat_fallback_never_picks_another_owners_endpoint():
|
||||
@@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop():
|
||||
# An unresolvable/empty token owner keeps the original single-user behaviour
|
||||
# (owner_filter no-op): first enabled row, whatever it is.
|
||||
rows = [_ep("first", "bob"), _ep("second", "alice")]
|
||||
def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
|
||||
# When no token owner is known, only null-owner (shared) endpoints are
|
||||
# visible — private endpoints of any user must not be returned.
|
||||
rows = [_ep("bob-private", "bob"), _ep("shared", None)]
|
||||
ep = _select(rows, None)
|
||||
assert ep is not None and ep.name == "first"
|
||||
assert ep is not None and ep.name == "shared"
|
||||
|
||||
|
||||
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
|
||||
# No shared rows → fail closed rather than returning another user's endpoint.
|
||||
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
|
||||
assert _select(rows, None) is None
|
||||
|
||||
Reference in New Issue
Block a user