Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/strands_tools/a2a_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import httpx
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
from a2a.types import AgentCard, Message, Part, PushNotificationConfig, Role, TextPart
from strands import tool
from strands import tool, ToolContext
from strands.types.tools import AgentTool

DEFAULT_TIMEOUT = 300 # set request timeout to 5 minutes
Expand Down Expand Up @@ -250,9 +250,13 @@ async def _list_discovered_agents(self) -> dict[str, Any]:
"total_count": 0,
}

@tool
@tool(context=True)
async def a2a_send_message(
self, message_text: str, target_agent_url: str, message_id: str | None = None
self,
message_text: str,
target_agent_url: str,
message_id: str | None = None,
tool_context: ToolContext | None = None,
) -> dict[str, Any]:
"""
Send a message to a specific A2A agent and return the response.
Expand All @@ -275,10 +279,20 @@ async def a2a_send_message(
- message_id: The message ID used
- target_agent_url: The agent URL that was contacted
"""
return await self._send_message(message_text, target_agent_url, message_id)
# Retrieve a2a_tool_metadata from tool_context.
# The Strands framework automatically injects tool_context when the tool is invoked.
a2a_tool_metadata = None
if tool_context is not None and tool_context.invocation_state:
a2a_tool_metadata = tool_context.invocation_state.get("metadata")

return await self._send_message(message_text, target_agent_url, message_id, a2a_tool_metadata)

