diff --git a/src/strands/vended_tools/__init__.py b/src/strands/vended_tools/__init__.py new file mode 100644 index 000000000..4778efbff --- /dev/null +++ b/src/strands/vended_tools/__init__.py @@ -0,0 +1,39 @@ +"""Vended tools for Strands agents. + +These are production-ready tools that ship with the SDK and integrate +with the :class:`~strands.sandbox.base.Sandbox` abstraction. They work +transparently whether the agent uses a local :class:`~strands.sandbox.host.HostSandbox` +or a remote sandbox implementation. + +Each tool reads its configuration from ``tool_context.agent.state`` using +a namespaced key (e.g., ``strands_shell_tool``). This means configuration +persists across tool calls and survives session serialization. + +Available tools: + +- :func:`~strands.vended_tools.shell.shell` — Execute shell commands +- :func:`~strands.vended_tools.editor.editor` — View, create, and edit files +- :func:`~strands.vended_tools.python_repl.python_repl` — Execute Python code + +Example:: + + from strands import Agent + from strands.vended_tools import shell, editor, python_repl + + agent = Agent(tools=[shell, editor, python_repl]) + + # Configure tools via agent state (persists across calls) + agent.state.set("strands_shell_tool", { + "timeout": 60, + }) +""" + +from .editor import editor +from .python_repl import python_repl +from .shell import shell + +__all__ = [ + "editor", + "python_repl", + "shell", +] diff --git a/src/strands/vended_tools/_utils.py b/src/strands/vended_tools/_utils.py new file mode 100644 index 000000000..e5e03c577 --- /dev/null +++ b/src/strands/vended_tools/_utils.py @@ -0,0 +1,25 @@ +"""Shared utilities for vended tools. + +Provides common helper functions used across all vended tools to avoid +code duplication. +""" + +from typing import Any + +from ..types.tools import ToolContext + + +def get_tool_config(tool_context: ToolContext, state_key: str) -> dict[str, Any]: + """Read tool configuration from agent state. + + All vended tools store their configuration in agent state under + a namespaced key. This helper standardizes the pattern. + + Args: + tool_context: The tool context providing access to agent state. + state_key: The state key for the tool's configuration. + + Returns: + Configuration dict. Empty dict if no config is set. + """ + return tool_context.agent.state.get(state_key) or {} diff --git a/src/strands/vended_tools/editor.py b/src/strands/vended_tools/editor.py new file mode 100644 index 000000000..b24aef1e2 --- /dev/null +++ b/src/strands/vended_tools/editor.py @@ -0,0 +1,416 @@ +"""File editor tool implementation. + +Provides view, create, str_replace, insert, and undo_edit operations on files +in the agent's sandbox. The tool delegates all file I/O to the sandbox's +``read_text``, ``write_text``, and ``list_files`` methods. + +The tool shape matches Anthropic's ``text_editor`` built-in tool — 5 commands, +7 parameters. This means models trained on Anthropic's tool spec will work +well with this tool out of the box. + +Configuration keys (set via ``agent.state.set("strands_editor_tool", {...})``): + +- ``max_file_size`` (int): Maximum file size in bytes for read operations. + Default: 1048576 (1 MB). +- ``require_absolute_paths`` (bool): When True, rejects relative paths and + paths containing ``..``. When False (the default), paths are passed through + to the sandbox without filesystem-level validation — the sandbox decides + what a path means. Default: False. + +Example:: + + from strands import Agent + from strands.vended_tools import editor + + agent = Agent(tools=[editor]) + agent("View the contents of /tmp/example.py") + + # Configure max file size + agent.state.set("strands_editor_tool", {"max_file_size": 2097152}) +""" + +import logging +from typing import Any, Literal + +from ..sandbox.base import Sandbox +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config + +logger = logging.getLogger(__name__) + +#: State key for editor tool configuration in agent.state +STATE_KEY = "strands_editor_tool" + +#: State key for undo history (internal) +_UNDO_STATE_KEY = "_strands_editor_undo" + +#: Default maximum file size (1 MB) +DEFAULT_MAX_FILE_SIZE = 1_048_576 + +#: Number of context lines to show around edits +SNIPPET_LINES = 4 + +#: Maximum directory listing depth +MAX_DIRECTORY_DEPTH = 2 + + +def _make_output(content: str, descriptor: str, init_line: int = 1) -> str: + """Format file content with line numbers (cat -n style). + + Args: + content: The file content to format. + descriptor: Description of what is being shown (e.g., file path). + init_line: Starting line number. + + Returns: + Formatted output with line numbers. + """ + # Expand tabs to spaces for display only + display_content = content.replace("\t", " ") + lines = display_content.split("\n") + numbered = [] + for i, line in enumerate(lines): + line_num = i + init_line + numbered.append(f"{line_num:>6} {line}") + return f"Here's the result of running `cat -n` on {descriptor}:\n" + "\n".join(numbered) + "\n" + + +def _save_undo(tool_context: ToolContext, path: str, content: str) -> None: + """Save file content for undo. + + Args: + tool_context: The tool context providing access to agent state. + path: The file path. + content: The file content before modification. + """ + undo_state = tool_context.agent.state.get(_UNDO_STATE_KEY) or {} + undo_state[path] = content + tool_context.agent.state.set(_UNDO_STATE_KEY, undo_state) + + +def _get_undo(tool_context: ToolContext, path: str) -> str | None: + """Get saved undo content for a file. + + Args: + tool_context: The tool context providing access to agent state. + path: The file path. + + Returns: + The saved content, or None if no undo is available. + """ + undo_state = tool_context.agent.state.get(_UNDO_STATE_KEY) or {} + return undo_state.get(path) + + +@tool(context=True) +async def editor( + command: Literal["view", "create", "str_replace", "insert", "undo_edit"], + path: str, + file_text: str | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + view_range: list[int] | None = None, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> str: + """View, create, and edit files in the agent's sandbox. + + Commands: + + - **view**: Display file contents with line numbers, or list directory contents. + Use ``view_range`` as ``[start_line, end_line]`` (1-indexed, -1 for end of file) + to view a specific range. + - **create**: Create a new file with ``file_text`` content. Fails if file exists. + - **str_replace**: Replace ``old_str`` with ``new_str`` in the file. + ``old_str`` must match exactly once in the file (uniqueness enforced). + - **insert**: Insert ``new_str`` at ``insert_line`` (0-indexed line number). + - **undo_edit**: Revert the last edit to the file at ``path``. + + File operations go through the agent's sandbox. By default, paths are passed + through to the sandbox as-is — the sandbox decides what a path means. Set + ``require_absolute_paths: true`` in ``strands_editor_tool`` config to enforce + absolute paths and block directory traversal. + + Configuration is read from ``agent.state.get("strands_editor_tool")``: + + - ``max_file_size``: Maximum file size in bytes (default: 1 MB). + - ``require_absolute_paths``: Reject relative paths and ``..`` (default: False). + + Args: + command: The operation to perform. + path: Path to the file or directory. + file_text: Content for new file (required for ``create``). + old_str: String to find and replace (required for ``str_replace``). + Must appear exactly once in the file. + new_str: Replacement string for ``str_replace``, or text to insert for ``insert``. + insert_line: Line number for insertion (0-indexed, required for ``insert``). + view_range: Line range for view as ``[start, end]``. 1-indexed. + Use -1 for end to mean end of file. + tool_context: Framework-injected tool context. + + Returns: + Result of the operation — file contents, success message, or error. + """ + config = get_tool_config(tool_context, STATE_KEY) + sandbox: Sandbox = tool_context.agent.sandbox + + # Path validation is opt-in. By default, paths are passed straight through + # to the sandbox without filesystem-level validation. This allows sandboxes + # like S3Sandbox to use relative keys (e.g., "hello.txt") as paths. + if config.get("require_absolute_paths"): + import os + + if not os.path.isabs(path): + suggested = os.path.abspath(path) + return f"Error: The path {path} is not an absolute path. Maybe you meant {suggested}?" + if ".." in path: + return "Error: Path traversal (..) is not allowed." + + try: + if command == "view": + return await _handle_view(sandbox, config, path, view_range) + elif command == "create": + if file_text is None: + return "Error: Parameter `file_text` is required for command: create" + return await _handle_create(sandbox, tool_context, path, file_text) + elif command == "str_replace": + if old_str is None: + return "Error: Parameter `old_str` is required for command: str_replace" + return await _handle_str_replace(sandbox, tool_context, config, path, old_str, new_str or "") + elif command == "insert": + if insert_line is None: + return "Error: Parameter `insert_line` is required for command: insert" + if new_str is None: + return "Error: Parameter `new_str` is required for command: insert" + # Coerce insert_line to int (LLMs may send floats) + try: + insert_line = int(insert_line) + except (TypeError, ValueError): + return f"Error: `insert_line` must be an integer, got: {type(insert_line).__name__}" + return await _handle_insert(sandbox, tool_context, config, path, insert_line, new_str) + elif command == "undo_edit": + return await _handle_undo(sandbox, tool_context, path) + + return f"Error: Unknown command: {command}" # type: ignore[unreachable] + except NotImplementedError as e: + return f"Error: Sandbox does not support this operation — {e}" + except (FileNotFoundError, UnicodeDecodeError, OSError, ValueError) as e: + return f"Error: {e}" + + +async def _handle_view(sandbox: Sandbox, config: dict[str, Any], path: str, view_range: list[int] | None) -> str: + """Handle the view command.""" + # Check if path is a directory + try: + entries = await sandbox.list_files(path) + # It's a directory + if view_range: + return "Error: The `view_range` parameter is not allowed when `path` points to a directory." + items = sorted(f"{e.name}/" if e.is_dir else e.name for e in entries if e.name not in (".", "..")) + return ( + f"Here's the files and directories up to 2 levels deep in {path}, " + f"excluding hidden items:\n" + "\n".join(items) + "\n" + ) + except (FileNotFoundError, OSError): + pass # Not a directory, try as file + + # Read file + max_size = config.get("max_file_size", DEFAULT_MAX_FILE_SIZE) + try: + content = await sandbox.read_text(path) + except FileNotFoundError: + return f"Error: The path {path} does not exist. Please provide a valid path." + except UnicodeDecodeError: + return f"Error: The file {path} is not a text file (cannot decode as UTF-8)." + + # Check size + if len(content.encode("utf-8")) > max_size: + return f"Error: File size exceeds maximum allowed size ({max_size} bytes)." + + if view_range is None: + return _make_output(content, path) + + # Validate and apply view range + lines = content.split("\n") + n_lines = len(lines) + + if len(view_range) != 2: + return "Error: `view_range` must be a list of two integers [start, end]." + + # Coerce view_range elements to int (LLMs may send floats like [1.0, 3.0]) + try: + start = int(view_range[0]) + end = int(view_range[1]) + except (TypeError, ValueError): + return "Error: `view_range` elements must be integers." + + if start < 1 or start > n_lines: + return ( + f"Error: Invalid `view_range`: [{start}, {end}]. First element `{start}` should be within [1, {n_lines}]." + ) + if end != -1 and end > n_lines: + return f"Error: Invalid `view_range`: [{start}, {end}]. Second element `{end}` should be <= {n_lines}." + if end != -1 and end < start: + return f"Error: Invalid `view_range`: [{start}, {end}]. Second element must be >= first element." + + if end == -1: + selected = lines[start - 1:] + else: + selected = lines[start - 1:end] + + return _make_output("\n".join(selected), path, init_line=start) + + +async def _handle_create(sandbox: Sandbox, tool_context: ToolContext, path: str, file_text: str) -> str: + """Handle the create command.""" + # Check if file already exists + try: + await sandbox.read_file(path) + return f"Error: File already exists at: {path}. Cannot overwrite with `create`. Use `str_replace` to edit." + except (FileNotFoundError, OSError): + pass # File doesn't exist, good + + await sandbox.write_text(path, file_text) + return f"File created successfully at: {path}" + + +async def _handle_str_replace( + sandbox: Sandbox, + tool_context: ToolContext, + config: dict[str, Any], + path: str, + old_str: str, + new_str: str, +) -> str: + """Handle the str_replace command.""" + try: + content = await sandbox.read_text(path) + except FileNotFoundError: + return f"Error: The path {path} does not exist." + + # Match against original content — do NOT expand tabs for matching. + # Tab expansion is only used for display output (_make_output). + # This preserves tab characters in the file and prevents false matches + # between tabs and their space equivalents. + count = content.count(old_str) + + if count == 0: + return f"Error: No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." + + if count > 1: + # Find line numbers of all occurrences + lines = content.split("\n") + line_nums: list[int] = [] + for i, line in enumerate(lines): + if old_str in line: + line_nums.append(i + 1) + # Also check multi-line matches + if not line_nums: + idx = 0 + while True: + idx = content.find(old_str, idx) + if idx == -1: + break + line_num = content[:idx].count("\n") + 1 + line_nums.append(line_num) + idx += 1 + return ( + f"Error: No replacement was performed. Multiple occurrences ({count}) of old_str " + f"in lines {line_nums}. Please ensure old_str is unique." + ) + + # Save undo state (original content before any modification) + _save_undo(tool_context, path, content) + + # Perform replacement on original content (preserving tabs) + new_content = content.replace(old_str, new_str, 1) + + # Write back (preserving original formatting including tabs) + await sandbox.write_text(path, new_content) + + # Generate snippet around the change (expand tabs for display only) + replace_idx = content.find(old_str) + replace_line = content[:replace_idx].count("\n") + inserted_lines = new_str.count("\n") + 1 + original_lines = old_str.count("\n") + 1 + line_diff = inserted_lines - original_lines + + new_lines = new_content.split("\n") + start = max(0, replace_line - SNIPPET_LINES) + end = min(len(new_lines), replace_line + SNIPPET_LINES + line_diff + 1) + snippet = "\n".join(new_lines[start:end]) + + return ( + f"The file {path} has been edited. " + + _make_output(snippet, f"a snippet of {path}", init_line=start + 1) + + "Review the changes and make sure they are as expected. Edit the file again if necessary." + ) + + +async def _handle_insert( + sandbox: Sandbox, + tool_context: ToolContext, + config: dict[str, Any], + path: str, + insert_line: int, + new_str: str, +) -> str: + """Handle the insert command.""" + try: + content = await sandbox.read_text(path) + except FileNotFoundError: + return f"Error: The path {path} does not exist." + + # Work with original content — do NOT expand tabs. + # Tab expansion is only for display output. + lines = content.split("\n") + n_lines = len(lines) + + if insert_line < 0 or insert_line > n_lines: + return f"Error: Invalid `insert_line`: {insert_line}. Should be within [0, {n_lines}]." + + # Save undo state (original content) + _save_undo(tool_context, path, content) + + # Insert (preserve original content including tabs) + new_str_lines = new_str.split("\n") + if content == "": + new_lines = new_str_lines + else: + new_lines = lines[:insert_line] + new_str_lines + lines[insert_line:] + + new_content = "\n".join(new_lines) + await sandbox.write_text(path, new_content) + + # Generate snippet (expand tabs for display only) + start = max(0, insert_line - SNIPPET_LINES) + end = min(len(new_lines), insert_line + len(new_str_lines) + SNIPPET_LINES) + snippet = "\n".join(new_lines[start:end]) + + return ( + f"The file {path} has been edited. " + + _make_output(snippet, "a snippet of the edited file", init_line=start + 1) + + "Review the changes and make sure they are as expected. Edit the file again if necessary." + ) + + +async def _handle_undo(sandbox: Sandbox, tool_context: ToolContext, path: str) -> str: + """Handle the undo_edit command.""" + previous_content = _get_undo(tool_context, path) + if previous_content is None: + return f"Error: No edit history found for {path}." + + # Read current content for future undo + try: + current = await sandbox.read_text(path) + except FileNotFoundError: + current = "" + + # Write the previous content back + await sandbox.write_text(path, previous_content) + + # Save current as new undo (so undo is toggleable) + _save_undo(tool_context, path, current) + + return f"Successfully reverted last edit to {path}." diff --git a/src/strands/vended_tools/python_repl.py b/src/strands/vended_tools/python_repl.py new file mode 100644 index 000000000..3a86b6e83 --- /dev/null +++ b/src/strands/vended_tools/python_repl.py @@ -0,0 +1,151 @@ +"""Python REPL tool implementation with streaming support. + +Executes Python code in the agent's sandbox using +``sandbox.execute_code_streaming(code, language="python")``. Each chunk of +stdout/stderr is yielded as a ``ToolStreamEvent`` in real time, allowing UI +consumers to display live output from code execution. + +The tool is an **async generator**: ``StreamChunk`` objects from the sandbox +are yielded during execution, and the final yield is the formatted result +string that becomes the ``ToolResult``. + +Configuration keys (set via ``agent.state.set("strands_python_repl_tool", {...})``): + +- ``timeout`` (int): Default timeout in seconds for code execution. + Overridden by the per-call ``timeout`` parameter. Default: 30. + +Example:: + + from strands import Agent + from strands.vended_tools import python_repl + + agent = Agent(tools=[python_repl]) + agent("Calculate the first 10 Fibonacci numbers") + + # Configure timeout + agent.state.set("strands_python_repl_tool", {"timeout": 60}) +""" + +import asyncio +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from ..sandbox.base import ExecutionResult, StreamChunk +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config + +logger = logging.getLogger(__name__) + +#: State key for python_repl tool configuration in agent.state +STATE_KEY = "strands_python_repl_tool" + +#: Default timeout for code execution (seconds) +DEFAULT_TIMEOUT = 30 + + +@tool(context=True) +async def python_repl( + code: str, + timeout: int | None = None, + reset: bool = False, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> AsyncGenerator[Any, None]: + """Execute Python code in the agent's sandbox with live output streaming. + + Code is executed via the agent's sandbox using + ``sandbox.execute_code_streaming(code, language="python")``. Each chunk + of stdout/stderr is yielded as a streaming event that UI consumers can + display in real time. The final yield is the formatted result string. + + Use ``reset=True`` to clear any sandbox-level state (e.g., restart the + interpreter session if the sandbox supports it). + + Configuration is read from ``agent.state.get("strands_python_repl_tool")``: + + - ``timeout``: Default timeout in seconds (overridden by per-call timeout). + + Args: + code: The Python code to execute. + timeout: Maximum execution time in seconds. Uses config default or 30s. + reset: If True, signal the sandbox to reset execution state. + tool_context: Framework-injected tool context. + + Yields: + :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as + ``ToolStreamEvent`` by the SDK), then a final string result. + """ + config = get_tool_config(tool_context, STATE_KEY) + sandbox = tool_context.agent.sandbox + + # Handle reset + if reset: + tool_context.agent.state.delete("_strands_python_repl_state") + if not code or not code.strip(): + yield "Python REPL state reset." + return + + # Resolve timeout: per-call > config > default + effective_timeout: int | None = timeout + if effective_timeout is None: + effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + + # Coerce timeout to int — JSON configs and LLMs may pass strings or floats + if effective_timeout is not None: + try: + effective_timeout = int(effective_timeout) + except (TypeError, ValueError): + effective_timeout = DEFAULT_TIMEOUT + + # Execute via sandbox streaming + result: ExecutionResult | None = None + try: + async for chunk in sandbox.execute_code_streaming( + code, + language="python", + timeout=effective_timeout, + ): + if isinstance(chunk, StreamChunk): + # Yield each chunk — the decorator wraps it as ToolStreamEvent + yield chunk + elif isinstance(chunk, ExecutionResult): + result = chunk + except asyncio.TimeoutError: + yield f"Error: Code execution timed out after {effective_timeout} seconds." + return + except NotImplementedError: + yield "Error: Sandbox does not support code execution (NoOpSandbox)." + return + except OSError as e: + yield f"Error: {e}" + return + + if result is None: + yield "Error: Sandbox did not return an execution result." + return + + # Format output + output_parts = [] + if result.stdout: + output_parts.append(result.stdout) + if result.stderr: + output_parts.append(result.stderr) + + output = "\n".join(output_parts).rstrip() + + if result.exit_code != 0: + if output: + output += f"\n\nExit code: {result.exit_code}" + else: + output = f"Code execution failed with exit code: {result.exit_code}" + + # Handle output files (images, charts, etc.) + if result.output_files: + file_names = [f.name for f in result.output_files] + if output: + output += f"\n\nGenerated files: {', '.join(file_names)}" + else: + output = f"Generated files: {', '.join(file_names)}" + + yield output if output else "(no output)" diff --git a/src/strands/vended_tools/shell.py b/src/strands/vended_tools/shell.py new file mode 100644 index 000000000..a7d4efa99 --- /dev/null +++ b/src/strands/vended_tools/shell.py @@ -0,0 +1,226 @@ +"""Shell tool implementation with streaming support. + +Executes shell commands in the agent's sandbox with persistent state tracking. +The tool uses ``sandbox.execute_streaming()`` so that stdout/stderr chunks are +yielded as ``ToolStreamEvent``s in real time. This allows UI consumers to display +live output from sandbox execution. + +The tool is an **async generator**: each ``StreamChunk`` from the sandbox is +yielded directly (the SDK decorator wraps it in a ``ToolStreamEvent``), and the +final yield is the formatted result string (which becomes the ``ToolResult``). + +Configuration keys (set via ``agent.state.set("strands_shell_tool", {...})``): + +- ``timeout`` (int): Default timeout in seconds. Overridden by the per-call + ``timeout`` parameter. Default: 120. + +Example:: + + from strands import Agent + from strands.vended_tools import shell + + agent = Agent(tools=[shell]) + agent("List all Python files in the current directory") + + # Configure timeout + agent.state.set("strands_shell_tool", {"timeout": 60}) +""" + +import asyncio +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from ..sandbox.base import ExecutionResult, StreamChunk +from ..tools.decorator import tool +from ..types.tools import ToolContext +from ._utils import get_tool_config + +logger = logging.getLogger(__name__) + +#: State key for shell tool configuration in agent.state +STATE_KEY = "strands_shell_tool" + +#: Default timeout for shell commands (seconds) +DEFAULT_TIMEOUT = 120 + +#: Internal marker used to separate user output from cwd tracking. +#: Must be unique enough to never appear in legitimate command output. +_CWD_MARKER = "__STRANDS_CWD__" + + +def _safe_yield_length(buffer: str, marker: str) -> int: + """Return the number of chars from the start of buffer that are safe to yield. + + "Safe" means: no suffix of the yielded portion could be a prefix of the marker. + This prevents partial marker leakage when the marker is split across chunks. + """ + # Check if the buffer contains the full marker — if so, only yield up to it + marker_pos = buffer.find(marker) + if marker_pos != -1: + return marker_pos + + # Check if the end of the buffer matches a prefix of the marker. + # e.g., buffer ends with "__STRAN" which is a prefix of "__STRANDS_CWD__" + max_overlap = min(len(marker) - 1, len(buffer)) + for i in range(max_overlap, 0, -1): + if buffer.endswith(marker[:i]): + return len(buffer) - i + + # No overlap — everything is safe + return len(buffer) + + +@tool(context=True) +async def shell( + command: str, + timeout: int | None = None, + restart: bool = False, + tool_context: ToolContext = None, # type: ignore[assignment] +) -> AsyncGenerator[Any, None]: + """Execute a shell command in the agent's sandbox with live output streaming. + + The sandbox preserves working directory and environment variables across + calls when using a persistent sandbox implementation. Use ``restart=True`` + to reset the shell state. + + Commands are executed via the agent's sandbox + (``sandbox.execute_streaming()``). Each chunk of stdout/stderr is yielded + as a streaming event that UI consumers can display in real time. The final + yield is the formatted result string. + + Configuration is read from ``agent.state.get("strands_shell_tool")``: + + - ``timeout``: Default timeout in seconds (overridden by per-call timeout). + + Args: + command: The shell command to execute. + timeout: Maximum execution time in seconds. Uses config default or 120s. + restart: If True, reset shell state by clearing tracked working directory. + tool_context: Framework-injected tool context. + + Yields: + :class:`~strands.sandbox.base.StreamChunk` objects during execution (wrapped as + ``ToolStreamEvent`` by the SDK), then a final string result. + """ + config = get_tool_config(tool_context, STATE_KEY) + sandbox = tool_context.agent.sandbox + + # Handle restart + if restart: + _clear_shell_state(tool_context) + if not command or not command.strip(): + yield "Shell state reset." + return + + # Resolve timeout: per-call > config > default + effective_timeout: int | None = timeout + if effective_timeout is None: + effective_timeout = config.get("timeout", DEFAULT_TIMEOUT) + + # Coerce timeout to int — JSON configs and LLMs may pass strings or floats + if effective_timeout is not None: + try: + effective_timeout = int(effective_timeout) + except (TypeError, ValueError): + effective_timeout = DEFAULT_TIMEOUT + + # Get tracked working directory from state (for session continuity) + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + cwd = shell_state.get("cwd") + + # Append cwd tracking to the command. We use a unique marker so we can + # reliably split the actual output from the cwd line. This captures the + # final working directory even after `cd` commands (which only affect + # the shell process they run in — a separate pwd call would not see them). + tracked_command = f"{command}; echo {_CWD_MARKER}; pwd" + + # Streaming with marker-aware buffering: + # We accumulate stdout in a buffer and yield only the "safe" prefix — i.e., + # bytes that cannot possibly be part of the CWD marker. This preserves + # real-time streaming for all output while ensuring no partial or full + # marker ever leaks to UI consumers. stderr is always yielded immediately. + stdout_buffer = "" + result: ExecutionResult | None = None + + try: + async for chunk in sandbox.execute_streaming( + tracked_command, + timeout=effective_timeout, + cwd=cwd, + ): + if isinstance(chunk, StreamChunk): + if chunk.stream_type == "stderr": + # stderr never contains the marker — yield immediately + yield chunk + else: + # Append to buffer, yield whatever is safely past the marker + stdout_buffer += chunk.data + safe_len = _safe_yield_length(stdout_buffer, _CWD_MARKER) + if safe_len > 0: + yield StreamChunk( + data=stdout_buffer[:safe_len], stream_type="stdout" + ) + stdout_buffer = stdout_buffer[safe_len:] + elif isinstance(chunk, ExecutionResult): + result = chunk + except asyncio.TimeoutError: + yield f"Error: Command timed out after {effective_timeout} seconds." + return + except NotImplementedError: + yield "Error: Sandbox does not support command execution (NoOpSandbox)." + return + except OSError as e: + yield f"Error: {e}" + return + + if result is None: + yield "Error: Sandbox did not return an execution result." + return + + # Extract cwd from the full stdout using rsplit — this splits on the LAST + # occurrence of the marker, which is always the one we appended. This + # prevents corruption if user commands happen to output the marker string. + stdout = result.stdout or "" + if _CWD_MARKER in stdout: + parts = stdout.rsplit(_CWD_MARKER, 1) + # Actual command output is before the LAST marker + stdout = parts[0].rstrip("\n") + # The cwd is the line after the last marker + new_cwd = parts[1].strip() + if new_cwd: + shell_state["cwd"] = new_cwd + tool_context.agent.state.set("_strands_shell_state", shell_state) + + # Flush remaining buffer with marker stripped. + # The buffer holds the tail that overlapped with the marker prefix. + if stdout_buffer: + cleaned = stdout_buffer.rsplit(_CWD_MARKER, 1)[0].rstrip("\n") + if cleaned: + yield StreamChunk(data=cleaned, stream_type="stdout") + + # Format final output (becomes the ToolResult) + output_parts = [] + if stdout: + output_parts.append(stdout) + if result.stderr: + output_parts.append(result.stderr) + + output = "\n".join(output_parts).rstrip() + + if result.exit_code != 0: + if output: + output += f"\n\nExit code: {result.exit_code}" + else: + output = f"Command failed with exit code: {result.exit_code}" + + yield output if output else "(no output)" + + +def _clear_shell_state(tool_context: ToolContext) -> None: + """Clear tracked shell state from agent state. + + Args: + tool_context: The tool context providing access to agent state. + """ + tool_context.agent.state.delete("_strands_shell_state") diff --git a/tests/strands/vended_tools/__init__.py b/tests/strands/vended_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_tools/conftest.py b/tests/strands/vended_tools/conftest.py new file mode 100644 index 000000000..d4da4ced6 --- /dev/null +++ b/tests/strands/vended_tools/conftest.py @@ -0,0 +1,69 @@ +"""Shared fixtures for vended tools tests.""" + +import uuid +from unittest.mock import MagicMock + +import pytest + +from strands.agent.state import AgentState +from strands.sandbox.base import StreamChunk +from strands.sandbox.host import HostSandbox +from strands.types.tools import ToolContext, ToolUse + + +@pytest.fixture +def sandbox(tmp_path): + """Create a HostSandbox for testing.""" + return HostSandbox(working_dir=str(tmp_path)) + + +@pytest.fixture +def agent_state(): + """Create a fresh AgentState.""" + return AgentState() + + +@pytest.fixture +def mock_agent(sandbox, agent_state): + """Create a mock agent with sandbox and state.""" + agent = MagicMock() + agent.sandbox = sandbox + agent.state = agent_state + return agent + + +@pytest.fixture +def tool_use(): + """Create a mock tool use.""" + return ToolUse( + toolUseId=str(uuid.uuid4()), + name="test_tool", + input={}, + ) + + +@pytest.fixture +def tool_context(mock_agent, tool_use): + """Create a ToolContext for testing.""" + return ToolContext( + tool_use=tool_use, + agent=mock_agent, + invocation_state={}, + ) + + +async def collect_generator(gen): + """Collect all values from an async generator. + + Returns (stream_chunks, final_result) where stream_chunks are all + StreamChunk objects yielded, and final_result is the last non-StreamChunk + value (the formatted result string). + """ + chunks = [] + final = None + async for item in gen: + if isinstance(item, StreamChunk): + chunks.append(item) + else: + final = item + return chunks, final diff --git a/tests/strands/vended_tools/test_adversarial_fixes.py b/tests/strands/vended_tools/test_adversarial_fixes.py new file mode 100644 index 000000000..aa87d3a19 --- /dev/null +++ b/tests/strands/vended_tools/test_adversarial_fixes.py @@ -0,0 +1,243 @@ +"""Adversarial tests validating the bug fixes. + +These tests verify that all 6 findings from the adversarial testing report +are now resolved. +""" + +import pytest + +from strands.sandbox.base import ExecutionResult, StreamChunk +from strands.vended_tools.shell import shell, _CWD_MARKER +from strands.vended_tools.editor import editor, _UNDO_STATE_KEY, STATE_KEY + +from .conftest import collect_generator + + +class TestShellCwdMarkerInjectionFixed: + """Verify Fix #1: CWD marker injection no longer corrupts state.""" + + @pytest.mark.asyncio + async def test_user_output_marker_does_not_corrupt_cwd(self, tool_context, tmp_path): + """User echoing the marker string should not affect tracked cwd.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '/evil/path'", + tool_context=tool_context, + ) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + tracked_cwd = shell_state.get("cwd", "") + + # CWD should NOT be /evil/path or contain the marker + assert tracked_cwd != "/evil/path" + assert _CWD_MARKER not in tracked_cwd + # It SHOULD be a real, valid single-line path + assert "\n" not in tracked_cwd + assert tracked_cwd != "" + + @pytest.mark.asyncio + async def test_multiple_marker_outputs_still_track_correctly(self, tool_context, tmp_path): + """Multiple user-output markers should not corrupt state.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '{_CWD_MARKER}' && echo fake_path", + tool_context=tool_context, + ) + ) + + shell_state = tool_context.agent.state.get("_strands_shell_state") or {} + tracked_cwd = shell_state.get("cwd", "") + + assert tracked_cwd != "fake_path" + assert _CWD_MARKER not in tracked_cwd + assert "\n" not in tracked_cwd + assert tracked_cwd != "" + + @pytest.mark.asyncio + async def test_subsequent_calls_work_after_marker_injection(self, tool_context, tmp_path): + """After a marker injection attempt, next shell call should still work.""" + # First call: inject marker + await collect_generator( + shell.__wrapped__( + command=f"echo '{_CWD_MARKER}' && echo '/evil/path'", + tool_context=tool_context, + ) + ) + + # Second call: should work normally (no Permission denied) + chunks, result = await collect_generator( + shell.__wrapped__(command="echo follow_up", tool_context=tool_context) + ) + assert "follow_up" in result + assert "error" not in result.lower() + + +class TestShellStringTimeoutFixed: + """Verify Fix #3: String timeout no longer crashes.""" + + @pytest.mark.asyncio + async def test_string_timeout_coerced_to_int(self, tool_context, mock_agent): + """String timeout "30" should be coerced to int 30.""" + mock_agent.state.set("strands_shell_tool", {"timeout": "30"}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + # Should not crash — timeout is coerced + assert "test" in result + assert "error" not in result.lower() + + @pytest.mark.asyncio + async def test_float_timeout_coerced(self, tool_context, mock_agent): + """Float timeout 30.5 should be coerced to int 30.""" + mock_agent.state.set("strands_shell_tool", {"timeout": 30.5}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "test" in result + + @pytest.mark.asyncio + async def test_invalid_timeout_uses_default(self, tool_context, mock_agent): + """Non-numeric timeout should fall back to default.""" + mock_agent.state.set("strands_shell_tool", {"timeout": "not_a_number"}) + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "test" in result + + +class TestShellMarkerSplitFixed: + """Verify Fix #5: Marker split across chunks no longer leaks.""" + + @pytest.mark.asyncio + async def test_marker_split_across_chunks_no_leak(self, tool_context, mock_agent): + """Marker split across chunks should not leak partial marker to consumer.""" + from unittest.mock import AsyncMock + + async def fake_streaming(command, timeout=None, cwd=None): + yield StreamChunk(data="output\n__STRAN", stream_type="stdout") + yield StreamChunk(data="DS_CWD__\n/tmp\n", stream_type="stdout") + yield ExecutionResult( + exit_code=0, + stdout="output\n__STRANDS_CWD__\n/tmp\n", + stderr="", + ) + + mock_agent.sandbox = AsyncMock() + mock_agent.sandbox.execute_streaming = fake_streaming + + chunks, result = await collect_generator( + shell.__wrapped__(command="test", tool_context=tool_context) + ) + + # No partial marker should leak + all_chunk_data = "".join(c.data for c in chunks if c.stream_type == "stdout") + assert "__STRAN" not in all_chunk_data + assert _CWD_MARKER not in all_chunk_data + # User output should be preserved + assert "output" in all_chunk_data + + +class TestEditorTabExpansionFixed: + """Verify Fixes #2 and #4: Tab expansion no longer corrupts files or creates false matches.""" + + @pytest.mark.asyncio + async def test_tabs_preserved_after_edit(self, tool_context, sandbox, tmp_path): + """Editing a file should NOT destroy tab characters.""" + path = f"{tmp_path}/preserve_tabs.txt" + original = "def foo():\n\treturn 42\n" + await sandbox.write_text(path, original) + + # Edit something unrelated to tabs + result = await editor.__wrapped__( + command="str_replace", + path=path, + old_str="42", + new_str="99", + tool_context=tool_context, + ) + + content = await sandbox.read_text(path) + # Tab MUST still be present + assert "\t" in content, f"Tab was destroyed! Content: {repr(content)}" + assert content == "def foo():\n\treturn 99\n" + + @pytest.mark.asyncio + async def test_tab_and_spaces_not_conflated(self, tool_context, sandbox, tmp_path): + """Tab and 8 spaces should NOT be treated as the same thing.""" + path = f"{tmp_path}/tabs.txt" + content = "\thello\n hello\n" + await sandbox.write_text(path, content) + + # Replace the tab version — should work (unique in source) + result = await editor.__wrapped__( + command="str_replace", + path=path, + old_str="\thello", + new_str="replaced", + tool_context=tool_context, + ) + + # Should succeed — no "multiple occurrences" error + assert "edited" in result.lower(), f"Failed: {result}" + new_content = await sandbox.read_text(path) + assert "replaced" in new_content + assert " hello" in new_content # 8-space version unchanged + + @pytest.mark.asyncio + async def test_insert_preserves_tabs(self, tool_context, sandbox, tmp_path): + """Insert should not expand tabs in existing content.""" + path = f"{tmp_path}/insert_tabs.txt" + await sandbox.write_text(path, "line1\n\tindented\nline3") + + result = await editor.__wrapped__( + command="insert", + path=path, + insert_line=1, + new_str="new_line", + tool_context=tool_context, + ) + + content = await sandbox.read_text(path) + assert "\t" in content, f"Tab destroyed by insert! Content: {repr(content)}" + + +class TestEditorFloatViewRangeFixed: + """Verify Fix #6: Float view_range no longer crashes.""" + + @pytest.mark.asyncio + async def test_float_view_range_coerced(self, tool_context, sandbox, tmp_path): + """view_range with floats [1.0, 3.0] should work (coerced to ints).""" + path = f"{tmp_path}/float_range.txt" + await sandbox.write_text(path, "line1\nline2\nline3\nline4\nline5") + + result = await editor.__wrapped__( + command="view", + path=path, + view_range=[1.0, 3.0], + tool_context=tool_context, + ) + + # Should work, showing lines 1-3 + assert "line1" in result + assert "line2" in result + assert "line3" in result + assert "line4" not in result + + @pytest.mark.asyncio + async def test_invalid_view_range_type_gives_error(self, tool_context, sandbox, tmp_path): + """Non-numeric view_range should give a clear error.""" + path = f"{tmp_path}/bad_range.txt" + await sandbox.write_text(path, "line1\nline2\nline3") + + result = await editor.__wrapped__( + command="view", + path=path, + view_range=["a", "b"], + tool_context=tool_context, + ) + + assert "error" in result.lower() diff --git a/tests/strands/vended_tools/test_editor.py b/tests/strands/vended_tools/test_editor.py new file mode 100644 index 000000000..c3e63b0df --- /dev/null +++ b/tests/strands/vended_tools/test_editor.py @@ -0,0 +1,360 @@ +"""Tests for the editor vended tool.""" + +import pytest + +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.editor import editor + + +class TestEditorTool: + """Tests for the editor vended tool.""" + + @pytest.mark.asyncio + async def test_view_file(self, tool_context, tmp_path): + """Test viewing a file.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "line 1" in result + assert "line 2" in result + assert "line 3" in result + assert "cat -n" in result + + @pytest.mark.asyncio + async def test_view_with_range(self, tool_context, tmp_path): + """Test viewing a file with line range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\nline 4\nline 5\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, 4], + tool_context=tool_context, + ) + assert "line 2" in result + assert "line 4" in result + assert " 1" not in result + + @pytest.mark.asyncio + async def test_view_with_range_end_minus_one(self, tool_context, tmp_path): + """Test viewing with -1 as end of range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\nline 3\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[2, -1], + tool_context=tool_context, + ) + assert "line 2" in result + assert "line 3" in result + + @pytest.mark.asyncio + async def test_view_directory(self, tool_context, tmp_path): + """Test viewing a directory listing.""" + (tmp_path / "file1.py").write_text("pass") + (tmp_path / "file2.txt").write_text("hello") + (tmp_path / "subdir").mkdir() + + result = await editor.__wrapped__( + command="view", + path=str(tmp_path), + tool_context=tool_context, + ) + assert "file1.py" in result + assert "file2.txt" in result + assert "subdir/" in result + + @pytest.mark.asyncio + async def test_view_nonexistent(self, tool_context, tmp_path): + """Test viewing a nonexistent file.""" + result = await editor.__wrapped__( + command="view", + path=str(tmp_path / "nonexistent.txt"), + tool_context=tool_context, + ) + assert "does not exist" in result.lower() + + @pytest.mark.asyncio + async def test_view_invalid_range(self, tool_context, tmp_path): + """Test viewing with invalid range.""" + test_file = tmp_path / "test.txt" + test_file.write_text("line 1\nline 2\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + view_range=[0, 2], + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_create_file(self, tool_context, tmp_path): + """Test creating a new file.""" + new_file = tmp_path / "new_file.py" + + result = await editor.__wrapped__( + command="create", + path=str(new_file), + file_text="print('hello')\n", + tool_context=tool_context, + ) + assert "created" in result.lower() + assert new_file.read_text() == "print('hello')\n" + + @pytest.mark.asyncio + async def test_create_existing_file(self, tool_context, tmp_path): + """Test creating a file that already exists.""" + existing = tmp_path / "existing.py" + existing.write_text("original") + + result = await editor.__wrapped__( + command="create", + path=str(existing), + file_text="new content", + tool_context=tool_context, + ) + assert "already exists" in result.lower() + assert existing.read_text() == "original" + + @pytest.mark.asyncio + async def test_create_missing_file_text(self, tool_context, tmp_path): + """Test create without file_text.""" + result = await editor.__wrapped__( + command="create", + path=str(tmp_path / "new.py"), + tool_context=tool_context, + ) + assert "file_text" in result.lower() + + @pytest.mark.asyncio + async def test_str_replace_unique(self, tool_context, tmp_path): + """Test str_replace with a unique match.""" + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="return 'hello'", + new_str="return 'world'", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert "return 'world'" in test_file.read_text() + + @pytest.mark.asyncio + async def test_str_replace_not_found(self, tool_context, tmp_path): + """Test str_replace when old_str not found.""" + test_file = tmp_path / "test.py" + test_file.write_text("def hello():\n return 'hello'\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="nonexistent string", + new_str="replacement", + tool_context=tool_context, + ) + assert "did not appear" in result.lower() + + @pytest.mark.asyncio + async def test_str_replace_multiple_occurrences(self, tool_context, tmp_path): + """Test str_replace rejects multiple occurrences.""" + test_file = tmp_path / "test.py" + test_file.write_text("x = 1\ny = 1\nz = 1\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="= 1", + new_str="= 2", + tool_context=tool_context, + ) + assert "multiple" in result.lower() + assert test_file.read_text() == "x = 1\ny = 1\nz = 1\n" + + @pytest.mark.asyncio + async def test_str_replace_deletion(self, tool_context, tmp_path): + """Test str_replace with empty new_str (deletion).""" + test_file = tmp_path / "test.py" + test_file.write_text("# TODO: remove this\ndef main():\n pass\n") + + result = await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="# TODO: remove this\n", + new_str="", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert "TODO" not in test_file.read_text() + + @pytest.mark.asyncio + async def test_insert(self, tool_context, tmp_path): + """Test inserting text at a line.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 1\nline 3\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=1, + new_str="line 2", + tool_context=tool_context, + ) + assert "edited" in result.lower() + content = test_file.read_text() + assert "line 1\nline 2\nline 3\n" == content + + @pytest.mark.asyncio + async def test_insert_at_beginning(self, tool_context, tmp_path): + """Test inserting at the beginning of a file.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 2\nline 3\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=0, + new_str="line 1", + tool_context=tool_context, + ) + assert "edited" in result.lower() + assert test_file.read_text().startswith("line 1\n") + + @pytest.mark.asyncio + async def test_insert_invalid_line(self, tool_context, tmp_path): + """Test insert with invalid line number.""" + test_file = tmp_path / "test.py" + test_file.write_text("line 1\n") + + result = await editor.__wrapped__( + command="insert", + path=str(test_file), + insert_line=999, + new_str="new line", + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_undo_edit(self, tool_context, tmp_path): + """Test undo_edit reverting a str_replace.""" + test_file = tmp_path / "test.py" + test_file.write_text("original content\n") + + await editor.__wrapped__( + command="str_replace", + path=str(test_file), + old_str="original content", + new_str="modified content", + tool_context=tool_context, + ) + assert "modified content" in test_file.read_text() + + result = await editor.__wrapped__( + command="undo_edit", + path=str(test_file), + tool_context=tool_context, + ) + assert "reverted" in result.lower() + assert "original content" in test_file.read_text() + + @pytest.mark.asyncio + async def test_undo_no_history(self, tool_context, tmp_path): + """Test undo_edit when no history exists.""" + result = await editor.__wrapped__( + command="undo_edit", + path=str(tmp_path / "nonexistent.py"), + tool_context=tool_context, + ) + assert "no edit history" in result.lower() + + @pytest.mark.asyncio + async def test_relative_path_allowed_by_default(self, tool_context, tmp_path): + """Test that relative paths are passed through to sandbox by default.""" + test_file = tmp_path / "relative_test.txt" + test_file.write_text("relative content\n") + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "relative content" in result + + @pytest.mark.asyncio + async def test_relative_path_rejected_when_configured(self, tool_context, mock_agent): + """Test that relative paths are rejected when require_absolute_paths is True.""" + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = await editor.__wrapped__( + command="view", + path="relative/path.py", + tool_context=tool_context, + ) + assert "not an absolute path" in result.lower() + + @pytest.mark.asyncio + async def test_path_traversal_allowed_by_default(self, tool_context, tmp_path): + """Test that paths with .. are passed through to sandbox by default.""" + subdir = tmp_path / "sub" + subdir.mkdir() + test_file = tmp_path / "traversal_test.txt" + test_file.write_text("traversal content\n") + + traversal_path = str(subdir / ".." / "traversal_test.txt") + result = await editor.__wrapped__( + command="view", + path=traversal_path, + tool_context=tool_context, + ) + assert "traversal content" in result + + @pytest.mark.asyncio + async def test_path_traversal_rejected_when_configured(self, tool_context, mock_agent): + """Test that path traversal is rejected when require_absolute_paths is True.""" + mock_agent.state.set("strands_editor_tool", {"require_absolute_paths": True}) + + result = await editor.__wrapped__( + command="view", + path="/tmp/../etc/passwd", + tool_context=tool_context, + ) + assert "not allowed" in result.lower() + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent, tmp_path): + """Test editor with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + + result = await editor.__wrapped__( + command="view", + path=str(tmp_path / "test.py"), + tool_context=tool_context, + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_max_file_size(self, tool_context, mock_agent, tmp_path): + """Test max file size configuration.""" + mock_agent.state.set("strands_editor_tool", {"max_file_size": 10}) + + test_file = tmp_path / "large.txt" + test_file.write_text("a" * 100) + + result = await editor.__wrapped__( + command="view", + path=str(test_file), + tool_context=tool_context, + ) + assert "exceeds" in result.lower() diff --git a/tests/strands/vended_tools/test_init.py b/tests/strands/vended_tools/test_init.py new file mode 100644 index 000000000..9c793b208 --- /dev/null +++ b/tests/strands/vended_tools/test_init.py @@ -0,0 +1,185 @@ +"""Tests for vended tools package imports and tool specs.""" + +import inspect + +import pytest + +from strands.sandbox.base import StreamChunk + +from .conftest import collect_generator + + +class TestVendedToolsImport: + """Test that vended tools can be imported from the package.""" + + def test_import_from_vended_tools(self): + """Test importing from strands.vended_tools.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_import_individual_tools(self): + """Test importing individual tools from flat modules.""" + from strands.vended_tools.editor import editor + from strands.vended_tools.python_repl import python_repl + from strands.vended_tools.shell import shell + + assert shell is not None + assert editor is not None + assert python_repl is not None + + def test_tools_have_tool_spec(self): + """Test that tools have proper tool specs.""" + from strands.vended_tools import editor, python_repl, shell + + assert shell.tool_name == "shell" + assert editor.tool_name == "editor" + assert python_repl.tool_name == "python_repl" + + for t in [shell, editor, python_repl]: + spec = t.tool_spec + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + + def test_shell_tool_spec_shape(self): + """Test shell tool spec matches expected shape.""" + from strands.vended_tools import shell + + spec = shell.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "timeout" in props + assert "restart" in props + assert schema.get("required") == ["command"] + + def test_editor_tool_spec_shape(self): + """Test editor tool spec matches expected shape.""" + from strands.vended_tools import editor + + spec = editor.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "command" in props + assert "path" in props + assert "file_text" in props + assert "old_str" in props + assert "new_str" in props + assert "insert_line" in props + assert "view_range" in props + assert set(schema.get("required", [])) == {"command", "path"} + + def test_python_repl_tool_spec_shape(self): + """Test python_repl tool spec matches expected shape.""" + from strands.vended_tools import python_repl + + spec = python_repl.tool_spec + schema = spec["inputSchema"]["json"] + props = schema.get("properties", {}) + + assert "code" in props + assert "timeout" in props + assert "reset" in props + assert schema.get("required") == ["code"] + + def test_shell_is_async_generator(self): + """Test that shell is detected as an async generator function.""" + from strands.vended_tools.shell import shell + + assert inspect.isasyncgenfunction(shell.__wrapped__) + + def test_python_repl_is_async_generator(self): + """Test that python_repl is detected as an async generator function.""" + from strands.vended_tools.python_repl import python_repl + + assert inspect.isasyncgenfunction(python_repl.__wrapped__) + + def test_editor_is_not_async_generator(self): + """Test that editor is a regular async function (not generator).""" + from strands.vended_tools.editor import editor + + assert not inspect.isasyncgenfunction(editor.__wrapped__) + assert inspect.iscoroutinefunction(editor.__wrapped__) + + +class TestStreamingIntegration: + """Test the streaming behavior of tools end-to-end.""" + + @pytest.mark.asyncio + async def test_shell_streams_before_result(self, tool_context): + """Test that shell yields chunks BEFORE the final result.""" + from strands.vended_tools.shell import shell + + all_items = [] + + async for item in shell.__wrapped__( + command="echo streaming_test", tool_context=tool_context + ): + all_items.append(item) + + # Should have at least 2 items: chunk(s) + final result + assert len(all_items) >= 2 + # Last item should be the string result + assert isinstance(all_items[-1], str) + # Earlier items should include StreamChunk + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + @pytest.mark.asyncio + async def test_python_repl_streams_before_result(self, tool_context): + """Test that python_repl yields chunks BEFORE the final result.""" + from strands.vended_tools.python_repl import python_repl + + all_items = [] + + async for item in python_repl.__wrapped__( + code="print('streaming_test')", tool_context=tool_context + ): + all_items.append(item) + + assert len(all_items) >= 2 + assert isinstance(all_items[-1], str) + stream_items = [i for i in all_items[:-1] if isinstance(i, StreamChunk)] + assert len(stream_items) >= 1 + + @pytest.mark.asyncio + async def test_shell_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield any StreamChunks before the error.""" + from strands.vended_tools.shell import shell + + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_python_repl_error_no_streaming_on_timeout(self, tool_context): + """Test that timeout errors don't yield chunks before the error.""" + from strands.vended_tools.python_repl import python_repl + + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_shell_stream_data_matches_result(self, tool_context): + """Test that streamed chunk data matches the final result content.""" + from strands.vended_tools.shell import shell + + chunks, result = await collect_generator( + shell.__wrapped__(command="echo precise_output_42", tool_context=tool_context) + ) + # The streamed chunks should contain the same data as the final result + all_chunk_data = "".join(c.data for c in chunks) + assert "precise_output_42" in all_chunk_data + assert "precise_output_42" in result diff --git a/tests/strands/vended_tools/test_python_repl.py b/tests/strands/vended_tools/test_python_repl.py new file mode 100644 index 000000000..9e174bff3 --- /dev/null +++ b/tests/strands/vended_tools/test_python_repl.py @@ -0,0 +1,177 @@ +"""Tests for the python_repl vended tool.""" + +import pytest + +from strands.sandbox.base import StreamChunk +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.python_repl import python_repl + +from .conftest import collect_generator + + +class TestPythonReplTool: + """Tests for the python_repl vended tool.""" + + @pytest.mark.asyncio + async def test_basic_code(self, tool_context): + """Test basic Python code execution returns result.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + assert "hello from python" in result + + @pytest.mark.asyncio + async def test_basic_code_streams_chunks(self, tool_context): + """Test that python_repl yields StreamChunk objects during execution.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('hello from python')", + tool_context=tool_context, + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello from python" in c.data for c in stdout_chunks) + + @pytest.mark.asyncio + async def test_stderr_streams_as_stderr_chunks(self, tool_context): + """Test that stderr from Python code yields stderr StreamChunks.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import sys; print('err_msg', file=sys.stderr)", + tool_context=tool_context, + ) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("err_msg" in c.data for c in stderr_chunks) + + @pytest.mark.asyncio + async def test_code_with_math(self, tool_context): + """Test Python math execution.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print(2 + 2)", + tool_context=tool_context, + ) + ) + assert "4" in result + + @pytest.mark.asyncio + async def test_code_with_error(self, tool_context): + """Test Python code that raises an error.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="raise ValueError('test error')", + tool_context=tool_context, + ) + ) + assert "test error" in result or "ValueError" in result + + @pytest.mark.asyncio + async def test_code_with_import(self, tool_context): + """Test Python code with imports.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import json; print(json.dumps({'key': 'value'}))", + tool_context=tool_context, + ) + ) + assert "key" in result + + @pytest.mark.asyncio + async def test_timeout(self, tool_context): + """Test code execution timeout.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + timeout=1, + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + mock_agent.state.set("strands_python_repl_tool", {"timeout": 1}) + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import time; time.sleep(10)", + tool_context=tool_context, + ) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_reset(self, tool_context): + """Test REPL reset.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="", + reset=True, + tool_context=tool_context, + ) + ) + assert "reset" in result.lower() + + @pytest.mark.asyncio + async def test_multiline_code(self, tool_context): + """Test multiline Python code.""" + code = """ +def fibonacci(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + +print(fibonacci(10)) +""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code=code, + tool_context=tool_context, + ) + ) + assert "55" in result + + @pytest.mark.asyncio + async def test_no_output(self, tool_context): + """Test code with no output.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="x = 42", + tool_context=tool_context, + ) + ) + assert result == "(no output)" + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent): + """Test python_repl with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="print('test')", + tool_context=tool_context, + ) + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + chunks, result = await collect_generator( + python_repl.__wrapped__( + code="import sys; print('out'); print('err', file=sys.stderr)", + tool_context=tool_context, + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") diff --git a/tests/strands/vended_tools/test_shell.py b/tests/strands/vended_tools/test_shell.py new file mode 100644 index 000000000..fd9f1161e --- /dev/null +++ b/tests/strands/vended_tools/test_shell.py @@ -0,0 +1,225 @@ +"""Tests for the shell vended tool.""" + +import pytest + +from strands.sandbox.base import ExecutionResult, StreamChunk +from strands.sandbox.noop import NoOpSandbox +from strands.vended_tools.shell import shell + +from .conftest import collect_generator + + +class TestShellTool: + """Tests for the shell vended tool.""" + + @pytest.mark.asyncio + async def test_basic_command(self, tool_context, tmp_path): + """Test basic shell command execution returns result.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello", tool_context=tool_context) + ) + assert "hello" in result + + @pytest.mark.asyncio + async def test_basic_command_streams_chunks(self, tool_context, tmp_path): + """Test that shell yields StreamChunk objects during execution.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello", tool_context=tool_context) + ) + # Should have at least one stdout chunk + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + assert len(stdout_chunks) >= 1 + assert any("hello" in c.data for c in stdout_chunks) + # Final result should also contain the output + assert "hello" in result + + @pytest.mark.asyncio + async def test_stderr_streams_as_stderr_chunks(self, tool_context, tmp_path): + """Test that stderr output yields StreamChunk with stream_type='stderr'.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo error >&2", tool_context=tool_context) + ) + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stderr_chunks) >= 1 + assert any("error" in c.data for c in stderr_chunks) + assert "error" in result + + @pytest.mark.asyncio + async def test_mixed_stdout_stderr_streaming(self, tool_context, tmp_path): + """Test command with both stdout and stderr streams both chunk types.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo out && echo err >&2", + tool_context=tool_context, + ) + ) + stdout_chunks = [c for c in chunks if c.stream_type == "stdout"] + stderr_chunks = [c for c in chunks if c.stream_type == "stderr"] + assert len(stdout_chunks) >= 1 + assert len(stderr_chunks) >= 1 + + @pytest.mark.asyncio + async def test_command_with_exit_code(self, tool_context, tmp_path): + """Test command that returns non-zero exit code.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="exit 42", tool_context=tool_context) + ) + assert "42" in result + + @pytest.mark.asyncio + async def test_timeout(self, tool_context): + """Test command timeout.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", timeout=1, tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_config_timeout(self, tool_context, mock_agent): + """Test timeout from config.""" + mock_agent.state.set("strands_shell_tool", {"timeout": 1}) + chunks, result = await collect_generator( + shell.__wrapped__(command="sleep 10", tool_context=tool_context) + ) + assert "timed out" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_restart(self, tool_context): + """Test shell restart.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="", restart=True, tool_context=tool_context) + ) + assert "reset" in result.lower() + + @pytest.mark.asyncio + async def test_no_output_command(self, tool_context): + """Test command with no output.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="true", tool_context=tool_context) + ) + assert result == "(no output)" + + @pytest.mark.asyncio + async def test_noop_sandbox(self, tool_context, mock_agent): + """Test shell with NoOpSandbox.""" + mock_agent.sandbox = NoOpSandbox() + chunks, result = await collect_generator( + shell.__wrapped__(command="echo test", tool_context=tool_context) + ) + assert "error" in result.lower() + + @pytest.mark.asyncio + async def test_cwd_tracking(self, tool_context, tmp_path): + """Test that working directory is tracked across calls.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + + chunks, result = await collect_generator( + shell.__wrapped__(command=f"cd {subdir}", tool_context=tool_context) + ) + + # Verify cwd state was tracked + shell_state = tool_context.agent.state.get("_strands_shell_state") + assert shell_state is not None + assert shell_state["cwd"] == str(subdir) + + # Verify no internal markers leak into the result or streamed chunks + assert "__STRANDS_CWD__" not in result + for chunk in chunks: + assert "__STRANDS_CWD__" not in chunk.data + + @pytest.mark.asyncio + async def test_cwd_no_marker_leak_in_streaming(self, tool_context, tmp_path): + """Test that no internal markers appear in streamed output.""" + chunks, result = await collect_generator( + shell.__wrapped__(command="echo hello && cd /tmp", tool_context=tool_context) + ) + # No internal markers should appear in any output + all_chunk_data = "".join(c.data for c in chunks) + assert "__STRANDS_CWD__" not in all_chunk_data + assert "__STRANDS_CWD__" not in result + + @pytest.mark.asyncio + async def test_multiline_output(self, tool_context): + """Test command with multiline output.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo 'line1\nline2\nline3'", + tool_context=tool_context, + ) + ) + assert "line1" in result + assert "line2" in result + + @pytest.mark.asyncio + async def test_pipe_command(self, tool_context): + """Test piped commands.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo 'hello world' | wc -w", + tool_context=tool_context, + ) + ) + assert "2" in result + + @pytest.mark.asyncio + async def test_stream_chunk_types_are_correct(self, tool_context): + """Test that all yielded chunks are proper StreamChunk instances.""" + chunks, result = await collect_generator( + shell.__wrapped__( + command="echo stdout_data && echo stderr_data >&2", + tool_context=tool_context, + ) + ) + for chunk in chunks: + assert isinstance(chunk, StreamChunk) + assert hasattr(chunk, "data") + assert hasattr(chunk, "stream_type") + assert chunk.stream_type in ("stdout", "stderr") + + @pytest.mark.asyncio + async def test_stderr_streams_immediately_during_one_chunk_behind( + self, tool_context, mock_agent + ): + """Test that stderr chunks are yielded immediately, not held back. + + The one-chunk-behind approach only holds back stdout chunks. Stderr + chunks should pass through without delay regardless of buffering. + """ + # Use a mock sandbox that interleaves stdout and stderr chunks + from unittest.mock import AsyncMock + + async def fake_streaming(command, timeout=None, cwd=None): + """Simulate a sandbox that yields interleaved stdout/stderr.""" + yield StreamChunk(data="stdout1\n", stream_type="stdout") + yield StreamChunk(data="err1\n", stream_type="stderr") + yield StreamChunk(data="stdout2\n", stream_type="stdout") + yield StreamChunk(data="err2\n", stream_type="stderr") + yield StreamChunk(data="stdout3\n__STRANDS_CWD__\n/tmp\n", stream_type="stdout") + yield ExecutionResult( + exit_code=0, + stdout="stdout1\nstdout2\nstdout3\n__STRANDS_CWD__\n/tmp\n", + stderr="err1\nerr2\n", + ) + + mock_agent.sandbox = AsyncMock() + mock_agent.sandbox.execute_streaming = fake_streaming + + chunks, result = await collect_generator( + shell.__wrapped__(command="test", tool_context=tool_context) + ) + + # Verify all chunks are present and no markers leaked + all_data = "".join(c.data for c in chunks) + assert "__STRANDS_CWD__" not in all_data + + # Verify stderr chunks are present + stderr_data = "".join(c.data for c in chunks if c.stream_type == "stderr") + assert "err1" in stderr_data + assert "err2" in stderr_data + + # Verify stdout chunks are present (minus marker) + stdout_data = "".join(c.data for c in chunks if c.stream_type == "stdout") + assert "stdout1" in stdout_data + assert "stdout2" in stdout_data + assert "stdout3" in stdout_data