Skip to content
5 changes: 4 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,11 @@ strands-agents/
│ │
│ ├── plugins/ # Plugin system
│ │ ├── plugin.py # Plugin base class
│ │ ├── multiagent_plugin.py # MultiAgentPlugin base class
│ │ ├── decorator.py # @hook decorator
│ │ └── registry.py # PluginRegistry for tracking plugins
│ │ ├── registry.py # PluginRegistry for tracking agent plugins
│ │ ├── multiagent_registry.py # Registry for tracking orchestrator plugins
│ │ └── _discovery.py # Shared hook/tool discovery utilities
│ │
│ ├── handlers/ # Event handlers
│ │ └── callback_handler.py # Callback handling
Expand Down
3 changes: 2 additions & 1 deletion src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .agent.agent import Agent
from .agent.base import AgentBase
from .event_loop._retry import ModelRetryStrategy
from .plugins import Plugin
from .plugins import MultiAgentPlugin, Plugin
from .tools.decorator import tool
from .types._snapshot import Snapshot
from .types.tools import ToolContext
Expand All @@ -17,6 +17,7 @@
"agent",
"models",
"ModelRetryStrategy",
"MultiAgentPlugin",
"Plugin",
"Skill",
"Snapshot",
Expand Down
15 changes: 15 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .._async import run_async
from ..agent import AgentResult
from ..hooks.registry import HookCallback
from ..interrupt import Interrupt
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput
Expand Down Expand Up @@ -254,6 +255,20 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
"""Restore orchestrator state from a session dict."""
raise NotImplementedError

def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
Comment thread
zastrowm marked this conversation as resolved.
"""Register a hook callback with the orchestrator.

Subclasses that support hooks should override this method to register
the callback with their hook registry.

Args:
callback: The callback function to invoke when events of this type occur.
event_type: The class type(s) of events this callback should handle.
Can be a single type, a list of types, or None to infer from
the callback's first parameter type hint.
"""
raise NotImplementedError(f"{type(self).__name__} must implement add_hook() to support plugins")

def _parse_trace_attributes(
self, attributes: Mapping[str, AttributeValue] | None = None
) -> dict[str, AttributeValue]:
Expand Down
33 changes: 32 additions & 1 deletion src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
BeforeNodeCallEvent,
MultiAgentInitializedEvent,
)
from ..hooks.registry import HookProvider, HookRegistry
from ..hooks.registry import HookCallback, HookProvider, HookRegistry
from ..interrupt import Interrupt, _InterruptState
from ..plugins.multiagent_plugin import MultiAgentPlugin
from ..plugins.multiagent_registry import _MultiAgentPluginRegistry
from ..session import SessionManager
from ..telemetry import get_tracer
from ..types._events import (
Expand Down Expand Up @@ -253,6 +255,7 @@ def __init__(self) -> None:
self._id: str = _DEFAULT_GRAPH_ID
self._session_manager: SessionManager | None = None
self._hooks: list[HookProvider] | None = None
self._plugins: list[MultiAgentPlugin] | None = None

def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode:
"""Add an AgentBase or MultiAgentBase instance as a node to the graph."""
Expand Down Expand Up @@ -370,6 +373,15 @@ def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder":
self._hooks = hooks
return self

def set_plugins(self, plugins: list[MultiAgentPlugin]) -> "GraphBuilder":
"""Set plugins for the graph.

Args:
plugins: List of multi-agent plugins for extending graph behavior
"""
self._plugins = plugins
return self

def build(self) -> "Graph":
"""Build and validate the graph with configured settings."""
if not self.nodes:
Expand Down Expand Up @@ -398,6 +410,7 @@ def build(self) -> "Graph":
session_manager=self._session_manager,
hooks=self._hooks,
id=self._id,
plugins=self._plugins,
)

def _validate_graph(self) -> None:
Expand Down Expand Up @@ -429,6 +442,7 @@ def __init__(
hooks: list[HookProvider] | None = None,
id: str = _DEFAULT_GRAPH_ID,
trace_attributes: Mapping[str, AttributeValue] | None = None,
plugins: list[MultiAgentPlugin] | None = None,
) -> None:
"""Initialize Graph with execution limits and reset behavior.

Expand All @@ -444,6 +458,7 @@ def __init__(
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
id: Unique graph id (default: None)
trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None)
plugins: List of multi-agent plugins for extending graph behavior (default: None)
"""
super().__init__()