async def _send_message(
self, message_text: str, target_agent_url: str, message_id: str | None = None
self,
message_text: str,
target_agent_url: str,
message_id: str | None = None,
a2a_tool_metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Internal async implementation for send_message."""

Expand All @@ -303,7 +317,7 @@ async def _send_message(
logger.info(f"Sending message to {target_agent_url}")

# With streaming=False, this will yield exactly one result
async for event in client.send_message(message):
async for event in client.send_message(message, request_metadata=a2a_tool_metadata):
if isinstance(event, Message):
# Direct message response
return {
Expand Down
136 changes: 130 additions & 6 deletions tests/test_a2a_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from a2a.types import Message
from strands.types.tools import ToolContext

from strands_tools.a2a_client import DEFAULT_TIMEOUT, A2AClientToolProvider

Expand Down Expand Up @@ -328,7 +329,7 @@ async def test_send_message_with_message_id():
result = await provider.a2a_send_message("Hello", "http://test.com", "test_id")

assert result == expected_result
mock_send_message.assert_called_once_with("Hello", "http://test.com", "test_id")
mock_send_message.assert_called_once_with("Hello", "http://test.com", "test_id", None)


@pytest.mark.asyncio
Expand All @@ -348,7 +349,7 @@ async def test_send_message_without_message_id():
result = await provider.a2a_send_message("Hello", "http://test.com")

assert result == expected_result
mock_send_message.assert_called_once_with("Hello", "http://test.com", None)
mock_send_message.assert_called_once_with("Hello", "http://test.com", None, None)


@pytest.mark.asyncio
Expand Down Expand Up @@ -379,7 +380,7 @@ async def test_send_message_success(mock_ensure, mock_factory, mock_discover, mo
mock_response = Mock(spec=Message)
mock_response.model_dump.return_value = {"result": "success"}

async def mock_send_message_iter(message):
async def mock_send_message_iter(message, **kwargs):
yield mock_response

mock_client.send_message = mock_send_message_iter
Expand Down Expand Up @@ -509,7 +510,7 @@ async def test_send_message_task_response(mock_ensure, mock_factory, mock_discov
mock_update_event = Mock()
mock_update_event.model_dump.return_value = {"event": "finished"}

async def mock_send_message_iter(message):
async def mock_send_message_iter(message, **kwargs):
yield (mock_task, mock_update_event)

mock_client.send_message = mock_send_message_iter
Expand Down Expand Up @@ -556,7 +557,7 @@ async def test_send_message_task_response_no_update(mock_ensure, mock_factory, m
mock_task = Mock()
mock_task.model_dump.return_value = {"task_id": "123", "status": "completed"}

async def mock_send_message_iter(message):
async def mock_send_message_iter(message, **kwargs):
yield (mock_task, None)

mock_client.send_message = mock_send_message_iter
Expand All @@ -570,3 +571,126 @@ async def mock_send_message_iter(message):
"target_agent_url": "http://test.com",
}
assert result == expected


@pytest.mark.asyncio
async def test_send_message_with_metadata_from_tool_context():
"""Test a2a_send_message extracts metadata from tool_context and passes to _send_message."""
provider = A2AClientToolProvider()
a2a_tool_metadata = {"session_id": "abc", "priority": "high"}

# Create mock with invocation_state as a dict
mock_tool_context = MagicMock()
mock_tool_context.invocation_state = {"metadata": a2a_tool_metadata}

with patch.object(provider, "_send_message") as mock_send_message:
mock_send_message.return_value = {"status": "success"}

await provider.a2a_send_message("Hello", "http://test.com", "id1", mock_tool_context)

mock_send_message.assert_called_once_with("Hello", "http://test.com", "id1", a2a_tool_metadata)


@pytest.mark.asyncio
@patch("strands_tools.a2a_client.uuid4")
@patch.object(A2AClientToolProvider, "_discover_agent_card")
@patch.object(A2AClientToolProvider, "_get_client_factory")
@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents")
async def test_send_message_metadata_passed_to_client(mock_ensure, mock_factory, mock_discover, mock_uuid):
"""Test _send_message passes metadata as request_metadata to client.send_message."""
provider = A2AClientToolProvider()
a2a_tool_metadata = {"trace_id": "xyz"}

mock_message_uuid = Mock()
mock_message_uuid.hex = "msg_123"
mock_uuid.return_value = mock_message_uuid

mock_agent_card = Mock()
mock_discover.return_value = mock_agent_card

mock_client_factory = Mock()
mock_client = Mock()
mock_factory.return_value = mock_client_factory
mock_client_factory.create.return_value = mock_client

mock_response = Mock(spec=Message)
mock_response.model_dump.return_value = {"result": "ok"}

captured_kwargs = {}

async def mock_send_message_iter(message, **kwargs):
captured_kwargs.update(kwargs)
yield mock_response

mock_client.send_message = mock_send_message_iter

await provider._send_message("Hello", "http://test.com", None, a2a_tool_metadata)

assert captured_kwargs["request_metadata"] == a2a_tool_metadata


@pytest.mark.asyncio
@patch("strands_tools.a2a_client.uuid4")
@patch.object(A2AClientToolProvider, "_discover_agent_card")
@patch.object(A2AClientToolProvider, "_get_client_factory")
@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents")
async def test_send_message_none_metadata_passed_to_client(mock_ensure, mock_factory, mock_discover, mock_uuid):
"""Test _send_message passes None metadata when not provided."""
provider = A2AClientToolProvider()

mock_message_uuid = Mock()
mock_message_uuid.hex = "msg_456"
mock_uuid.return_value = mock_message_uuid

mock_agent_card = Mock()
mock_discover.return_value = mock_agent_card

mock_client_factory = Mock()
mock_client = Mock()
mock_factory.return_value = mock_client_factory
mock_client_factory.create.return_value = mock_client

mock_response = Mock(spec=Message)
mock_response.model_dump.return_value = {"result": "ok"}

captured_kwargs = {}

async def mock_send_message_iter(message, **kwargs):
captured_kwargs.update(kwargs)
yield mock_response

mock_client.send_message = mock_send_message_iter

await provider._send_message("Hello", "http://test.com")

assert captured_kwargs["request_metadata"] is None


@pytest.mark.asyncio
async def test_send_message_without_tool_context():
"""Test a2a_send_message passes None metadata when tool_context is not provided."""
provider = A2AClientToolProvider()

with patch.object(provider, "_send_message") as mock_send_message:
mock_send_message.return_value = {"status": "success"}

await provider.a2a_send_message("Hello", "http://test.com", "id1", None)

mock_send_message.assert_called_once_with("Hello", "http://test.com", "id1", None)


@pytest.mark.asyncio
async def test_send_message_with_empty_tool_context():
"""Test a2a_send_message passes None metadata when tool_context has no metadata."""
provider = A2AClientToolProvider()

# Create mock with invocation_state as empty dict
mock_tool_context = MagicMock()
mock_tool_context.invocation_state = {}

with patch.object(provider, "_send_message") as mock_send_message:
mock_send_message.return_value = {"status": "success"}

await provider.a2a_send_message("Hello", "http://test.com", "id1", mock_tool_context)

mock_send_message.assert_called_once_with("Hello", "http://test.com", "id1", None)
Loading