Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,23 @@ def _create_task_agent(self, task: Dict) -> Agent:
filtered_tools = []
if task_tools and self.parent_agent and hasattr(self.parent_agent, "tool_registry"):
# Filter parent agent tools to only include specified tool names
# ALWAYS exclude 'workflow' tool to prevent recursion
available_tools = self.parent_agent.tool_registry.registry
for tool_name in task_tools:
if tool_name == "workflow":
logger.warning("Excluding 'workflow' tool from task agent to prevent recursion")
continue
if tool_name in available_tools:
filtered_tools.append(available_tools[tool_name])
else:
logger.warning(f"Tool '{tool_name}' not found in parent agent's tool registry")
elif self.parent_agent and hasattr(self.parent_agent, "tool_registry"):
# Inherit all tools from parent if none specified
filtered_tools = list(self.parent_agent.tool_registry.registry.values())
# Inherit all tools from parent EXCEPT the workflow tool to prevent recursion
for tool_name, tool_obj in self.parent_agent.tool_registry.registry.items():
if tool_name == "workflow":
logger.debug("Automatically excluding 'workflow' tool from task agent to prevent recursion")
continue
filtered_tools.append(tool_obj)

# Configure model
selected_model = None
Expand Down
94 changes: 94 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,100 @@ def test_get_ready_tasks_with_completed_dependencies(self, mock_parent_agent):
assert len(ready_tasks) == 1
assert ready_tasks[0]["task_id"] == "task2"

def test_create_task_agent_excludes_workflow_tool_explicit(self, mock_parent_agent):
"""Test that explicitly requesting 'workflow' tool in task tools list is excluded to prevent recursion."""
mock_parent_agent.tool_registry.registry["workflow"] = MagicMock()

with patch("strands_tools.workflow.Agent") as mock_agent_class:
mock_agent_class.return_value = MagicMock()

manager = workflow_module.WorkflowManager(mock_parent_agent)
manager.parent_agent = mock_parent_agent

task = {
"task_id": "test_task",
"description": "Test task",
"tools": ["calculator", "workflow"],
}

manager._create_task_agent(task)

call_kwargs = mock_agent_class.call_args.kwargs
# Should have only 1 tool (calculator), workflow excluded
assert len(call_kwargs["tools"]) == 1
assert call_kwargs["tools"][0] == mock_parent_agent.tool_registry.registry["calculator"]

def test_create_task_agent_excludes_workflow_tool_inherit_all(self, mock_parent_agent):
"""Test that workflow tool is automatically excluded when inheriting all tools from parent."""
mock_parent_agent.tool_registry.registry["workflow"] = MagicMock()

with patch("strands_tools.workflow.Agent") as mock_agent_class:
mock_agent_class.return_value = MagicMock()

manager = workflow_module.WorkflowManager(mock_parent_agent)
manager.parent_agent = mock_parent_agent

# No tools specified - should inherit all except workflow
task = {
"task_id": "test_task",
"description": "Test task",
}

manager._create_task_agent(task)

call_kwargs = mock_agent_class.call_args.kwargs
tool_count = len(call_kwargs["tools"])
# Parent has 8 original tools + workflow = 9 total, workflow excluded = 8
assert tool_count == 8
# Verify workflow tool is not in the list
workflow_tool = mock_parent_agent.tool_registry.registry["workflow"]
assert workflow_tool not in call_kwargs["tools"]

def test_create_task_agent_only_workflow_tool_requested(self, mock_parent_agent):
"""Test that requesting only 'workflow' tool results in empty tools list."""
mock_parent_agent.tool_registry.registry["workflow"] = MagicMock()

with patch("strands_tools.workflow.Agent") as mock_agent_class:
mock_agent_class.return_value = MagicMock()

manager = workflow_module.WorkflowManager(mock_parent_agent)
manager.parent_agent = mock_parent_agent

task = {
"task_id": "test_task",
"description": "Test task",
"tools": ["workflow"],
}

manager._create_task_agent(task)

call_kwargs = mock_agent_class.call_args.kwargs
# Should have 0 tools since only workflow was requested and it's excluded
assert len(call_kwargs["tools"]) == 0

def test_create_task_agent_without_workflow_in_registry(self, mock_parent_agent):
"""Test that task agent creation works normally when workflow is not in parent registry."""
# mock_parent_agent does not have 'workflow' in registry by default
assert "workflow" not in mock_parent_agent.tool_registry.registry

with patch("strands_tools.workflow.Agent") as mock_agent_class:
mock_agent_class.return_value = MagicMock()

manager = workflow_module.WorkflowManager(mock_parent_agent)
manager.parent_agent = mock_parent_agent

# No tools specified - inherits all from parent
task = {
"task_id": "test_task",
"description": "Test task",
}

manager._create_task_agent(task)

call_kwargs = mock_agent_class.call_args.kwargs
# Should have all 8 original tools
assert len(call_kwargs["tools"]) == 8


class TestWorkflowEdgeCases:
"""Test edge cases and error conditions."""
Expand Down
Loading