From 383cadcddbf0e1398ee3e075479584722ab1036a Mon Sep 17 00:00:00 2001 From: Containerized Agent Date: Fri, 13 Feb 2026 19:57:54 +0000 Subject: [PATCH 1/2] feat(plugins): Add Plugin base class with auto-discovery Introduces the Plugin abstraction for the Strands Agents SDK that allows bundling @tool and @hook decorated methods into a single reusable unit. Key changes: - New src/strands/plugins/ package with Plugin base class - Plugin auto-discovers @tool and @hook decorated methods on subclasses - Plugin.tools and Plugin.hooks are filterable lists - Plugin.init_plugin(agent) lifecycle callback for post-registration setup - Agent.__init__ accepts plugins=[] parameter to register plugin tools/hooks - Includes @hook decorator (from PR #1581) for decorator-based hooks - Exports Plugin and hook from top-level strands package Usage: class MyPlugin(Plugin): name = 'my-plugin' @tool def my_tool(self, x: str) -> str: return x @hook def on_invoke(self, event: BeforeInvocationEvent) -> None: pass agent = Agent(plugins=[MyPlugin()]) --- src/strands/__init__.py | 4 + src/strands/agent/agent.py | 18 + src/strands/hooks/__init__.py | 27 +- src/strands/hooks/decorator.py | 286 ++++++++++++++++ src/strands/plugins/__init__.py | 13 + src/strands/plugins/plugin.py | 141 ++++++++ tests/strands/plugins/__init__.py | 0 tests/strands/plugins/test_plugin.py | 296 +++++++++++++++++ .../plugins/test_plugin_agent_integration.py | 309 ++++++++++++++++++ 9 files changed, 1090 insertions(+), 4 deletions(-) create mode 100644 src/strands/hooks/decorator.py create mode 100644 src/strands/plugins/__init__.py create mode 100644 src/strands/plugins/plugin.py create mode 100644 tests/strands/plugins/__init__.py create mode 100644 tests/strands/plugins/test_plugin.py create mode 100644 tests/strands/plugins/test_plugin_agent_integration.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 6026d4240..4c54ae0f7 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,6 +4,8 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy +from .hooks.decorator import hook +from .plugins.plugin import Plugin from .tools.decorator import tool from .types.tools import ToolContext @@ -11,8 +13,10 @@ "Agent", "AgentBase", "agent", + "hook", "models", "ModelRetryStrategy", + "Plugin", "tool", "ToolContext", "types", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 567a92b4a..7c76313f9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -31,6 +31,7 @@ from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: + from ..plugins.plugin import Plugin from ..tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( @@ -129,6 +130,7 @@ def __init__( structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, + plugins: "list[Plugin] | None" = None, ): """Initialize the Agent with the specified configuration. @@ -186,6 +188,12 @@ def __init__( retry_strategy: Strategy for retrying model calls on throttling or other transient errors. Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. Implement a custom HookProvider for custom retry logic, or pass None to disable retries. + plugins: List of Plugin instances to register with the agent. + Each plugin can contribute tools and hooks via ``@tool`` and ``@hook`` + decorated methods, which are auto-discovered and registered during + agent initialization. An optional ``init_plugin(agent)`` callback is + invoked after registration for any additional setup. + Defaults to None. Raises: ValueError: If agent id contains path separators. @@ -302,6 +310,16 @@ def __init__( if hooks: for hook in hooks: self.hooks.add_hook(hook) + + # Process plugins: register discovered tools, hooks, and call init_plugin + if plugins: + for plugin in plugins: + for plugin_tool in plugin.tools: + self.tool_registry.register_tool(plugin_tool) + for plugin_hook in plugin.hooks: + self.hooks.add_hook(plugin_hook) + plugin.init_plugin(self) + self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) @property diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..ff8494a64 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -5,7 +5,7 @@ built-in SDK components and user code to react to or modify agent behavior through strongly-typed event callbacks. -Example Usage: +Example Usage with Class-Based Hooks: ```python from strands.hooks import HookProvider, HookRegistry from strands.hooks.events import BeforeInvocationEvent, AfterInvocationEvent @@ -25,10 +25,24 @@ def log_end(self, event: AfterInvocationEvent) -> None: agent = Agent(hooks=[LoggingHooks()]) ``` -This replaces the older callback_handler approach with a more composable, -type-safe system that supports multiple subscribers per event type. +Example Usage with Decorator-Based Hooks: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` + +This module supports both the class-based HookProvider approach and the newer +decorator-based @hook approach for maximum flexibility. """ +from .decorator import DecoratedFunctionHook, hook from .events import ( AfterInvocationEvent, AfterModelCallEvent, @@ -48,6 +62,10 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ + # Decorator + "hook", + "DecoratedFunctionHook", + # Events "AgentInitializedEvent", "BeforeInvocationEvent", "BeforeToolCallEvent", @@ -56,12 +74,13 @@ def log_end(self, event: AfterInvocationEvent) -> None: "AfterModelCallEvent", "AfterInvocationEvent", "MessageAddedEvent", + # Registry "HookEvent", "HookProvider", "HookCallback", "HookRegistry", - "HookEvent", "BaseHookEvent", + # Multi-agent events "AfterMultiAgentInvocationEvent", "AfterNodeCallEvent", "BeforeMultiAgentInvocationEvent", diff --git a/src/strands/hooks/decorator.py b/src/strands/hooks/decorator.py new file mode 100644 index 000000000..22643b233 --- /dev/null +++ b/src/strands/hooks/decorator.py @@ -0,0 +1,286 @@ +"""Hook decorator for defining hooks as functions. + +This module provides the @hook decorator that transforms Python functions into +HookProvider implementations with automatic event type detection from type hints. + +Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` +""" + +import functools +import inspect +import types +from collections.abc import Callable +from dataclasses import dataclass +from typing import ( + Any, + Generic, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) + +from .registry import BaseHookEvent, HookCallback, HookProvider, HookRegistry + +TEvent = TypeVar("TEvent", bound=BaseHookEvent) + + +@dataclass +class HookMetadata: + """Metadata extracted from a decorated hook function. + + Attributes: + name: The name of the hook function. + description: Description extracted from the function's docstring. + event_types: List of event types this hook handles. + is_async: Whether the hook function is async. + """ + + name: str + description: str + event_types: list[type[BaseHookEvent]] + is_async: bool + + +class FunctionHookMetadata: + """Helper class to extract and manage function metadata for hook decoration.""" + + def __init__( + self, + func: Callable[..., Any], + ) -> None: + """Initialize with the function to process. + + Args: + func: The function to extract metadata from. + """ + self.func = func + self.signature = inspect.signature(func) + + # Validate and extract event types + self._event_types = self._resolve_event_types() + self._validate_event_types() + + def _resolve_event_types(self) -> list[type[BaseHookEvent]]: + """Resolve event types from type hints. + + Returns: + List of event types this hook handles. + + Raises: + ValueError: If no event type can be determined. + """ + # Try to extract from type hints + try: + type_hints = get_type_hints(self.func) + except Exception: + # get_type_hints can fail for various reasons (forward refs, etc.) + type_hints = {} + + # Find the first parameter's type hint (should be the event) + # Skip 'self' and 'cls' for class methods + params = list(self.signature.parameters.values()) + event_params = [p for p in params if p.name not in ("self", "cls")] + + if not event_params: + raise ValueError( + f"Hook function '{self.func.__name__}' must have at least one parameter for the event with a type hint." + ) + + first_param = event_params[0] + event_type = type_hints.get(first_param.name) + + if event_type is None: + # Check annotation directly (for cases where get_type_hints fails) + if first_param.annotation is not inspect.Parameter.empty: + event_type = first_param.annotation + else: + raise ValueError(f"Hook function '{self.func.__name__}' must have a type hint for the event parameter.") + + # Handle Union types (e.g., BeforeToolCallEvent | AfterToolCallEvent) + return self._extract_event_types_from_annotation(event_type) + + def _is_union_type(self, annotation: Any) -> bool: + """Check if annotation is a Union type (typing.Union or types.UnionType).""" + origin = get_origin(annotation) + if origin is Union: + return True + + # Python 3.10+ uses types.UnionType for `A | B` syntax + if isinstance(annotation, types.UnionType): + return True + + return False + + def _extract_event_types_from_annotation(self, annotation: Any) -> list[type[BaseHookEvent]]: + """Extract event types from a type annotation.""" + # Handle Union types (Union[A, B] or A | B) + if self._is_union_type(annotation): + args = get_args(annotation) + event_types = [] + for arg in args: + # Skip NoneType in Optional[X] + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, BaseHookEvent): + event_types.append(arg) + else: + raise ValueError(f"All types in Union must be subclasses of BaseHookEvent, got {arg}") + return event_types + + # Single type + if isinstance(annotation, type) and issubclass(annotation, BaseHookEvent): + return [annotation] + + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {annotation}") + + def _validate_event_types(self) -> None: + """Validate that all event types are valid.""" + if not self._event_types: + raise ValueError(f"Hook function '{self.func.__name__}' must handle at least one event type.") + + for event_type in self._event_types: + if not isinstance(event_type, type) or not issubclass(event_type, BaseHookEvent): + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {event_type}") + + def extract_metadata(self) -> HookMetadata: + """Extract metadata from the function to create hook specification.""" + return HookMetadata( + name=self.func.__name__, + description=inspect.getdoc(self.func) or self.func.__name__, + event_types=self._event_types, + is_async=inspect.iscoroutinefunction(self.func), + ) + + @property + def event_types(self) -> list[type[BaseHookEvent]]: + """Get the event types this hook handles.""" + return self._event_types + + +class DecoratedFunctionHook(HookProvider, Generic[TEvent]): + """A HookProvider that wraps a function decorated with @hook.""" + + _func: Callable[[TEvent], Any] + _metadata: FunctionHookMetadata + _hook_metadata: HookMetadata + + def __init__( + self, + func: Callable[[TEvent], Any], + metadata: FunctionHookMetadata, + ): + """Initialize the decorated function hook. + + Args: + func: The original function being decorated. + metadata: The FunctionHookMetadata object with extracted function information. + """ + self._func = func + self._metadata = metadata + self._hook_metadata = metadata.extract_metadata() + + # Preserve function metadata + functools.update_wrapper(wrapper=self, wrapped=self._func) + + def __get__(self, instance: Any, obj_type: type[Any] | None = None) -> "DecoratedFunctionHook[TEvent]": + """Descriptor protocol implementation for proper method binding.""" + if instance is not None and not inspect.ismethod(self._func): + # Create a bound method + bound_func = self._func.__get__(instance, instance.__class__) + return DecoratedFunctionHook(bound_func, self._metadata) + + return self + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callback functions for specific event types.""" + callback = cast(HookCallback[BaseHookEvent], self._func) + for event_type in self._metadata.event_types: + registry.add_callback(event_type, callback) + + def __call__(self, event: TEvent) -> Any: + """Allow direct invocation for testing.""" + return self._func(event) + + @property + def name(self) -> str: + """Get the name of the hook.""" + return self._hook_metadata.name + + @property + def description(self) -> str: + """Get the description of the hook.""" + return self._hook_metadata.description + + @property + def event_types(self) -> list[type[BaseHookEvent]]: + """Get the event types this hook handles.""" + return self._hook_metadata.event_types + + @property + def is_async(self) -> bool: + """Check if this hook is async.""" + return self._hook_metadata.is_async + + def __repr__(self) -> str: + """Return a string representation of the hook.""" + event_names = [e.__name__ for e in self._hook_metadata.event_types] + return f"DecoratedFunctionHook({self._hook_metadata.name}, events={event_names})" + + +# Type variable for the decorated function +F = TypeVar("F", bound=Callable[..., Any]) + + +def hook( + func: F | None = None, +) -> DecoratedFunctionHook[Any] | Callable[[F], DecoratedFunctionHook[Any]]: + """Decorator that transforms a function into a HookProvider. + + The decorated function can be passed directly to Agent(hooks=[...]). + Event types are automatically detected from the function's type hints. + + Args: + func: The function to decorate. + + Returns: + A DecoratedFunctionHook that implements HookProvider. + + Raises: + ValueError: If no event type can be determined from type hints. + ValueError: If event types are not subclasses of BaseHookEvent. + + Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` + """ + + def decorator(f: F) -> DecoratedFunctionHook[Any]: + hook_meta = FunctionHookMetadata(f) + return DecoratedFunctionHook(f, hook_meta) + + if func is None: + return decorator + + return decorator(func) diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py new file mode 100644 index 000000000..e3ed8b781 --- /dev/null +++ b/src/strands/plugins/__init__.py @@ -0,0 +1,13 @@ +"""Plugin system for Strands Agents SDK. + +This package exposes the :class:`Plugin` base class that allows tool and hook +methods to be bundled together and registered with an agent in one step. + +See :mod:`strands.plugins.plugin` for full documentation and examples. +""" + +from .plugin import Plugin + +__all__ = [ + "Plugin", +] diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py new file mode 100644 index 000000000..8a3705608 --- /dev/null +++ b/src/strands/plugins/plugin.py @@ -0,0 +1,141 @@ +"""Plugin base class for Strands Agents SDK. + +This module provides the ``Plugin`` base class that enables self-contained bundles +of tools and hooks to be registered with an agent in a single step. Decorated +methods are auto-discovered so that plugin authors only need to apply the +standard ``@tool`` and ``@hook`` decorators to their methods. + +Example: + ```python + from strands import Agent, Plugin, tool + from strands.hooks import BeforeInvocationEvent, hook + + class GreeterPlugin(Plugin): + name = "greeter" + + @tool + def greet(self, who: str) -> str: + '''Say hello.''' + return f"Hello, {who}!" + + @hook + def log_invocation(self, event: BeforeInvocationEvent) -> None: + '''Log every invocation.''' + print("Invocation starting") + + agent = Agent(plugins=[GreeterPlugin()]) + ``` +""" + +import logging +from typing import TYPE_CHECKING, Any + +from ..hooks.decorator import DecoratedFunctionHook +from ..hooks.registry import HookProvider +from ..tools.decorator import DecoratedFunctionTool + +if TYPE_CHECKING: + from ..agent.agent import Agent + +logger = logging.getLogger(__name__) + + +class Plugin: + """Base class for agent plugins with auto-discovery of ``@tool`` and ``@hook`` methods. + + Subclasses declare tools and hooks by decorating methods with ``@tool`` and + ``@hook``. When a ``Plugin`` instance is constructed, it scans its own + methods and collects the decorated ones into the :pyattr:`tools` and + :pyattr:`hooks` lists. These lists can be filtered or replaced before + the plugin is handed to an ``Agent``. + + Attributes: + name: A human-readable identifier for the plugin. Subclasses should + override this with a meaningful value. + """ + + name: str = "" + + def __init__(self) -> None: + """Initialize the plugin and auto-discover ``@tool`` / ``@hook`` methods.""" + self._tools: list[DecoratedFunctionTool[..., Any]] = [] + self._hooks: list[HookProvider] = [] + self._discover_tools_and_hooks() + + # -- public properties --------------------------------------------------- + + @property + def tools(self) -> list[DecoratedFunctionTool[..., Any]]: + """The list of auto-discovered (or manually set) tools for this plugin.""" + return self._tools + + @tools.setter + def tools(self, value: list[DecoratedFunctionTool[..., Any]]) -> None: + self._tools = value + + @property + def hooks(self) -> list[HookProvider]: + """The list of auto-discovered (or manually set) hooks for this plugin.""" + return self._hooks + + @hooks.setter + def hooks(self, value: list[HookProvider]) -> None: + self._hooks = value + + # -- lifecycle ----------------------------------------------------------- + + def init_plugin(self, agent: "Agent") -> None: + """Optional hook called after the plugin's tools and hooks are registered. + + Override this method to perform additional setup that requires the + fully-constructed agent (e.g. mutating ``agent.system_prompt``). + + Args: + agent: The agent that this plugin has been registered with. + """ + + # -- internals ----------------------------------------------------------- + + def _discover_tools_and_hooks(self) -> None: + """Scan instance methods for ``@tool`` and ``@hook`` decorators. + + The scan iterates over the *class* attributes (via ``type(self)``) so + that descriptors are seen in their raw form. The *instance* attribute + is then fetched to obtain a properly bound version. + """ + seen: set[str] = set() + for cls in type(self).__mro__: + for attr_name, cls_attr in vars(cls).items(): + if attr_name.startswith("_") or attr_name in seen: + continue + seen.add(attr_name) + + if isinstance(cls_attr, DecoratedFunctionTool): + # Accessing through the instance triggers __get__ which + # returns a properly bound DecoratedFunctionTool. + bound_tool: DecoratedFunctionTool[..., Any] = getattr(self, attr_name) + self._tools.append(bound_tool) + logger.debug( + "plugin=%s | discovered tool: %s", + self.name or type(self).__name__, + bound_tool.tool_name, + ) + + elif isinstance(cls_attr, DecoratedFunctionHook): + # Accessing through the instance triggers __get__ which + # returns a properly bound DecoratedFunctionHook. + bound_hook: DecoratedFunctionHook[Any] = getattr(self, attr_name) + self._hooks.append(bound_hook) + logger.debug( + "plugin=%s | discovered hook: %s", + self.name or type(self).__name__, + bound_hook.name, + ) + + def __repr__(self) -> str: + """Return a developer-friendly representation.""" + return ( + f"Plugin(name={self.name!r}, " + f"tools=[{', '.join(t.tool_name for t in self._tools)}], " + f"hooks=[{', '.join(h.name if hasattr(h, 'name') else type(h).__name__ for h in self._hooks)}])" + ) diff --git a/tests/strands/plugins/__init__.py b/tests/strands/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/plugins/test_plugin.py b/tests/strands/plugins/test_plugin.py new file mode 100644 index 000000000..217016bc0 --- /dev/null +++ b/tests/strands/plugins/test_plugin.py @@ -0,0 +1,296 @@ +"""Tests for the Plugin base class auto-discovery and properties.""" + +import unittest.mock + +from strands.hooks import BeforeInvocationEvent, BeforeToolCallEvent +from strands.hooks.decorator import DecoratedFunctionHook, hook +from strands.plugins.plugin import Plugin +from strands.tools.decorator import DecoratedFunctionTool, tool + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +class SimplePlugin(Plugin): + """Plugin with one tool and one hook.""" + + name = "simple" + + @tool + def greet(self, who: str) -> str: + """Say hello.""" + return f"Hello, {who}!" + + @hook + def on_invocation(self, event: BeforeInvocationEvent) -> None: + """Track invocations.""" + + +class ToolOnlyPlugin(Plugin): + """Plugin with tools but no hooks.""" + + name = "tool-only" + + @tool + def add(self, a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + @tool + def multiply(self, a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + +class HookOnlyPlugin(Plugin): + """Plugin with hooks but no tools.""" + + name = "hook-only" + + @hook + def before_invocation(self, event: BeforeInvocationEvent) -> None: + """Before invocation hook.""" + + @hook + def before_tool_call(self, event: BeforeToolCallEvent) -> None: + """Before tool call hook.""" + + +class EmptyPlugin(Plugin): + """Plugin with nothing decorated.""" + + name = "empty" + + def regular_method(self) -> str: + return "not a tool" + + +class MultiEventHookPlugin(Plugin): + """Plugin with a hook that listens to multiple event types.""" + + name = "multi-event" + + @hook + def on_event(self, event: BeforeInvocationEvent | BeforeToolCallEvent) -> None: + """Handle multiple events.""" + + +class InheritedPlugin(SimplePlugin): + """Plugin that inherits from SimplePlugin and adds its own tool.""" + + name = "inherited" + + @tool + def farewell(self, who: str) -> str: + """Say goodbye.""" + return f"Goodbye, {who}!" + + +class InitPlugin(Plugin): + """Plugin that overrides init_plugin.""" + + name = "init-test" + initialized_with: "unittest.mock.Mock | None" = None + + def init_plugin(self, agent: "unittest.mock.Mock") -> None: # type: ignore[override] + self.initialized_with = agent + + +# --------------------------------------------------------------------------- +# Tests: auto-discovery +# --------------------------------------------------------------------------- + + +class TestPluginAutoDiscovery: + """Test that @tool and @hook methods are discovered correctly.""" + + def test_discovers_tools(self) -> None: + plugin = SimplePlugin() + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "greet" + + def test_discovers_hooks(self) -> None: + plugin = SimplePlugin() + assert len(plugin.hooks) == 1 + hook_provider = plugin.hooks[0] + assert isinstance(hook_provider, DecoratedFunctionHook) + assert hook_provider.name == "on_invocation" + + def test_tool_only_plugin(self) -> None: + plugin = ToolOnlyPlugin() + assert len(plugin.tools) == 2 + tool_names = {t.tool_name for t in plugin.tools} + assert tool_names == {"add", "multiply"} + assert len(plugin.hooks) == 0 + + def test_hook_only_plugin(self) -> None: + plugin = HookOnlyPlugin() + assert len(plugin.tools) == 0 + assert len(plugin.hooks) == 2 + hook_names = {h.name for h in plugin.hooks} + assert hook_names == {"before_invocation", "before_tool_call"} + + def test_empty_plugin(self) -> None: + plugin = EmptyPlugin() + assert len(plugin.tools) == 0 + assert len(plugin.hooks) == 0 + + def test_multi_event_hook(self) -> None: + plugin = MultiEventHookPlugin() + assert len(plugin.hooks) == 1 + hook_provider = plugin.hooks[0] + assert isinstance(hook_provider, DecoratedFunctionHook) + assert set(hook_provider.event_types) == {BeforeInvocationEvent, BeforeToolCallEvent} + + def test_inherited_plugin_discovers_parent_and_child(self) -> None: + plugin = InheritedPlugin() + tool_names = {t.tool_name for t in plugin.tools} + assert "greet" in tool_names + assert "farewell" in tool_names + + hook_names = {h.name for h in plugin.hooks} + assert "on_invocation" in hook_names + + +# --------------------------------------------------------------------------- +# Tests: tools are properly bound +# --------------------------------------------------------------------------- + + +class TestPluginToolBinding: + """Test that discovered tools are properly bound to the plugin instance.""" + + def test_tool_is_callable(self) -> None: + plugin = SimplePlugin() + result = plugin.tools[0]("world") + assert result == "Hello, world!" + + def test_each_instance_gets_its_own_tools(self) -> None: + p1 = SimplePlugin() + p2 = SimplePlugin() + # Each instance should have distinct tool lists + assert p1.tools is not p2.tools + + def test_tool_is_decorated_function_tool(self) -> None: + plugin = SimplePlugin() + assert isinstance(plugin.tools[0], DecoratedFunctionTool) + + +# --------------------------------------------------------------------------- +# Tests: hooks are properly bound +# --------------------------------------------------------------------------- + + +class TestPluginHookBinding: + """Test that discovered hooks are properly bound to the plugin instance.""" + + def test_hook_is_hook_provider(self) -> None: + plugin = SimplePlugin() + from strands.hooks.registry import HookProvider + + assert isinstance(plugin.hooks[0], HookProvider) + + def test_hook_is_decorated_function_hook(self) -> None: + plugin = SimplePlugin() + assert isinstance(plugin.hooks[0], DecoratedFunctionHook) + + +# --------------------------------------------------------------------------- +# Tests: property setters (filtering) +# --------------------------------------------------------------------------- + + +class TestPluginPropertySetters: + """Test that tools and hooks lists can be replaced for filtering.""" + + def test_tools_setter(self) -> None: + plugin = ToolOnlyPlugin() + assert len(plugin.tools) == 2 + + # Filter down to just 'add' + plugin.tools = [t for t in plugin.tools if t.tool_name == "add"] + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "add" + + def test_hooks_setter(self) -> None: + plugin = HookOnlyPlugin() + assert len(plugin.hooks) == 2 + + # Filter down to just 'before_invocation' + plugin.hooks = [h for h in plugin.hooks if h.name == "before_invocation"] + assert len(plugin.hooks) == 1 + + def test_tools_setter_empty(self) -> None: + plugin = SimplePlugin() + plugin.tools = [] + assert len(plugin.tools) == 0 + + def test_hooks_setter_empty(self) -> None: + plugin = SimplePlugin() + plugin.hooks = [] + assert len(plugin.hooks) == 0 + + +# --------------------------------------------------------------------------- +# Tests: init_plugin +# --------------------------------------------------------------------------- + + +class TestPluginInit: + """Test the init_plugin lifecycle callback.""" + + def test_init_plugin_default_is_noop(self) -> None: + plugin = SimplePlugin() + # Should not raise + plugin.init_plugin(unittest.mock.Mock()) + + def test_init_plugin_receives_agent(self) -> None: + plugin = InitPlugin() + mock_agent = unittest.mock.Mock() + plugin.init_plugin(mock_agent) + assert plugin.initialized_with is mock_agent + + +# --------------------------------------------------------------------------- +# Tests: name attribute +# --------------------------------------------------------------------------- + + +class TestPluginName: + """Test the name attribute.""" + + def test_name_from_class_attribute(self) -> None: + plugin = SimplePlugin() + assert plugin.name == "simple" + + def test_default_name_is_empty(self) -> None: + class Unnamed(Plugin): + pass + + plugin = Unnamed() + assert plugin.name == "" + + +# --------------------------------------------------------------------------- +# Tests: repr +# --------------------------------------------------------------------------- + + +class TestPluginRepr: + """Test the __repr__ method.""" + + def test_repr_includes_name(self) -> None: + plugin = SimplePlugin() + r = repr(plugin) + assert "simple" in r + + def test_repr_includes_tool_names(self) -> None: + plugin = SimplePlugin() + r = repr(plugin) + assert "greet" in r + + def test_repr_includes_hook_names(self) -> None: + plugin = SimplePlugin() + r = repr(plugin) + assert "on_invocation" in r diff --git a/tests/strands/plugins/test_plugin_agent_integration.py b/tests/strands/plugins/test_plugin_agent_integration.py new file mode 100644 index 000000000..b75a86a1d --- /dev/null +++ b/tests/strands/plugins/test_plugin_agent_integration.py @@ -0,0 +1,309 @@ +"""Integration tests for Plugin registration with Agent.""" + +import unittest.mock +from typing import Any + +from strands.agent.agent import Agent +from strands.hooks import BeforeInvocationEvent, BeforeToolCallEvent +from strands.hooks.decorator import hook +from strands.plugins.plugin import Plugin +from strands.tools.decorator import tool + +# --------------------------------------------------------------------------- +# Test plugins +# --------------------------------------------------------------------------- + + +class MathPlugin(Plugin): + """Plugin that provides math tools.""" + + name = "math" + + @tool + def add(self, a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + @tool + def subtract(self, a: int, b: int) -> int: + """Subtract b from a.""" + return a - b + + +class LoggingPlugin(Plugin): + """Plugin that provides logging hooks.""" + + name = "logging" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + @hook + def log_invocation(self, event: BeforeInvocationEvent) -> None: + """Log invocations.""" + self.invocations.append("invocation_started") + + +class FullPlugin(Plugin): + """Plugin with both tools and hooks.""" + + name = "full" + + def __init__(self) -> None: + super().__init__() + self.init_called = False + + @tool + def echo(self, text: str) -> str: + """Echo the input.""" + return text + + @hook + def before_invocation(self, event: BeforeInvocationEvent) -> None: + """Before invocation.""" + + def init_plugin(self, agent: Any) -> None: + self.init_called = True + + +class SystemPromptPlugin(Plugin): + """Plugin that modifies agent system prompt during init.""" + + name = "system-prompt" + + def init_plugin(self, agent: Any) -> None: + current = agent.system_prompt or "" + agent.system_prompt = current + "\nYou are also a helpful math tutor." + + +# --------------------------------------------------------------------------- +# Tests: tool registration +# --------------------------------------------------------------------------- + + +class TestPluginToolRegistration: + """Test that plugin tools are registered with the agent.""" + + def test_plugin_tools_registered(self) -> None: + plugin = MathPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + assert "add" in agent.tool_names + assert "subtract" in agent.tool_names + + def test_plugin_tools_combined_with_agent_tools(self) -> None: + @tool + def standalone_tool(x: str) -> str: + """A standalone tool.""" + return x + + plugin = MathPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + tools=[standalone_tool], + plugins=[plugin], + ) + assert "standalone_tool" in agent.tool_names + assert "add" in agent.tool_names + assert "subtract" in agent.tool_names + + def test_multiple_plugins_tools_registered(self) -> None: + plugin1 = MathPlugin() + + class StringPlugin(Plugin): + name = "string" + + @tool + def upper(self, text: str) -> str: + """Uppercase text.""" + return text.upper() + + plugin2 = StringPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin1, plugin2], + ) + assert "add" in agent.tool_names + assert "subtract" in agent.tool_names + assert "upper" in agent.tool_names + + +# --------------------------------------------------------------------------- +# Tests: hook registration +# --------------------------------------------------------------------------- + + +class TestPluginHookRegistration: + """Test that plugin hooks are registered with the agent.""" + + def test_plugin_hooks_registered(self) -> None: + plugin = LoggingPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + # Verify the hook is in the registry by checking callbacks exist for the event type + callbacks = list(agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) + # At least one callback should be from our plugin + assert any(True for _ in callbacks), "Expected at least one BeforeInvocationEvent callback" + + def test_plugin_hooks_combined_with_agent_hooks(self) -> None: + @hook + def standalone_hook(event: BeforeToolCallEvent) -> None: + """A standalone hook.""" + + plugin = LoggingPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + hooks=[standalone_hook], + plugins=[plugin], + ) + before_invocation_cbs = list(agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) + before_tool_cbs = list(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) + assert len(before_invocation_cbs) >= 1 + assert len(before_tool_cbs) >= 1 + + +# --------------------------------------------------------------------------- +# Tests: init_plugin +# --------------------------------------------------------------------------- + + +class TestPluginInitCallback: + """Test that init_plugin is called during agent initialization.""" + + def test_init_plugin_called(self) -> None: + plugin = FullPlugin() + assert not plugin.init_called + Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + assert plugin.init_called + + def test_init_plugin_can_modify_agent(self) -> None: + plugin = SystemPromptPlugin() + agent = Agent( + model=unittest.mock.MagicMock(), + system_prompt="You are helpful.", + plugins=[plugin], + ) + assert "math tutor" in agent.system_prompt + + def test_init_plugin_called_for_each_plugin(self) -> None: + init_calls: list[str] = [] + + class P1(Plugin): + name = "p1" + + def init_plugin(self, agent: Any) -> None: + init_calls.append("p1") + + class P2(Plugin): + name = "p2" + + def init_plugin(self, agent: Any) -> None: + init_calls.append("p2") + + Agent( + model=unittest.mock.MagicMock(), + plugins=[P1(), P2()], + ) + assert init_calls == ["p1", "p2"] + + +# --------------------------------------------------------------------------- +# Tests: filtering before passing to agent +# --------------------------------------------------------------------------- + + +class TestPluginFiltering: + """Test that plugin tools/hooks can be filtered before agent creation.""" + + def test_filter_tools_before_agent(self) -> None: + plugin = MathPlugin() + # Keep only 'add' + plugin.tools = [t for t in plugin.tools if t.tool_name == "add"] + + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + assert "add" in agent.tool_names + assert "subtract" not in agent.tool_names + + def test_remove_all_tools_before_agent(self) -> None: + plugin = MathPlugin() + plugin.tools = [] + + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + # None of the plugin tools should be registered + assert "add" not in agent.tool_names + assert "subtract" not in agent.tool_names + + def test_filter_hooks_before_agent(self) -> None: + plugin = LoggingPlugin() + plugin.hooks = [] + + # Hooks list was cleared so no plugin-specific callbacks added + # (there may be built-in callbacks from conversation_manager, retry, etc.) + # Just verify it doesn't crash + Agent( + model=unittest.mock.MagicMock(), + plugins=[plugin], + ) + + +# --------------------------------------------------------------------------- +# Tests: no plugins +# --------------------------------------------------------------------------- + + +class TestNoPlugins: + """Test that agent works normally when no plugins are provided.""" + + def test_none_plugins(self) -> None: + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=None, + ) + # Should work fine, no extra tools + assert agent.tool_names is not None + + def test_empty_plugins_list(self) -> None: + agent = Agent( + model=unittest.mock.MagicMock(), + plugins=[], + ) + assert agent.tool_names is not None + + def test_omitted_plugins(self) -> None: + agent = Agent( + model=unittest.mock.MagicMock(), + ) + assert agent.tool_names is not None + + +# --------------------------------------------------------------------------- +# Tests: top-level imports +# --------------------------------------------------------------------------- + + +class TestTopLevelImports: + """Test that Plugin and hook are importable from top-level strands package.""" + + def test_import_plugin(self) -> None: + from strands import Plugin as P + + assert P is Plugin + + def test_import_hook(self) -> None: + from strands import hook as h + + assert h is hook From 9d3af92b0052840ba44139c087a7032c217dada0 Mon Sep 17 00:00:00 2001 From: Containerized Agent Date: Fri, 13 Feb 2026 20:04:39 +0000 Subject: [PATCH 2/2] feat(plugins): Add **kwargs to init_plugin for forward compatibility Plugin.init_plugin now accepts **kwargs so that future SDK versions can pass additional keyword arguments without breaking existing plugin subclasses. Updated all test overrides to follow the same pattern. --- src/strands/plugins/plugin.py | 7 ++++++- tests/strands/plugins/test_plugin.py | 3 ++- tests/strands/plugins/test_plugin_agent_integration.py | 8 ++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 8a3705608..846930b47 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -84,14 +84,19 @@ def hooks(self, value: list[HookProvider]) -> None: # -- lifecycle ----------------------------------------------------------- - def init_plugin(self, agent: "Agent") -> None: + def init_plugin(self, agent: "Agent", **kwargs: Any) -> None: """Optional hook called after the plugin's tools and hooks are registered. Override this method to perform additional setup that requires the fully-constructed agent (e.g. mutating ``agent.system_prompt``). + Subclasses should always accept ``**kwargs`` so that additional keyword + arguments can be introduced in future SDK versions without breaking + existing plugins. + Args: agent: The agent that this plugin has been registered with. + **kwargs: Reserved for future use. """ # -- internals ----------------------------------------------------------- diff --git a/tests/strands/plugins/test_plugin.py b/tests/strands/plugins/test_plugin.py index 217016bc0..0e5a68c5e 100644 --- a/tests/strands/plugins/test_plugin.py +++ b/tests/strands/plugins/test_plugin.py @@ -1,6 +1,7 @@ """Tests for the Plugin base class auto-discovery and properties.""" import unittest.mock +from typing import Any from strands.hooks import BeforeInvocationEvent, BeforeToolCallEvent from strands.hooks.decorator import DecoratedFunctionHook, hook @@ -93,7 +94,7 @@ class InitPlugin(Plugin): name = "init-test" initialized_with: "unittest.mock.Mock | None" = None - def init_plugin(self, agent: "unittest.mock.Mock") -> None: # type: ignore[override] + def init_plugin(self, agent: "unittest.mock.Mock", **kwargs: Any) -> None: # type: ignore[override] self.initialized_with = agent diff --git a/tests/strands/plugins/test_plugin_agent_integration.py b/tests/strands/plugins/test_plugin_agent_integration.py index b75a86a1d..bfbf4d0fd 100644 --- a/tests/strands/plugins/test_plugin_agent_integration.py +++ b/tests/strands/plugins/test_plugin_agent_integration.py @@ -63,7 +63,7 @@ def echo(self, text: str) -> str: def before_invocation(self, event: BeforeInvocationEvent) -> None: """Before invocation.""" - def init_plugin(self, agent: Any) -> None: + def init_plugin(self, agent: Any, **kwargs: Any) -> None: self.init_called = True @@ -72,7 +72,7 @@ class SystemPromptPlugin(Plugin): name = "system-prompt" - def init_plugin(self, agent: Any) -> None: + def init_plugin(self, agent: Any, **kwargs: Any) -> None: current = agent.system_prompt or "" agent.system_prompt = current + "\nYou are also a helpful math tutor." @@ -199,13 +199,13 @@ def test_init_plugin_called_for_each_plugin(self) -> None: class P1(Plugin): name = "p1" - def init_plugin(self, agent: Any) -> None: + def init_plugin(self, agent: Any, **kwargs: Any) -> None: init_calls.append("p1") class P2(Plugin): name = "p2" - def init_plugin(self, agent: Any) -> None: + def init_plugin(self, agent: Any, **kwargs: Any) -> None: init_calls.append("p2") Agent(