Expand All @@ -469,12 +484,28 @@ def __init__(
for hook in hooks:
self.hooks.add_hook(hook)

self._plugin_registry = _MultiAgentPluginRegistry(self)
if plugins:
for plugin in plugins:
self._plugin_registry.add_and_init(plugin)

self._resume_next_nodes: list[GraphNode] = []
self._resume_from_session = False
self.id = id

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
"""Register a hook callback with the graph.

Args:
callback: The callback function to invoke when events of this type occur.
event_type: The class type(s) of events this callback should handle.
Can be a single type, a list of types, or None to infer from
the callback's first parameter type hint.
"""
self.hooks.add_callback(event_type, callback)

def __call__(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
Expand Down
22 changes: 21 additions & 1 deletion src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
BeforeNodeCallEvent,
MultiAgentInitializedEvent,
)
from ..hooks.registry import HookProvider, HookRegistry
from ..hooks.registry import HookCallback, HookProvider, HookRegistry
from ..interrupt import Interrupt, _InterruptState
from ..plugins.multiagent_plugin import MultiAgentPlugin
from ..plugins.multiagent_registry import _MultiAgentPluginRegistry
from ..session import SessionManager
from ..telemetry import get_tracer
from ..tools.decorator import tool
Expand Down Expand Up @@ -249,6 +251,7 @@ def __init__(
hooks: list[HookProvider] | None = None,
id: str = _DEFAULT_SWARM_ID,
trace_attributes: Mapping[str, AttributeValue] | None = None,
plugins: list[MultiAgentPlugin] | None = None,
) -> None:
"""Initialize Swarm with agents and configuration.

Expand All @@ -267,6 +270,7 @@ def __init__(
session_manager: Session manager for persisting graph state and execution history (default: None)
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None)
plugins: List of multi-agent plugins for extending swarm behavior (default: None)
"""
super().__init__()
self.id = id
Expand Down Expand Up @@ -299,12 +303,28 @@ def __init__(
if self.session_manager:
self.hooks.add_hook(self.session_manager)

self._plugin_registry = _MultiAgentPluginRegistry(self)
if plugins:
for plugin in plugins:
self._plugin_registry.add_and_init(plugin)

self._resume_from_session = False

self._setup_swarm(nodes)
self._inject_swarm_tools()
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
"""Register a hook callback with the swarm.

Args:
callback: The callback function to invoke when events of this type occur.
event_type: The class type(s) of events this callback should handle.
Can be a single type, a list of types, or None to infer from
the callback's first parameter type hint.
"""
self.hooks.add_callback(event_type, callback)

def __call__(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> SwarmResult:
Expand Down
7 changes: 5 additions & 2 deletions src/strands/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Plugin system for extending agent functionality.
"""Plugin system for extending agent and orchestrator functionality.

This module provides a composable mechanism for building objects that can
extend agent behavior through automatic hook and tool registration.
extend agent and multi-agent orchestrator behavior through automatic hook
and tool registration.
"""

from .decorator import hook
from .multiagent_plugin import MultiAgentPlugin
from .plugin import Plugin

__all__ = [
"MultiAgentPlugin",
"Plugin",
"hook",
]
103 changes: 103 additions & 0 deletions src/strands/plugins/_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Shared utility for discovering decorated methods on plugin instances.

This module provides helper functions used by both Plugin and MultiAgentPlugin
to scan for @hook (and optionally @tool) decorated methods, and shared registry
utilities for plugin initialization and hook registration.
"""

import inspect
import logging
from collections.abc import Awaitable, Callable
from typing import Any, cast

from .._async import run_async
from ..hooks.registry import HookCallback
from ..tools.decorator import DecoratedFunctionTool

logger = logging.getLogger(__name__)


def _discover_methods(instance: object, plugin_name: str, predicate: Callable[[object], bool], label: str) -> list[Any]:
"""Scan an instance's class hierarchy for methods matching a predicate.

Walks the MRO in reverse so parent class methods come first, but child
overrides win (only the child's version is included).

Args:
instance: The plugin instance to scan.
plugin_name: The plugin name (used for debug logging).
predicate: Function that returns True for attributes to collect.
label: Label for debug logging (e.g., "hook", "tool").

Returns:
List of matching bound methods/descriptors in declaration order.
"""
results: list[Any] = []
seen: set[str] = set()

for cls in reversed(type(instance).__mro__):
for attr_name in cls.__dict__:
if attr_name in seen:
continue
seen.add(attr_name)

try:
bound = getattr(instance, attr_name)
except Exception:
continue

if predicate(bound):
results.append(bound)
logger.debug("plugin=<%s>, %s=<%s> | discovered", plugin_name, label, attr_name)

return results


def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]:
"""Scan an instance's class hierarchy for @hook decorated methods.

Args:
instance: The plugin instance to scan.
plugin_name: The plugin name (used for debug logging).

Returns:
List of bound hook callback methods in declaration order.
"""
return _discover_methods(
instance,
plugin_name,
predicate=lambda bound: hasattr(bound, "_hook_event_types") and callable(bound),
label="hook",
)


def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]:
"""Scan an instance's class hierarchy for @tool decorated methods.

Args:
instance: The plugin instance to scan.
plugin_name: The plugin name (used for debug logging).

Returns:
List of DecoratedFunctionTool instances in declaration order.
"""
return _discover_methods(
instance,
plugin_name,
predicate=lambda bound: isinstance(bound, DecoratedFunctionTool),
label="tool",
)


def call_init_method(init_method: Callable[..., Any], target: Any) -> None:
"""Call a plugin's init method, handling both sync and async implementations.

Args:
init_method: The init_agent or init_multi_agent method to call.
target: The agent or orchestrator instance to pass to the init method.
"""
if inspect.iscoroutinefunction(init_method):
async_init = cast(Callable[..., Awaitable[None]], init_method)
run_async(lambda: async_init(target))
else:
init_method(target)
Loading
Loading