Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/aws-strands/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ag_ui_strands"
version = "0.1.7"
version = "0.1.8"
authors = [
{ name = "AG-UI Contributors" }
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ToolResultContext,
PredictStateMapping,
SessionManagerProvider,
ToolStreamEventHandler,
)

__all__ = [
Expand All @@ -29,5 +30,6 @@
"ToolResultContext",
"PredictStateMapping",
"SessionManagerProvider",
"ToolStreamEventHandler",
]

20 changes: 18 additions & 2 deletions integrations/aws-strands/python/src/ag_ui_strands/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,25 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[Any]:
elif "tool_stream_event" in event:
tool_stream = event["tool_stream_event"]
stream_data = tool_stream.get("data", {})
_tse_tool_use = tool_stream.get("tool_use", {})
_tse_tool_name = _tse_tool_use.get("name", "")
_tse_tool_use_id = _tse_tool_use.get("toolUseId", "")
_tse_behavior = self.config.tool_behaviors.get(_tse_tool_name) if _tse_tool_name else None

# Emit state snapshot if tool yielded state
if isinstance(stream_data, dict) and "state" in stream_data:
if _tse_behavior and _tse_behavior.tool_stream_event_handler:
try:
async for _tse_event in _tse_behavior.tool_stream_event_handler(
_tse_tool_use_id, stream_data
):
if _tse_event is not None:
yield _tse_event
except Exception as _tse_exc:
logger.warning(
f"tool_stream_event_handler failed for {_tse_tool_name}: {_tse_exc}",
exc_info=True,
)
elif isinstance(stream_data, dict) and "state" in stream_data:
# Default behaviour: emit state snapshot when tool yields {"state": ...}
yield StateSnapshotEvent(
type=EventType.STATE_SNAPSHOT,
snapshot=stream_data["state"],
Expand Down
12 changes: 12 additions & 0 deletions integrations/aws-strands/python/src/ag_ui_strands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ class ToolResultContext(ToolCallContext):
CustomResultHandler = Callable[[ToolResultContext], AsyncIterator[Any]]
StateContextBuilder = Callable[[RunAgentInput, str], str]
SessionManagerProvider = Callable[[RunAgentInput], Awaitable[Optional[SessionManager]] | Optional[SessionManager]]
ToolStreamEventHandler = Callable[[str, Any], AsyncIterator[Any]]
"""Handler for raw tool_stream_event data emitted by async-generator tools.

Called with (tool_use_id: str, stream_data: Any) for every intermediate event
yielded by the tool while it is executing. The handler may yield zero or more
AG-UI Event objects which are forwarded directly into the top-level event stream.

When a handler is registered for a tool, the default behaviour of emitting a
StateSnapshotEvent for ``{"state": ...}`` payloads is suppressed for that tool.
The handler is responsible for any state updates it wants to emit.
"""


@dataclass
Expand Down Expand Up @@ -78,6 +89,7 @@ class ToolBehavior:
state_from_args: Optional[StateFromArgs] = None
state_from_result: Optional[StateFromResult] = None
custom_result_handler: Optional[CustomResultHandler] = None
tool_stream_event_handler: Optional[ToolStreamEventHandler] = None


@dataclass
Expand Down