diff --git a/routes/embedding_routes.py b/routes/embedding_routes.py index 3b6b090..7ae8005 100644 --- a/routes/embedding_routes.py +++ b/routes/embedding_routes.py @@ -242,6 +242,18 @@ def setup_embedding_routes(): if not url: raise HTTPException(400, "URL is required") + # SSRF hardening: validate the user-supplied URL before any outbound + # request. Local-first means loopback/LAN endpoints are allowed by + # default; non-HTTP(S) schemes and the cloud metadata range are always + # rejected. Set EMBEDDING_BLOCK_PRIVATE_IPS=true for full lockdown. + from src.url_safety import check_outbound_url + ok, reason = check_outbound_url( + url, + block_private=os.getenv("EMBEDDING_BLOCK_PRIVATE_IPS", "false").lower() == "true", + ) + if not ok: + raise HTTPException(400, f"Rejected endpoint URL: {reason}") + # Quick health check try: import httpx diff --git a/src/url_safety.py b/src/url_safety.py new file mode 100644 index 0000000..ec7c8f8 --- /dev/null +++ b/src/url_safety.py @@ -0,0 +1,88 @@ +"""Outbound URL safety checks (SSRF hardening). + +Run before the server makes a request to a *user-supplied* URL — e.g. the custom +embedding endpoint set via ``POST /api/embeddings/endpoint``, which then triggers +an outbound ``httpx`` call. + +Odysseus is local-first: pointing the embedding endpoint at a loopback or LAN +address (a local vLLM / llama.cpp / Ollama server) is a normal, intended setup. +So this guard does **not** blanket-block private addresses by default — that would +break the primary use case. What it *always* rejects: + + - a non-HTTP(S) scheme (``file://``, ``gopher://``, ``ftp://`` …), and + - the link-local range (``169.254.0.0/16`` / ``fe80::/10``), i.e. the cloud + instance-metadata SSRF credential-exfil vector — nobody serves embeddings + there — plus multicast / reserved / unspecified addresses. + +For exposed multi-tenant deployments, set ``EMBEDDING_BLOCK_PRIVATE_IPS=true`` to +additionally reject all private and loopback targets (full SSRF lockdown). +""" + +import ipaddress +import socket +from typing import Callable, List, Optional, Tuple +from urllib.parse import urlparse + +ALLOWED_SCHEMES = ("http", "https") + + +def _default_resolver(host: str) -> List[str]: + """Resolve a hostname to the list of IP strings it maps to (A + AAAA).""" + return [info[4][0] for info in socket.getaddrinfo(host, None)] + + +def _classify(ip: ipaddress._BaseAddress, *, block_private: bool) -> Optional[str]: + """Return a rejection reason for an IP, or None if it is allowed.""" + # IPv4-mapped IPv6 (e.g. ::ffff:169.254.169.254) — judge the embedded v4. + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: + ip = ip.ipv4_mapped + if ip.is_link_local: + return f"link-local address blocked (SSRF metadata risk): {ip}" + if ip.is_multicast or ip.is_reserved or ip.is_unspecified: + return f"disallowed address: {ip}" + if block_private and (ip.is_private or ip.is_loopback): + return f"private/loopback address blocked: {ip}" + return None + + +def check_outbound_url( + url: str, + *, + block_private: bool = False, + resolver: Optional[Callable[[str], List[str]]] = None, +) -> Tuple[bool, str]: + """Validate a user-supplied outbound URL. + + Returns ``(ok, reason)``. ``ok`` is True only when the URL is safe to fetch. + ``resolver`` is injectable so callers/tests can avoid real DNS. + """ + if not url or not url.strip(): + return False, "URL is required" + try: + parsed = urlparse(url.strip()) + except Exception as e: # pragma: no cover - urlparse is very tolerant + return False, f"unparseable URL: {e}" + + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + return False, f"scheme must be http or https, got '{parsed.scheme or '(none)'}'" + host = parsed.hostname + if not host: + return False, "URL has no host" + + resolve = resolver or _default_resolver + try: + raw_ips = resolve(host) + except Exception as e: + return False, f"host does not resolve: {e}" + if not raw_ips: + return False, "host does not resolve" + + for raw in raw_ips: + try: + ip = ipaddress.ip_address(raw.split("%")[0]) # strip IPv6 zone id + except ValueError: + continue + reason = _classify(ip, block_private=block_private) + if reason: + return False, reason + return True, "ok" diff --git a/tests/test_url_safety.py b/tests/test_url_safety.py new file mode 100644 index 0000000..8d4a189 --- /dev/null +++ b/tests/test_url_safety.py @@ -0,0 +1,70 @@ +"""Tests for outbound URL safety / SSRF hardening (src/url_safety.py). + +A stub resolver is injected so the tests never touch real DNS. +""" + +from src.url_safety import check_outbound_url + + +def _resolver(mapping): + def resolve(host): + if host in mapping: + return mapping[host] + raise OSError(f"unresolvable: {host}") + return resolve + + +PUBLIC = _resolver({"example.com": ["93.184.216.34"]}) +LOOPBACK = _resolver({"localhost": ["127.0.0.1"]}) +LAN = _resolver({"nas.local": ["192.168.1.50"]}) +METADATA = _resolver({"evil.example": ["169.254.169.254"]}) +MAPPED_METADATA = _resolver({"evil6.example": ["::ffff:169.254.169.254"]}) + + +def test_non_http_scheme_blocked(): + for url in ("file:///etc/passwd", "ftp://x/y", "gopher://h", "redis://h:6379"): + ok, reason = check_outbound_url(url, resolver=PUBLIC) + assert ok is False, url + assert "scheme" in reason + + +def test_missing_host_or_empty_blocked(): + assert check_outbound_url("", resolver=PUBLIC)[0] is False + assert check_outbound_url("http://", resolver=PUBLIC)[0] is False + + +def test_public_url_allowed(): + ok, reason = check_outbound_url("https://example.com/v1/embeddings", resolver=PUBLIC) + assert ok is True, reason + + +def test_cloud_metadata_blocked_even_when_private_allowed(): + # The headline SSRF vector must be blocked regardless of block_private. + ok, reason = check_outbound_url("http://evil.example/latest/meta-data/", resolver=METADATA) + assert ok is False + assert "link-local" in reason + + +def test_ipv4_mapped_metadata_blocked(): + ok, reason = check_outbound_url("http://evil6.example/", resolver=MAPPED_METADATA) + assert ok is False + assert "link-local" in reason + + +def test_loopback_and_lan_allowed_by_default_local_first(): + # Local-first: a localhost / LAN embedding server is a legitimate target. + assert check_outbound_url("http://localhost:8080/v1", resolver=LOOPBACK)[0] is True + assert check_outbound_url("http://nas.local:1234/v1", resolver=LAN)[0] is True + + +def test_strict_mode_blocks_private_and_loopback(): + ok, reason = check_outbound_url("http://localhost:8080", block_private=True, resolver=LOOPBACK) + assert ok is False and "private" in reason + ok, reason = check_outbound_url("http://nas.local", block_private=True, resolver=LAN) + assert ok is False and "private" in reason + + +def test_unresolvable_host_blocked(): + ok, reason = check_outbound_url("http://does-not-resolve.invalid", resolver=PUBLIC) + assert ok is False + assert "resolve" in reason