diff --git a/src/strands_tools/workflow.py b/src/strands_tools/workflow.py index c9dd0555..902bacea 100644 --- a/src/strands_tools/workflow.py +++ b/src/strands_tools/workflow.py @@ -328,15 +328,23 @@ def _create_task_agent(self, task: Dict) -> Agent: filtered_tools = [] if task_tools and self.parent_agent and hasattr(self.parent_agent, "tool_registry"): # Filter parent agent tools to only include specified tool names + # ALWAYS exclude 'workflow' tool to prevent recursion available_tools = self.parent_agent.tool_registry.registry for tool_name in task_tools: + if tool_name == "workflow": + logger.warning("Excluding 'workflow' tool from task agent to prevent recursion") + continue if tool_name in available_tools: filtered_tools.append(available_tools[tool_name]) else: logger.warning(f"Tool '{tool_name}' not found in parent agent's tool registry") elif self.parent_agent and hasattr(self.parent_agent, "tool_registry"): - # Inherit all tools from parent if none specified - filtered_tools = list(self.parent_agent.tool_registry.registry.values()) + # Inherit all tools from parent EXCEPT the workflow tool to prevent recursion + for tool_name, tool_obj in self.parent_agent.tool_registry.registry.items(): + if tool_name == "workflow": + logger.debug("Automatically excluding 'workflow' tool from task agent to prevent recursion") + continue + filtered_tools.append(tool_obj) # Configure model selected_model = None diff --git a/tests/test_workflow.py b/tests/test_workflow.py index bfc35b3f..ef320210 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -572,6 +572,100 @@ def test_get_ready_tasks_with_completed_dependencies(self, mock_parent_agent): assert len(ready_tasks) == 1 assert ready_tasks[0]["task_id"] == "task2" + def test_create_task_agent_excludes_workflow_tool_explicit(self, mock_parent_agent): + """Test that explicitly requesting 'workflow' tool in task tools list is excluded to prevent recursion.""" + mock_parent_agent.tool_registry.registry["workflow"] = MagicMock() + + with patch("strands_tools.workflow.Agent") as mock_agent_class: + mock_agent_class.return_value = MagicMock() + + manager = workflow_module.WorkflowManager(mock_parent_agent) + manager.parent_agent = mock_parent_agent + + task = { + "task_id": "test_task", + "description": "Test task", + "tools": ["calculator", "workflow"], + } + + manager._create_task_agent(task) + + call_kwargs = mock_agent_class.call_args.kwargs + # Should have only 1 tool (calculator), workflow excluded + assert len(call_kwargs["tools"]) == 1 + assert call_kwargs["tools"][0] == mock_parent_agent.tool_registry.registry["calculator"] + + def test_create_task_agent_excludes_workflow_tool_inherit_all(self, mock_parent_agent): + """Test that workflow tool is automatically excluded when inheriting all tools from parent.""" + mock_parent_agent.tool_registry.registry["workflow"] = MagicMock() + + with patch("strands_tools.workflow.Agent") as mock_agent_class: + mock_agent_class.return_value = MagicMock() + + manager = workflow_module.WorkflowManager(mock_parent_agent) + manager.parent_agent = mock_parent_agent + + # No tools specified - should inherit all except workflow + task = { + "task_id": "test_task", + "description": "Test task", + } + + manager._create_task_agent(task) + + call_kwargs = mock_agent_class.call_args.kwargs + tool_count = len(call_kwargs["tools"]) + # Parent has 8 original tools + workflow = 9 total, workflow excluded = 8 + assert tool_count == 8 + # Verify workflow tool is not in the list + workflow_tool = mock_parent_agent.tool_registry.registry["workflow"] + assert workflow_tool not in call_kwargs["tools"] + + def test_create_task_agent_only_workflow_tool_requested(self, mock_parent_agent): + """Test that requesting only 'workflow' tool results in empty tools list.""" + mock_parent_agent.tool_registry.registry["workflow"] = MagicMock() + + with patch("strands_tools.workflow.Agent") as mock_agent_class: + mock_agent_class.return_value = MagicMock() + + manager = workflow_module.WorkflowManager(mock_parent_agent) + manager.parent_agent = mock_parent_agent + + task = { + "task_id": "test_task", + "description": "Test task", + "tools": ["workflow"], + } + + manager._create_task_agent(task) + + call_kwargs = mock_agent_class.call_args.kwargs + # Should have 0 tools since only workflow was requested and it's excluded + assert len(call_kwargs["tools"]) == 0 + + def test_create_task_agent_without_workflow_in_registry(self, mock_parent_agent): + """Test that task agent creation works normally when workflow is not in parent registry.""" + # mock_parent_agent does not have 'workflow' in registry by default + assert "workflow" not in mock_parent_agent.tool_registry.registry + + with patch("strands_tools.workflow.Agent") as mock_agent_class: + mock_agent_class.return_value = MagicMock() + + manager = workflow_module.WorkflowManager(mock_parent_agent) + manager.parent_agent = mock_parent_agent + + # No tools specified - inherits all from parent + task = { + "task_id": "test_task", + "description": "Test task", + } + + manager._create_task_agent(task) + + call_kwargs = mock_agent_class.call_args.kwargs + # Should have all 8 original tools + assert len(call_kwargs["tools"]) == 8 + class TestWorkflowEdgeCases: """Test edge cases and error conditions."""