ml-knowledge-platform/knowledge_platform/tools/builtin/bash.py

224 lines
7 KiB
Python
Raw Permalink Normal View History

"""BashTool — execute shell commands with safety guardrails.
Runs commands via asyncio subprocess with timeout enforcement,
interactive-command rejection, and structured output capture.
"""
from __future__ import annotations
import asyncio
import logging
import re
from typing import Any, ClassVar
from ..base import Tool, ToolParameter, ToolResult
logger = logging.getLogger(__name__)
# Commands that require interactive terminal input — always rejected.
INTERACTIVE_COMMANDS: frozenset[str] = frozenset({
"sudo",
"su",
"vim",
"vi",
"nvim",
"nano",
"emacs",
"pico",
"passwd",
"ssh",
"telnet",
"ftp",
"top",
"htop",
"less",
"more",
"man",
"nslookup",
})
# Regex patterns for destructive commands.
# Each pattern is compiled and tested against the full command string.
DESTRUCTIVE_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"\brm\s+(-\w*f\w*\s+)*-*r\w*\s+/\s*$"), # rm -rf /
re.compile(r"\brm\s+(-\w*f\w*\s+)*-*r\w*\s+/\*"), # rm -rf /*
re.compile(r"\bpkill\s+node\b"),
re.compile(r"\bkillall\s+node\b"),
re.compile(r"\bshutdown\b"),
re.compile(r"\breboot\b"),
re.compile(r"\binit\s+[06]\b"),
re.compile(r"\bmkfs\b"),
re.compile(r"\bdd\s+if=/dev/zero\b"),
re.compile(r":\(\)\s*\{.*\}"), # fork bomb
)
# Maximum allowed timeout: 10 minutes in milliseconds.
MAX_TIMEOUT_MS = 600_000
DEFAULT_TIMEOUT_MS = 120_000
def _extract_base_command(command: str) -> str:
"""Extract the first token of a command string, handling pipes and chains."""
stripped = command.strip()
# Handle env vars prefix like VAR=val cmd
for token in stripped.split():
if "=" not in token:
return token.split("/")[-1] # basename
return stripped.split()[0] if stripped else ""
def validate_command(command: str) -> str | None:
"""Validate a command string for safety.
Returns an error message if the command is rejected, or None if safe.
"""
stripped = command.strip()
if not stripped:
return "Empty command"
# Check for destructive patterns (regex match)
lowered = stripped.lower()
for pattern in DESTRUCTIVE_PATTERNS:
if pattern.search(lowered):
return f"Destructive command rejected: matched pattern '{pattern.pattern}'"
# Check each command in a pipeline / chain for interactive commands
# Split on pipe, semicolon, &&, ||
segments = re.split(r"\||\;|&&|\|\|", stripped)
for segment in segments:
base = _extract_base_command(segment)
if base in INTERACTIVE_COMMANDS:
return f"Interactive command rejected: '{base}' requires terminal input"
return None
class BashTool(Tool):
"""Execute bash commands with safety guardrails.
Runs commands in a subprocess with timeout enforcement, working directory
scoped to PLATFORM_ROOT, and rejection of interactive/destructive commands.
"""
name: ClassVar[str] = "bash"
description: ClassVar[str] = (
"Execute a bash command and return its output. "
"Commands run in the platform root directory with timeout enforcement. "
"Interactive commands (vim, sudo, etc.) are rejected."
)
parameters: ClassVar[list[ToolParameter]] = [
ToolParameter(
name="command",
type="string",
description="The bash command to execute",
),
ToolParameter(
name="description",
type="string",
description="Human-readable description of what the command does",
),
ToolParameter(
name="timeout",
type="integer",
description="Timeout in milliseconds (default 120000, max 600000)",
required=False,
default=DEFAULT_TIMEOUT_MS,
),
]
async def execute(self, **kwargs: Any) -> ToolResult:
command: str = kwargs["command"]
description: str = kwargs["description"]
timeout_ms: int = kwargs.get("timeout", DEFAULT_TIMEOUT_MS)
# Clamp timeout
timeout_ms = max(1000, min(timeout_ms, MAX_TIMEOUT_MS))
timeout_seconds = timeout_ms / 1000.0
# Safety validation
error = validate_command(command)
if error is not None:
logger.warning("Command rejected: %s%s", command, error)
return ToolResult.fail(error, command=command, description=description)
logger.info("Executing: %s (%s)", command, description)
# Resolve working directory
from . import PLATFORM_ROOT
cwd = PLATFORM_ROOT if PLATFORM_ROOT.is_dir() else None
try:
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
process.kill()
await process.wait()
logger.warning("Command timed out after %.1fs: %s", timeout_seconds, command)
return ToolResult.timed_out(
timeout_seconds,
command=command,
description=description,
)
except OSError as exc:
return ToolResult.fail(
f"Failed to start process: {exc}",
command=command,
description=description,
)
stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
exit_code = process.returncode or 0
# Combine output
output_parts: list[str] = []
if stdout:
output_parts.append(stdout)
if stderr:
output_parts.append(stderr)
combined_output = "\n".join(output_parts).rstrip("\n")
# Truncate very large outputs to avoid blowing up context
max_output_chars = 100_000
truncated = False
if len(combined_output) > max_output_chars:
combined_output = combined_output[:max_output_chars]
truncated = True
logger.info("Command exited %d: %s", exit_code, command)
if exit_code != 0:
from ..base import ToolResultStatus
return ToolResult(
status=ToolResultStatus.ERROR,
output=combined_output,
error=f"Command exited with code {exit_code}",
metadata={
"command": command,
"description": description,
"exit_code": exit_code,
"truncated": truncated,
},
)
return ToolResult.success(
combined_output,
command=command,
description=description,
exit_code=exit_code,
truncated=truncated,
)