Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions src/strands_evals/simulation/tool_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from strands.models.model import Model
from strands.tools.decorator import DecoratedFunctionTool, FunctionToolMetadata

from strands_evals.types.simulation.hook_events import PostCallHookEvent, PreCallHookEvent
from strands_evals.types.simulation.tool import DefaultToolResponse, RegisteredTool

from .prompt_templates.tool_response_generation import TOOL_RESPONSE_PROMPT_TEMPLATE
Expand Down Expand Up @@ -166,6 +167,8 @@ def __init__(
state_registry: StateRegistry | None = None,
model: Model | str | None = None,
max_tool_call_cache_size: int = 20,
pre_call_hook: Callable | None = None,
post_call_hook: Callable | None = None,
):
"""
Initialize a ToolSimulator instance.
Expand All @@ -178,10 +181,21 @@ def __init__(
Only used when creating a new StateRegistry (ignored if state_registry
is provided). Older calls are automatically evicted when limit is exceeded.
Default is 20.
pre_call_hook: Optional callable invoked before the LLM generates a tool response.
Receives a PreCallHookEvent with tool_name, parameters, state_key,
and previous_calls. If it returns a non-None dict, that dict is used
as the tool response (short-circuiting the LLM call) and cached in
the state registry. If it returns None, normal LLM simulation proceeds.
post_call_hook: Optional callable invoked after the LLM generates a tool response
but before it is cached. Receives a PostCallHookEvent with tool_name,
parameters, state_key, and response. Must return a (possibly modified)
response dict.
"""
self.model = model
self.state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size)
self._registered_tools: dict[str, RegisteredTool] = {}
self._pre_call_hook = pre_call_hook
self._post_call_hook = post_call_hook

def _create_tool_wrapper(self, registered_tool: RegisteredTool):
"""
Expand Down Expand Up @@ -245,7 +259,35 @@ def _parse_simulated_response(self, result: AgentResult) -> dict[str, Any]:
return response_data

def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, state_key: str) -> dict[str, Any]:
"""Simulate a tool invocation and return the response."""
"""Simulate a tool invocation and return the response.
Comment thread
kaghatim marked this conversation as resolved.

If a pre_call_hook is configured and returns a non-None dict, that dict is used
as the tool response (short-circuiting the LLM call). The response is still cached.

If a post_call_hook is configured, it receives the LLM-generated response before
caching and may modify it.
"""
parameters = json.loads(parameters_string)
current_state = self.state_registry.get_state(state_key)

# Pre-call hook: may short-circuit the LLM call
if self._pre_call_hook is not None:
event = PreCallHookEvent(
tool_name=registered_tool.name,
parameters=parameters,
state_key=state_key,
previous_calls=current_state.get("previous_calls", []),
)
hook_response = self._pre_call_hook(event)
if hook_response is not None:
Comment thread
kaghatim marked this conversation as resolved.
if not isinstance(hook_response, dict):
raise TypeError(f"pre_call_hook must return a dict or None, got {type(hook_response).__name__}")
self.state_registry.cache_tool_call(
registered_tool.name, state_key, hook_response, parameters=parameters
)
return hook_response

# Normal LLM simulation
# Get input schema from Strands tool decorator
input_schema_dict = registered_tool.function.tool_spec.get("inputSchema", {}).get("json", {})
input_schema = json.dumps(input_schema_dict, indent=2)
Expand All @@ -254,8 +296,6 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st
output_schema = registered_tool.output_schema.model_json_schema()
output_schema_string = json.dumps(output_schema, indent=2)

current_state = self.state_registry.get_state(state_key)

prompt = TOOL_RESPONSE_PROMPT_TEMPLATE.format(
tool_name=registered_tool.name,
input_schema=input_schema,
Expand All @@ -268,9 +308,19 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st

response_data = self._parse_simulated_response(result)

self.state_registry.cache_tool_call(
registered_tool.name, state_key, response_data, parameters=json.loads(parameters_string)
)
# Post-call hook: may modify the response before caching
if self._post_call_hook is not None:
event = PostCallHookEvent(
tool_name=registered_tool.name,
parameters=parameters,
state_key=state_key,
response=response_data,
)
response_data = self._post_call_hook(event)
if not isinstance(response_data, dict):
raise TypeError(f"post_call_hook must return a dict, got {type(response_data).__name__}")

self.state_registry.cache_tool_call(registered_tool.name, state_key, response_data, parameters=parameters)
return response_data

def tool(
Expand Down
38 changes: 38 additions & 0 deletions src/strands_evals/types/simulation/hook_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass, field
from typing import Any


@dataclass
class PreCallHookEvent:
"""
Event passed to pre_call_hook before the LLM generates a tool response.

Attributes:
tool_name: Name of the tool being called.
parameters: Parsed parameters for the tool call.
state_key: Key for the state (tool_name or share_state_id).
previous_calls: List of previous tool call records from the state registry.
"""

tool_name: str
parameters: dict[str, Any]
state_key: str
previous_calls: list[dict[str, Any]] = field(default_factory=list)


@dataclass
class PostCallHookEvent:
"""
Event passed to post_call_hook after the LLM generates a tool response.

Attributes:
tool_name: Name of the tool that was called.
parameters: Parsed parameters for the tool call.
state_key: Key for the state (tool_name or share_state_id).
response: The LLM-generated response dict, which the hook may modify.
Comment thread
kaghatim marked this conversation as resolved.
"""

tool_name: str
parameters: dict[str, Any]
state_key: str
response: dict[str, Any] = field(default_factory=dict)
Loading
Loading