diff --git a/pydantic_ai_slim/pydantic_ai/_mcp.py b/pydantic_ai_slim/pydantic_ai/_mcp.py index f00230ee07..1729e4c225 100644 --- a/pydantic_ai_slim/pydantic_ai/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/_mcp.py @@ -91,7 +91,7 @@ def add_msg( 'user', mcp_types.ImageContent( type='image', - data=base64.b64decode(chunk.data).decode(), + data=base64.b64encode(chunk.data).decode(), mimeType=chunk.media_type, ), ) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ac3cfeae5c..227b8e1399 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -36,6 +36,7 @@ from mcp.shared import exceptions as mcp_exceptions from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage + from mcp.shared.session import RequestResponder except ImportError as _import_error: raise ImportError( 'Please install the `mcp` package to use the MCP server, ' @@ -226,12 +227,21 @@ class ServerCapabilities: prompts: bool = False """Whether the server offers any prompt templates.""" + prompts_list_changed: bool = False + """Whether the server will emit notifications when the list of prompts changes.""" + resources: bool = False """Whether the server offers any resources to read.""" + resources_list_changed: bool = False + """Whether the server will emit notifications when the list of resources changes.""" + tools: bool = False """Whether the server offers any tools to call.""" + tools_list_changed: bool = False + """Whether the server will emit notifications when the list of tools changes.""" + completions: bool = False """Whether the server offers autocompletion suggestions for prompts and resources.""" @@ -244,12 +254,18 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC Args: mcp_capabilities: The MCP SDK ServerCapabilities object. """ + prompts_cap = mcp_capabilities.prompts + resources_cap = mcp_capabilities.resources + tools_cap = mcp_capabilities.tools return cls( experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None, logging=mcp_capabilities.logging is not None, - prompts=mcp_capabilities.prompts is not None, - resources=mcp_capabilities.resources is not None, - tools=mcp_capabilities.tools is not None, + prompts=prompts_cap is not None, + prompts_list_changed=bool(prompts_cap.listChanged) if prompts_cap else False, + resources=resources_cap is not None, + resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False, + tools=tools_cap is not None, + tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False, completions=mcp_capabilities.completions is not None, ) @@ -319,6 +335,26 @@ class MCPServer(AbstractToolset[Any], ABC): elicitation_callback: ElicitationFnT | None = None """Callback function to handle elicitation requests from the server.""" + cache_tools: bool + """Whether to cache the list of tools. + + When enabled (default), tools are fetched once and cached until either: + - The server sends a `notifications/tools/list_changed` notification + - The connection is closed + + Set to `False` for servers that change tools dynamically without sending notifications. + """ + + cache_resources: bool + """Whether to cache the list of resources. + + When enabled (default), resources are fetched once and cached until either: + - The server sends a `notifications/resources/list_changed` notification + - The connection is closed + + Set to `False` for servers that change resources dynamically without sending notifications. + """ + _id: str | None _enter_lock: Lock = field(compare=False) @@ -332,6 +368,9 @@ class MCPServer(AbstractToolset[Any], ABC): _server_capabilities: ServerCapabilities _instructions: str | None + _cached_tools: list[mcp_types.Tool] | None + _cached_resources: list[Resource] | None + def __init__( self, tool_prefix: str | None = None, @@ -344,6 +383,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, *, id: str | None = None, ): @@ -357,6 +398,8 @@ def __init__( self.sampling_model = sampling_model self.max_retries = max_retries self.elicitation_callback = elicitation_callback + self.cache_tools = cache_tools + self.cache_resources = cache_resources self._id = id or tool_prefix @@ -366,6 +409,8 @@ def __post_init__(self): self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None + self._cached_tools = None + self._cached_resources = None @abstractmethod @asynccontextmanager @@ -430,13 +475,22 @@ def instructions(self) -> str | None: async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. - Note: - - We don't cache tools as they might change. - - We also don't subscribe to the server to avoid complexity. + Tools are cached by default, with cache invalidation on: + - `notifications/tools/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_tools=False` for servers that change tools without sending notifications. """ - async with self: # Ensure server is running - result = await self._client.list_tools() - return result.tools + async with self: + if self.cache_tools: + if self._cached_tools is not None: + return self._cached_tools + result = await self._client.list_tools() + self._cached_tools = result.tools + return result.tools + else: + result = await self._client.list_tools() + return result.tools async def direct_call_tool( self, @@ -542,21 +596,31 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: async def list_resources(self) -> list[Resource]: """Retrieve resources that are currently present on the server. - Note: - - We don't cache resources as they might change. - - We also don't subscribe to resource changes to avoid complexity. + Resources are cached by default, with cache invalidation on: + - `notifications/resources/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_resources=False` for servers that change resources without sending notifications. Raises: MCPError: If the server returns an error. """ - async with self: # Ensure server is running + async with self: if not self.capabilities.resources: return [] try: - result = await self._client.list_resources() + if self.cache_resources: + if self._cached_resources is not None: + return self._cached_resources + result = await self._client.list_resources() + resources = [Resource.from_mcp_sdk(r) for r in result.resources] + self._cached_resources = resources + return resources + else: + result = await self._client.list_resources() + return [Resource.from_mcp_sdk(r) for r in result.resources] except mcp_exceptions.McpError as e: raise MCPError.from_mcp_sdk(e) from e - return [Resource.from_mcp_sdk(r) for r in result.resources] async def list_resource_templates(self) -> list[ResourceTemplate]: """Retrieve resource templates that are currently present on the server. @@ -628,6 +692,7 @@ async def __aenter__(self) -> Self: elicitation_callback=self.elicitation_callback, logging_callback=self.log_handler, read_timeout_seconds=timedelta(seconds=self.read_timeout), + message_handler=self._handle_notification, ) self._client = await exit_stack.enter_async_context(client) @@ -651,6 +716,8 @@ async def __aexit__(self, *args: Any) -> bool | None: if self._running_count == 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None + self._cached_tools = None + self._cached_resources = None @property def is_running(self) -> bool: @@ -680,6 +747,19 @@ async def _sampling_callback( model=self.sampling_model.model_name, ) + async def _handle_notification( + self, + message: RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] + | mcp_types.ServerNotification + | Exception, + ) -> None: + """Handle notifications from the MCP server, invalidating caches as needed.""" + if isinstance(message, mcp_types.ServerNotification): # pragma: no branch + if isinstance(message.root, mcp_types.ToolListChangedNotification): + self._cached_tools = None + elif isinstance(message.root, mcp_types.ResourceListChangedNotification): + self._cached_resources = None + async def _map_tool_result_part( self, part: mcp_types.ContentBlock ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: @@ -776,6 +856,8 @@ class MCPServerStdio(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -794,6 +876,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, id: str | None = None, ): """Build a new MCP server. @@ -813,6 +897,10 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. + cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -831,6 +919,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) @@ -930,6 +1020,8 @@ class _MCPServerHTTP(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -948,6 +1040,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, **_deprecated_kwargs: Any, ): """Build a new MCP server. @@ -967,6 +1061,10 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. + cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: @@ -997,6 +1095,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 8ba9b9997f..cd042dc76a 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -235,6 +235,19 @@ async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> s return f'User {result.action}ed the elicitation' +async def hidden_tool() -> str: + """A tool that is hidden by default.""" + return 'I was hidden!' + + +@mcp.tool() +async def enable_hidden_tool(ctx: Context[ServerSession, None]) -> str: + """Enable the hidden tool, triggering a ToolListChangedNotification.""" + mcp._tool_manager.add_tool(hidden_tool) # pyright: ignore[reportPrivateUsage] + await ctx.session.send_tool_list_changed() + return 'Hidden tool enabled' + + @mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] async def set_logging_level(level: str) -> None: global log_level diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 221ad37548..d53ad47551 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -57,7 +57,7 @@ TextContent, ) - from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response + from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response, map_from_pai_messages from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import OpenAIChatModel @@ -95,7 +95,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -156,7 +156,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) async def test_process_tool_call(run_context: RunContext[int]) -> int: @@ -1525,6 +1525,49 @@ def test_map_from_mcp_params_model_response(): ) +def test_map_from_pai_messages_with_binary_content(): + """Test that map_from_pai_messages correctly converts image and audio content to MCP format. + + Note: `data` in this case are base64-encoded bytes (e.g., base64.b64encode(b'raw')). + map_from_pai_messages decodes this to get the base64 string for MCP. + """ + + message = ModelRequest( + parts=[ + UserPromptPart(content='text message'), + UserPromptPart(content=[BinaryContent(data=b'raw_image_bytes', media_type='image/png')]), + # TODO uncomment when audio content is supported + # UserPromptPart(content=[BinaryContent(data=b'raw_audio_bytes', media_type='audio/wav'), 'text after audio']), + ] + ) + system_prompt, sampling_msgs = map_from_pai_messages([message]) + assert system_prompt == '' + assert [m.model_dump(by_alias=True) for m in sampling_msgs] == snapshot( + [ + {'role': 'user', 'content': {'type': 'text', 'text': 'text message', 'annotations': None, '_meta': None}}, + { + 'role': 'user', + 'content': { + 'type': 'image', + 'data': 'cmF3X2ltYWdlX2J5dGVz', + 'mimeType': 'image/png', + 'annotations': None, + '_meta': None, + }, + }, + ] + ) + + # Unsupported content type raises NotImplementedError + message_with_video = ModelRequest( + parts=[UserPromptPart(content=[BinaryContent(data=b'raw_video_bytes', media_type='video/mp4')])] + ) + with pytest.raises( + NotImplementedError, match="Unsupported content type: " + ): + map_from_pai_messages([message_with_video]) + + def test_map_from_model_response(): with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'): map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')])) @@ -2007,3 +2050,139 @@ async def test_custom_http_client_not_closed(): assert len(tools) > 0 assert not custom_http_client.is_closed + + +# ============================================================================ +# Tool and Resource Caching Tests +# ============================================================================ + + +async def test_tools_caching_enabled_by_default() -> None: + """Test that list_tools() caches results by default.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + # First call - should fetch from server and cache + tools1 = await server.list_tools() + assert len(tools1) > 0 + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + # Second call - should return cached value (cache is still populated) + tools2 = await server.list_tools() + assert tools2 == tools1 + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + +async def test_tools_no_caching_when_disabled() -> None: + """Test that list_tools() does not cache when cache_tools=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_tools=False) + async with server: + # First call - should not populate cache + tools1 = await server.list_tools() + assert len(tools1) > 0 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + # Second call - cache should still be None + tools2 = await server.list_tools() + assert tools2 == tools1 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + +async def test_tools_cache_invalidation_on_notification() -> None: + """Test that tools cache is invalidated when ToolListChangedNotification is received.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + # Get initial tools - hidden_tool should NOT be present (it's disabled at startup) + tools1 = await server.list_tools() + tool_names1 = [t.name for t in tools1] + assert 'hidden_tool' not in tool_names1 + assert 'enable_hidden_tool' in tool_names1 + + # Enable the hidden tool (server sends ToolListChangedNotification) + await server.direct_call_tool('enable_hidden_tool', {}) + + # Get tools again - hidden_tool should now be present (cache was invalidated) + tools2 = await server.list_tools() + tool_names2 = [t.name for t in tools2] + assert 'hidden_tool' in tool_names2 + + +async def test_resources_caching_enabled_by_default() -> None: + """Test that list_resources() caches results by default.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources + + # First call - should fetch from server and cache + resources1 = await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + # Second call - should return cached value (cache is still populated) + resources2 = await server.list_resources() + assert resources2 == resources1 + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_no_caching_when_disabled() -> None: + """Test that list_resources() does not cache when cache_resources=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_resources=False) + async with server: + assert server.capabilities.resources + + # First call - should not populate cache + resources1 = await server.list_resources() + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + # Second call - cache should still be None + resources2 = await server.list_resources() + assert resources2 == resources1 + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_cache_invalidation_on_notification() -> None: + """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" + from mcp.types import ResourceListChangedNotification, ServerNotification + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources + + # Populate cache + await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a resource list changed notification + notification = ServerNotification(ResourceListChangedNotification()) + await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + +async def test_cache_cleared_on_connection_close() -> None: + """Test that caches are cleared when the connection is closed.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + # First connection + async with server: + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + # After exiting, cache should be cleared by __aexit__ + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + # Reconnect and verify cache starts empty + async with server: + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + # Fetch again to populate + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + +async def test_server_capabilities_list_changed_fields() -> None: + """Test that ServerCapabilities correctly parses listChanged fields.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + caps = server.capabilities + assert isinstance(caps.prompts_list_changed, bool) + assert isinstance(caps.tools_list_changed, bool) + assert isinstance(caps.resources_list_changed, bool)