diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index 9dfe7cc752..96ffd2a303 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -14,6 +14,10 @@ from __future__ import annotations +import inspect +from typing import Any +from typing import Callable + from google.genai import types from typing_extensions import override @@ -21,6 +25,7 @@ from .function_tool import FunctionTool from .tool_configs import BaseToolConfig from .tool_configs import ToolArgsConfig +from .tool_context import ToolContext try: from crewai.tools import BaseTool as CrewaiBaseTool @@ -61,6 +66,94 @@ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str): elif tool.description: self.description = tool.description + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Override run_async to handle CrewAI-specific parameter filtering. + + CrewAI tools use **kwargs pattern, so we need special parameter filtering + logic that allows all parameters to pass through while removing only + reserved parameters like 'self' and 'tool_context'. + + Note: 'tool_context' is removed from the initial args dictionary to prevent + duplicates, but is re-added if the function signature explicitly requires it + as a parameter. + """ + # Preprocess arguments (includes Pydantic model conversion) + args_to_call = self._preprocess_args(args) + + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + + # Check if function accepts **kwargs + has_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + + if has_kwargs: + # For functions with **kwargs, we pass all arguments. We defensively + # remove arguments like `self` that are managed by the framework and not + # intended to be passed through **kwargs. + args_to_call.pop('self', None) + # We also remove `tool_context` that might have been passed in `args`, + # as it will be explicitly injected later if it's a valid parameter. + args_to_call.pop('tool_context', None) + else: + # For functions without **kwargs, use the original filtering. + args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} + + # Inject tool_context if it's an explicit parameter. This will add it + # or overwrite any value that might have been passed in `args`. + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + # Check for missing mandatory arguments + mandatory_args = self._get_mandatory_args() + missing_mandatory_args = [ + arg for arg in mandatory_args if arg not in args_to_call + ] + + if missing_mandatory_args: + missing_mandatory_args_str = '\n'.join(missing_mandatory_args) + error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present: +{missing_mandatory_args_str} +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + return {'error': error_str} + + # Handle tool confirmation if required + if isinstance(self._require_confirmation, Callable): + require_confirmation = await self._invoke_callable( + self._require_confirmation, args_to_call + ) + else: + require_confirmation = bool(self._require_confirmation) + + if require_confirmation: + if not tool_context.tool_confirmation: + args_to_show = args_to_call.copy() + if 'tool_context' in args_to_show: + args_to_show.pop('tool_context') + + tool_context.request_confirmation( + hint=( + f'Please approve or reject the tool call {self.name}() by' + ' responding with a FunctionResponse with an expected' + ' ToolConfirmation payload.' + ), + ) + return { + 'error': ( + 'This tool call requires confirmation, please approve or' + ' reject.' + ) + } + elif not tool_context.tool_confirmation.confirmed: + return {'error': 'This tool call is rejected.'} + + return await self._invoke_callable(self.func, args_to_call) + @override def _get_declaration(self) -> types.FunctionDeclaration: """Build the function declaration for the tool.""" diff --git a/tests/unittests/tools/test_crewai_tool.py b/tests/unittests/tools/test_crewai_tool.py new file mode 100644 index 0000000000..112f26515f --- /dev/null +++ b/tests/unittests/tools/test_crewai_tool.py @@ -0,0 +1,173 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.sessions.session import Session +from google.adk.tools.crewai_tool import CrewaiTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +@pytest.fixture +def mock_tool_context() -> ToolContext: + """Fixture that provides a mock ToolContext for testing.""" + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + return ToolContext(invocation_context=mock_invocation_context) + + +def _simple_crewai_tool(*args, **kwargs): + """Simple CrewAI-style tool that accepts any keyword arguments.""" + return { + "search_query": kwargs.get("search_query"), + "other_param": kwargs.get("other_param"), + } + + +def _crewai_tool_with_context(tool_context: ToolContext, *args, **kwargs): + """CrewAI tool with explicit tool_context parameter.""" + return { + "search_query": kwargs.get("search_query"), + "tool_context_present": bool(tool_context), + } + + +class MockCrewaiBaseTool: + """Mock CrewAI BaseTool for testing.""" + + def __init__(self, run_func, name="mock_tool", description="Mock tool"): + self.run = run_func + self.name = name + self.description = description + self.args_schema = MagicMock() + self.args_schema.model_json_schema.return_value = { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "Search query" + } + } + } + + +def test_crewai_tool_initialization(): + """Test CrewaiTool initialization with various parameters.""" + mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool) + + # Test with custom name and description + tool = CrewaiTool( + mock_crewai_tool, + name="custom_search_tool", + description="Custom search tool description" + ) + + assert tool.name == "custom_search_tool" + assert tool.description == "Custom search tool description" + assert tool.tool == mock_crewai_tool + + +def test_crewai_tool_initialization_with_tool_defaults(): + """Test CrewaiTool initialization using tool's default name and description.""" + mock_crewai_tool = MockCrewaiBaseTool( + _simple_crewai_tool, + name="Serper Dev Tool", + description="Search the internet with Serper" + ) + + # Test with empty name and description (should use tool defaults) + tool = CrewaiTool(mock_crewai_tool, name="", description="") + + assert tool.name == "serper_dev_tool" # Spaces replaced with underscores, lowercased + assert tool.description == "Search the internet with Serper" + + +@pytest.mark.asyncio +async def test_crewai_tool_basic_functionality(mock_tool_context): + """Test basic CrewaiTool functionality with **kwargs parameter passing.""" + mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool) + tool = CrewaiTool(mock_crewai_tool, name="test_tool", description="Test tool") + + # Test that **kwargs parameters are passed through correctly + result = await tool.run_async( + args={"search_query": "test query", "other_param": "test value"}, + tool_context=mock_tool_context, + ) + + assert result["search_query"] == "test query" + assert result["other_param"] == "test value" + + +@pytest.mark.asyncio +async def test_crewai_tool_with_tool_context(mock_tool_context): + """Test CrewaiTool with a tool that has explicit tool_context parameter.""" + mock_crewai_tool = MockCrewaiBaseTool(_crewai_tool_with_context) + tool = CrewaiTool(mock_crewai_tool, name="context_tool", description="Context tool") + + # Test that tool_context is properly injected + result = await tool.run_async( + args={"search_query": "test query"}, + tool_context=mock_tool_context, + ) + + assert result["search_query"] == "test query" + assert result["tool_context_present"] is True + + +@pytest.mark.asyncio +async def test_crewai_tool_parameter_filtering(mock_tool_context): + """Test that CrewaiTool filters parameters for non-**kwargs functions.""" + + def explicit_params_func(arg1: str, arg2: int): + """Function with explicit parameters (no **kwargs).""" + return {"arg1": arg1, "arg2": arg2} + + mock_crewai_tool = MockCrewaiBaseTool(explicit_params_func) + tool = CrewaiTool(mock_crewai_tool, name="explicit_tool", description="Explicit tool") + + # Test that unexpected parameters are filtered out + result = await tool.run_async( + args={ + "arg1": "test", + "arg2": 42, + "unexpected_param": "should_be_filtered" + }, + tool_context=mock_tool_context, + ) + + assert result == {"arg1": "test", "arg2": 42} + # Verify unexpected parameter was filtered out + assert "unexpected_param" not in result + + +@pytest.mark.asyncio +async def test_crewai_tool_get_declaration(): + """Test that CrewaiTool properly builds function declarations.""" + mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool) + tool = CrewaiTool(mock_crewai_tool, name="test_tool", description="Test tool") + + # Test function declaration generation + declaration = tool._get_declaration() + + # Verify the declaration object structure and content + assert declaration is not None + assert declaration.name == "test_tool" + assert declaration.description == "Test tool" + assert declaration.parameters is not None + + # Verify that the args_schema was used to build the declaration + mock_crewai_tool.args_schema.model_json_schema.assert_called_once() \ No newline at end of file diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index e7854a2c87..0469b7ac23 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -22,6 +22,17 @@ import pytest +@pytest.fixture +def mock_tool_context() -> ToolContext: + """Fixture that provides a mock ToolContext for testing.""" + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + return ToolContext(invocation_context=mock_invocation_context) + + + + def function_for_testing_with_no_args(): """Function for testing with no args.""" pass @@ -394,3 +405,32 @@ def sample_func(arg1: str): tool_context=tool_context_mock, ) assert result == {"received_arg": "hello"} + + + + +@pytest.mark.asyncio +async def test_run_async_parameter_filtering(mock_tool_context): + """Test that parameter filtering works correctly for functions with explicit parameters.""" + + def explicit_params_func(arg1: str, arg2: int): + """Function with explicit parameters (no **kwargs).""" + return {"arg1": arg1, "arg2": arg2} + + tool = FunctionTool(explicit_params_func) + + # Test that unexpected parameters are still filtered out for non-kwargs functions + result = await tool.run_async( + args={ + "arg1": "test", + "arg2": 42, + "unexpected_param": "should_be_filtered" + }, + tool_context=mock_tool_context, + ) + + assert result == {"arg1": "test", "arg2": 42} + # Explicitly verify that unexpected_param was filtered out and not passed to the function + assert "unexpected_param" not in result + +