224 lines
7 KiB
Python
224 lines
7 KiB
Python
|
|
"""BashTool — execute shell commands with safety guardrails.
|
||
|
|
|
||
|
|
Runs commands via asyncio subprocess with timeout enforcement,
|
||
|
|
interactive-command rejection, and structured output capture.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
import re
|
||
|
|
from typing import Any, ClassVar
|
||
|
|
|
||
|
|
from ..base import Tool, ToolParameter, ToolResult
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Commands that require interactive terminal input — always rejected.
|
||
|
|
INTERACTIVE_COMMANDS: frozenset[str] = frozenset({
|
||
|
|
"sudo",
|
||
|
|
"su",
|
||
|
|
"vim",
|
||
|
|
"vi",
|
||
|
|
"nvim",
|
||
|
|
"nano",
|
||
|
|
"emacs",
|
||
|
|
"pico",
|
||
|
|
"passwd",
|
||
|
|
"ssh",
|
||
|
|
"telnet",
|
||
|
|
"ftp",
|
||
|
|
"top",
|
||
|
|
"htop",
|
||
|
|
"less",
|
||
|
|
"more",
|
||
|
|
"man",
|
||
|
|
"nslookup",
|
||
|
|
})
|
||
|
|
|
||
|
|
# Regex patterns for destructive commands.
|
||
|
|
# Each pattern is compiled and tested against the full command string.
|
||
|
|
DESTRUCTIVE_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||
|
|
re.compile(r"\brm\s+(-\w*f\w*\s+)*-*r\w*\s+/\s*$"), # rm -rf /
|
||
|
|
re.compile(r"\brm\s+(-\w*f\w*\s+)*-*r\w*\s+/\*"), # rm -rf /*
|
||
|
|
re.compile(r"\bpkill\s+node\b"),
|
||
|
|
re.compile(r"\bkillall\s+node\b"),
|
||
|
|
re.compile(r"\bshutdown\b"),
|
||
|
|
re.compile(r"\breboot\b"),
|
||
|
|
re.compile(r"\binit\s+[06]\b"),
|
||
|
|
re.compile(r"\bmkfs\b"),
|
||
|
|
re.compile(r"\bdd\s+if=/dev/zero\b"),
|
||
|
|
re.compile(r":\(\)\s*\{.*\}"), # fork bomb
|
||
|
|
)
|
||
|
|
|
||
|
|
# Maximum allowed timeout: 10 minutes in milliseconds.
|
||
|
|
MAX_TIMEOUT_MS = 600_000
|
||
|
|
DEFAULT_TIMEOUT_MS = 120_000
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_base_command(command: str) -> str:
|
||
|
|
"""Extract the first token of a command string, handling pipes and chains."""
|
||
|
|
stripped = command.strip()
|
||
|
|
# Handle env vars prefix like VAR=val cmd
|
||
|
|
for token in stripped.split():
|
||
|
|
if "=" not in token:
|
||
|
|
return token.split("/")[-1] # basename
|
||
|
|
return stripped.split()[0] if stripped else ""
|
||
|
|
|
||
|
|
|
||
|
|
def validate_command(command: str) -> str | None:
|
||
|
|
"""Validate a command string for safety.
|
||
|
|
|
||
|
|
Returns an error message if the command is rejected, or None if safe.
|
||
|
|
"""
|
||
|
|
stripped = command.strip()
|
||
|
|
if not stripped:
|
||
|
|
return "Empty command"
|
||
|
|
|
||
|
|
# Check for destructive patterns (regex match)
|
||
|
|
lowered = stripped.lower()
|
||
|
|
for pattern in DESTRUCTIVE_PATTERNS:
|
||
|
|
if pattern.search(lowered):
|
||
|
|
return f"Destructive command rejected: matched pattern '{pattern.pattern}'"
|
||
|
|
|
||
|
|
# Check each command in a pipeline / chain for interactive commands
|
||
|
|
# Split on pipe, semicolon, &&, ||
|
||
|
|
segments = re.split(r"\||\;|&&|\|\|", stripped)
|
||
|
|
for segment in segments:
|
||
|
|
base = _extract_base_command(segment)
|
||
|
|
if base in INTERACTIVE_COMMANDS:
|
||
|
|
return f"Interactive command rejected: '{base}' requires terminal input"
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
class BashTool(Tool):
|
||
|
|
"""Execute bash commands with safety guardrails.
|
||
|
|
|
||
|
|
Runs commands in a subprocess with timeout enforcement, working directory
|
||
|
|
scoped to PLATFORM_ROOT, and rejection of interactive/destructive commands.
|
||
|
|
"""
|
||
|
|
|
||
|
|
name: ClassVar[str] = "bash"
|
||
|
|
description: ClassVar[str] = (
|
||
|
|
"Execute a bash command and return its output. "
|
||
|
|
"Commands run in the platform root directory with timeout enforcement. "
|
||
|
|
"Interactive commands (vim, sudo, etc.) are rejected."
|
||
|
|
)
|
||
|
|
parameters: ClassVar[list[ToolParameter]] = [
|
||
|
|
ToolParameter(
|
||
|
|
name="command",
|
||
|
|
type="string",
|
||
|
|
description="The bash command to execute",
|
||
|
|
),
|
||
|
|
ToolParameter(
|
||
|
|
name="description",
|
||
|
|
type="string",
|
||
|
|
description="Human-readable description of what the command does",
|
||
|
|
),
|
||
|
|
ToolParameter(
|
||
|
|
name="timeout",
|
||
|
|
type="integer",
|
||
|
|
description="Timeout in milliseconds (default 120000, max 600000)",
|
||
|
|
required=False,
|
||
|
|
default=DEFAULT_TIMEOUT_MS,
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
async def execute(self, **kwargs: Any) -> ToolResult:
|
||
|
|
command: str = kwargs["command"]
|
||
|
|
description: str = kwargs["description"]
|
||
|
|
timeout_ms: int = kwargs.get("timeout", DEFAULT_TIMEOUT_MS)
|
||
|
|
|
||
|
|
# Clamp timeout
|
||
|
|
timeout_ms = max(1000, min(timeout_ms, MAX_TIMEOUT_MS))
|
||
|
|
timeout_seconds = timeout_ms / 1000.0
|
||
|
|
|
||
|
|
# Safety validation
|
||
|
|
error = validate_command(command)
|
||
|
|
if error is not None:
|
||
|
|
logger.warning("Command rejected: %s — %s", command, error)
|
||
|
|
return ToolResult.fail(error, command=command, description=description)
|
||
|
|
|
||
|
|
logger.info("Executing: %s (%s)", command, description)
|
||
|
|
|
||
|
|
# Resolve working directory
|
||
|
|
from . import PLATFORM_ROOT
|
||
|
|
|
||
|
|
cwd = PLATFORM_ROOT if PLATFORM_ROOT.is_dir() else None
|
||
|
|
|
||
|
|
try:
|
||
|
|
process = await asyncio.create_subprocess_shell(
|
||
|
|
command,
|
||
|
|
stdout=asyncio.subprocess.PIPE,
|
||
|
|
stderr=asyncio.subprocess.PIPE,
|
||
|
|
cwd=cwd,
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||
|
|
process.communicate(),
|
||
|
|
timeout=timeout_seconds,
|
||
|
|
)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
process.kill()
|
||
|
|
await process.wait()
|
||
|
|
logger.warning("Command timed out after %.1fs: %s", timeout_seconds, command)
|
||
|
|
return ToolResult.timed_out(
|
||
|
|
timeout_seconds,
|
||
|
|
command=command,
|
||
|
|
description=description,
|
||
|
|
)
|
||
|
|
|
||
|
|
except OSError as exc:
|
||
|
|
return ToolResult.fail(
|
||
|
|
f"Failed to start process: {exc}",
|
||
|
|
command=command,
|
||
|
|
description=description,
|
||
|
|
)
|
||
|
|
|
||
|
|
stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
|
||
|
|
stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
|
||
|
|
exit_code = process.returncode or 0
|
||
|
|
|
||
|
|
# Combine output
|
||
|
|
output_parts: list[str] = []
|
||
|
|
if stdout:
|
||
|
|
output_parts.append(stdout)
|
||
|
|
if stderr:
|
||
|
|
output_parts.append(stderr)
|
||
|
|
combined_output = "\n".join(output_parts).rstrip("\n")
|
||
|
|
|
||
|
|
# Truncate very large outputs to avoid blowing up context
|
||
|
|
max_output_chars = 100_000
|
||
|
|
truncated = False
|
||
|
|
if len(combined_output) > max_output_chars:
|
||
|
|
combined_output = combined_output[:max_output_chars]
|
||
|
|
truncated = True
|
||
|
|
|
||
|
|
logger.info("Command exited %d: %s", exit_code, command)
|
||
|
|
|
||
|
|
if exit_code != 0:
|
||
|
|
from ..base import ToolResultStatus
|
||
|
|
|
||
|
|
return ToolResult(
|
||
|
|
status=ToolResultStatus.ERROR,
|
||
|
|
output=combined_output,
|
||
|
|
error=f"Command exited with code {exit_code}",
|
||
|
|
metadata={
|
||
|
|
"command": command,
|
||
|
|
"description": description,
|
||
|
|
"exit_code": exit_code,
|
||
|
|
"truncated": truncated,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
return ToolResult.success(
|
||
|
|
combined_output,
|
||
|
|
command=command,
|
||
|
|
description=description,
|
||
|
|
exit_code=exit_code,
|
||
|
|
truncated=truncated,
|
||
|
|
)
|