From 8c9078982fc748dc18ee8a55de8b012d13dd7175 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:10:09 -0500 Subject: [PATCH 01/10] add tool and resource caching for mcp servers that support change notifications --- pydantic_ai_slim/pydantic_ai/mcp.py | 77 +++++++++++-- tests/test_mcp.py | 164 ++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+), 12 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ac3cfeae5c..051a2f5c20 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -232,6 +232,12 @@ class ServerCapabilities: tools: bool = False """Whether the server offers any tools to call.""" + tools_list_changed: bool = False + """Whether the server will emit notifications when the list of tools changes.""" + + resources_list_changed: bool = False + """Whether the server will emit notifications when the list of resources changes.""" + completions: bool = False """Whether the server offers autocompletion suggestions for prompts and resources.""" @@ -244,12 +250,16 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC Args: mcp_capabilities: The MCP SDK ServerCapabilities object. """ + tools_cap = mcp_capabilities.tools + resources_cap = mcp_capabilities.resources return cls( experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None, logging=mcp_capabilities.logging is not None, prompts=mcp_capabilities.prompts is not None, - resources=mcp_capabilities.resources is not None, - tools=mcp_capabilities.tools is not None, + resources=resources_cap is not None, + tools=tools_cap is not None, + tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False, + resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False, completions=mcp_capabilities.completions is not None, ) @@ -332,6 +342,11 @@ class MCPServer(AbstractToolset[Any], ABC): _server_capabilities: ServerCapabilities _instructions: str | None + _cached_tools: list[mcp_types.Tool] | None + _tools_cache_valid: bool + _cached_resources: list[Resource] | None + _resources_cache_valid: bool + def __init__( self, tool_prefix: str | None = None, @@ -366,6 +381,10 @@ def __post_init__(self): self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None + self._cached_tools = None + self._tools_cache_valid = False + self._cached_resources = None + self._resources_cache_valid = False @abstractmethod @asynccontextmanager @@ -430,13 +449,23 @@ def instructions(self) -> str | None: async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. - Note: - - We don't cache tools as they might change. - - We also don't subscribe to the server to avoid complexity. + Tools are cached when the server advertises `tools.listChanged` capability, + with cache invalidation on tool change notifications and reconnection. """ async with self: # Ensure server is running - result = await self._client.list_tools() - return result.tools + # Only cache if server supports listChanged notifications + if self._server_capabilities.tools_list_changed: + if self._cached_tools is not None and self._tools_cache_valid: + return self._cached_tools + + result = await self._client.list_tools() + self._cached_tools = result.tools + self._tools_cache_valid = True + return result.tools + else: + # Server doesn't support notifications, always fetch fresh + result = await self._client.list_tools() + return result.tools async def direct_call_tool( self, @@ -542,9 +571,8 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: async def list_resources(self) -> list[Resource]: """Retrieve resources that are currently present on the server. - Note: - - We don't cache resources as they might change. - - We also don't subscribe to resource changes to avoid complexity. + Resources are cached when the server advertises `resources.listChanged` capability, + with cache invalidation on resource change notifications and reconnection. Raises: MCPError: If the server returns an error. @@ -553,10 +581,21 @@ async def list_resources(self) -> list[Resource]: if not self.capabilities.resources: return [] try: - result = await self._client.list_resources() + # caching logic same as list_tools + if self._server_capabilities.resources_list_changed: + if self._cached_resources is not None and self._resources_cache_valid: + return self._cached_resources + + result = await self._client.list_resources() + resources = [Resource.from_mcp_sdk(r) for r in result.resources] + self._cached_resources = resources + self._resources_cache_valid = True + return resources + else: + result = await self._client.list_resources() + return [Resource.from_mcp_sdk(r) for r in result.resources] except mcp_exceptions.McpError as e: raise MCPError.from_mcp_sdk(e) from e - return [Resource.from_mcp_sdk(r) for r in result.resources] async def list_resource_templates(self) -> list[ResourceTemplate]: """Retrieve resource templates that are currently present on the server. @@ -619,6 +658,12 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._running_count == 0: + # Invalidate caches on fresh connection + self._cached_tools = None + self._tools_cache_valid = False + self._cached_resources = None + self._resources_cache_valid = False + async with AsyncExitStack() as exit_stack: self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams()) client = ClientSession( @@ -628,6 +673,7 @@ async def __aenter__(self) -> Self: elicitation_callback=self.elicitation_callback, logging_callback=self.log_handler, read_timeout_seconds=timedelta(seconds=self.read_timeout), + message_handler=self._handle_notification, ) self._client = await exit_stack.enter_async_context(client) @@ -680,6 +726,13 @@ async def _sampling_callback( model=self.sampling_model.model_name, ) + async def _handle_notification(self, message: Any) -> None: + """Handle notifications from the MCP server, invalidating caches as needed.""" + if isinstance(message, mcp_types.ToolListChangedNotification): + self._tools_cache_valid = False + elif isinstance(message, mcp_types.ResourceListChangedNotification): + self._resources_cache_valid = False + async def _map_tool_result_part( self, part: mcp_types.ContentBlock ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: diff --git a/tests/test_mcp.py b/tests/test_mcp.py index bd9db6602d..d2f0fe7237 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -54,7 +54,9 @@ ElicitRequestParams, ElicitResult, ImageContent, + ResourceListChangedNotification, TextContent, + ToolListChangedNotification, ) from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response @@ -1987,3 +1989,165 @@ async def test_custom_http_client_not_closed(): assert len(tools) > 0 assert not custom_http_client.is_closed + + +# ============================================================================ +# Tool and Resource Caching Tests +# ============================================================================ + + +async def test_tools_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None: + """Test that list_tools() caches results when server supports listChanged notifications.""" + async with mcp_server: + # Mock the server capabilities to indicate listChanged is supported + mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] + + # First call - should fetch from server and cache + tools1 = await mcp_server.list_tools() + assert len(tools1) > 0 + assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] + + # Mock _client.list_tools to track if it's called again + original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage] + call_count = 0 + + async def mock_list_tools(): # pragma: no cover + nonlocal call_count + call_count += 1 + return await original_list_tools() + + mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + + # Second call - should return cached value without calling server + tools2 = await mcp_server.list_tools() + assert tools2 == tools1 + assert call_count == 0 # list_tools should not have been called + + +async def test_tools_no_caching_without_list_changed_capability(mcp_server: MCPServerStdio) -> None: + """Test that list_tools() always fetches fresh when server doesn't support listChanged.""" + async with mcp_server: + # Verify the server doesn't advertise listChanged by default + # (this depends on the test MCP server implementation) + mcp_server._server_capabilities.tools_list_changed = False # pyright: ignore[reportPrivateUsage] + + # First call + tools1 = await mcp_server.list_tools() + assert len(tools1) > 0 + + # Mock _client.list_tools to track calls + original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage] + call_count = 0 + + async def mock_list_tools(): + nonlocal call_count + call_count += 1 + return await original_list_tools() + + mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + + # Second call - should fetch fresh since no listChanged capability + tools2 = await mcp_server.list_tools() + assert tools2 == tools1 + assert call_count == 1 # list_tools should have been called + + +async def test_tools_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: + """Test that tools cache is invalidated when ToolListChangedNotification is received.""" + async with mcp_server: + # Enable caching + mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] + + # Populate cache + await mcp_server.list_tools() + assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a tool list changed notification + notification = ToolListChangedNotification() + await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated + assert mcp_server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage] + + # Cached tools are still present but marked invalid + assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None: + """Test that list_resources() caches results when server supports listChanged notifications.""" + async with mcp_server: + # Mock the server capabilities to indicate listChanged is supported + mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage] + + # First call - should fetch from server and cache + if mcp_server.capabilities.resources: # pragma: no branch + resources1 = await mcp_server.list_resources() + assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage] + + # Mock _client.list_resources to track if it's called again + original_list_resources = mcp_server._client.list_resources # pyright: ignore[reportPrivateUsage] + call_count = 0 + + async def mock_list_resources(): # pragma: no cover + nonlocal call_count + call_count += 1 + return await original_list_resources() + + mcp_server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + + # Second call - should return cached value without calling server + resources2 = await mcp_server.list_resources() + assert resources2 == resources1 + assert call_count == 0 # list_resources should not have been called + + +async def test_resources_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: + """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" + async with mcp_server: + # Enable caching + mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage] + + # Populate cache (if server supports resources) + if mcp_server.capabilities.resources: # pragma: no branch + await mcp_server.list_resources() + assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a resource list changed notification + notification = ResourceListChangedNotification() + await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated + assert mcp_server._resources_cache_valid is False # pyright: ignore[reportPrivateUsage] + + +async def test_cache_invalidation_on_reconnection() -> None: + """Test that caches are cleared when reconnecting to the server.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + # First connection + async with server: + server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + assert server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] + + # After exiting, the server is no longer running + # but cache state persists until next connection + + # Reconnect + async with server: + # Cache should be cleared on fresh connection + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + assert server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage] + + +async def test_server_capabilities_list_changed_fields() -> None: + """Test that ServerCapabilities correctly parses listChanged fields.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + # Test that capabilities are accessible + caps = server.capabilities + assert isinstance(caps.tools_list_changed, bool) + assert isinstance(caps.resources_list_changed, bool) From 57e94ffdefe9488153f0b58c2e2f616540f16e97 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:31:07 -0500 Subject: [PATCH 02/10] apply review comments --- pydantic_ai_slim/pydantic_ai/mcp.py | 100 +++++++++++++------ tests/test_mcp.py | 148 ++++++++++++++-------------- 2 files changed, 144 insertions(+), 104 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 051a2f5c20..6b47d03691 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -226,6 +226,9 @@ class ServerCapabilities: prompts: bool = False """Whether the server offers any prompt templates.""" + prompts_list_changed: bool = False + """Whether the server will emit notifications when the list of prompts changes.""" + resources: bool = False """Whether the server offers any resources to read.""" @@ -250,16 +253,18 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC Args: mcp_capabilities: The MCP SDK ServerCapabilities object. """ - tools_cap = mcp_capabilities.tools + prompts_cap = mcp_capabilities.prompts resources_cap = mcp_capabilities.resources + tools_cap = mcp_capabilities.tools return cls( experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None, logging=mcp_capabilities.logging is not None, - prompts=mcp_capabilities.prompts is not None, + prompts=prompts_cap is not None, + prompts_list_changed=bool(prompts_cap.listChanged) if prompts_cap else False, resources=resources_cap is not None, + resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False, tools=tools_cap is not None, tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False, - resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False, completions=mcp_capabilities.completions is not None, ) @@ -329,6 +334,26 @@ class MCPServer(AbstractToolset[Any], ABC): elicitation_callback: ElicitationFnT | None = None """Callback function to handle elicitation requests from the server.""" + cache_tools: bool + """Whether to cache the list of tools. + + When enabled (default), tools are fetched once and cached until either: + - The server sends a `notifications/tools/list_changed` notification + - The connection is closed + + Set to `False` for servers that change tools dynamically without sending notifications. + """ + + cache_resources: bool + """Whether to cache the list of resources. + + When enabled (default), resources are fetched once and cached until either: + - The server sends a `notifications/resources/list_changed` notification + - The connection is closed + + Set to `False` for servers that change resources dynamically without sending notifications. + """ + _id: str | None _enter_lock: Lock = field(compare=False) @@ -343,9 +368,7 @@ class MCPServer(AbstractToolset[Any], ABC): _instructions: str | None _cached_tools: list[mcp_types.Tool] | None - _tools_cache_valid: bool _cached_resources: list[Resource] | None - _resources_cache_valid: bool def __init__( self, @@ -359,6 +382,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, *, id: str | None = None, ): @@ -372,6 +397,8 @@ def __init__( self.sampling_model = sampling_model self.max_retries = max_retries self.elicitation_callback = elicitation_callback + self.cache_tools = cache_tools + self.cache_resources = cache_resources self._id = id or tool_prefix @@ -382,9 +409,7 @@ def __post_init__(self): self._running_count = 0 self._exit_stack = None self._cached_tools = None - self._tools_cache_valid = False self._cached_resources = None - self._resources_cache_valid = False @abstractmethod @asynccontextmanager @@ -449,21 +474,20 @@ def instructions(self) -> str | None: async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. - Tools are cached when the server advertises `tools.listChanged` capability, - with cache invalidation on tool change notifications and reconnection. + Tools are cached by default, with cache invalidation on: + - `notifications/tools/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_tools=False` for servers that change tools without sending notifications. """ - async with self: # Ensure server is running - # Only cache if server supports listChanged notifications - if self._server_capabilities.tools_list_changed: - if self._cached_tools is not None and self._tools_cache_valid: + async with self: + if self.cache_tools: + if self._cached_tools is not None: return self._cached_tools - result = await self._client.list_tools() self._cached_tools = result.tools - self._tools_cache_valid = True return result.tools else: - # Server doesn't support notifications, always fetch fresh result = await self._client.list_tools() return result.tools @@ -571,25 +595,25 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: async def list_resources(self) -> list[Resource]: """Retrieve resources that are currently present on the server. - Resources are cached when the server advertises `resources.listChanged` capability, - with cache invalidation on resource change notifications and reconnection. + Resources are cached by default, with cache invalidation on: + - `notifications/resources/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_resources=False` for servers that change resources without sending notifications. Raises: MCPError: If the server returns an error. """ - async with self: # Ensure server is running + async with self: if not self.capabilities.resources: return [] try: - # caching logic same as list_tools - if self._server_capabilities.resources_list_changed: - if self._cached_resources is not None and self._resources_cache_valid: + if self.cache_resources: + if self._cached_resources is not None: return self._cached_resources - result = await self._client.list_resources() resources = [Resource.from_mcp_sdk(r) for r in result.resources] self._cached_resources = resources - self._resources_cache_valid = True return resources else: result = await self._client.list_resources() @@ -658,12 +682,6 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._running_count == 0: - # Invalidate caches on fresh connection - self._cached_tools = None - self._tools_cache_valid = False - self._cached_resources = None - self._resources_cache_valid = False - async with AsyncExitStack() as exit_stack: self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams()) client = ClientSession( @@ -697,6 +715,8 @@ async def __aexit__(self, *args: Any) -> bool | None: if self._running_count == 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None + self._cached_tools = None + self._cached_resources = None @property def is_running(self) -> bool: @@ -729,9 +749,9 @@ async def _sampling_callback( async def _handle_notification(self, message: Any) -> None: """Handle notifications from the MCP server, invalidating caches as needed.""" if isinstance(message, mcp_types.ToolListChangedNotification): - self._tools_cache_valid = False + self._cached_tools = None elif isinstance(message, mcp_types.ResourceListChangedNotification): - self._resources_cache_valid = False + self._cached_resources = None async def _map_tool_result_part( self, part: mcp_types.ContentBlock @@ -829,6 +849,8 @@ class MCPServerStdio(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -847,6 +869,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, id: str | None = None, ): """Build a new MCP server. @@ -866,6 +890,8 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + cache_resources: Whether to cache the list of resources. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -884,6 +910,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) @@ -983,6 +1011,8 @@ class _MCPServerHTTP(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -1001,6 +1031,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, **_deprecated_kwargs: Any, ): """Build a new MCP server. @@ -1020,6 +1052,8 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + cache_resources: Whether to cache the list of resources. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: @@ -1050,6 +1084,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index d2f0fe7237..8f2c377fb5 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1996,17 +1996,13 @@ async def test_custom_http_client_not_closed(): # ============================================================================ -async def test_tools_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None: - """Test that list_tools() caches results when server supports listChanged notifications.""" +async def test_tools_caching_enabled_by_default(mcp_server: MCPServerStdio) -> None: + """Test that list_tools() caches results by default.""" async with mcp_server: - # Mock the server capabilities to indicate listChanged is supported - mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] - # First call - should fetch from server and cache tools1 = await mcp_server.list_tools() assert len(tools1) > 0 assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] - assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] # Mock _client.list_tools to track if it's called again original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage] @@ -2022,22 +2018,19 @@ async def mock_list_tools(): # pragma: no cover # Second call - should return cached value without calling server tools2 = await mcp_server.list_tools() assert tools2 == tools1 - assert call_count == 0 # list_tools should not have been called - + assert call_count == 0 -async def test_tools_no_caching_without_list_changed_capability(mcp_server: MCPServerStdio) -> None: - """Test that list_tools() always fetches fresh when server doesn't support listChanged.""" - async with mcp_server: - # Verify the server doesn't advertise listChanged by default - # (this depends on the test MCP server implementation) - mcp_server._server_capabilities.tools_list_changed = False # pyright: ignore[reportPrivateUsage] +async def test_tools_no_caching_when_disabled() -> None: + """Test that list_tools() always fetches fresh when cache_tools=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_tools=False) + async with server: # First call - tools1 = await mcp_server.list_tools() + tools1 = await server.list_tools() assert len(tools1) > 0 # Mock _client.list_tools to track calls - original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage] + original_list_tools = server._client.list_tools # pyright: ignore[reportPrivateUsage] call_count = 0 async def mock_list_tools(): @@ -2045,109 +2038,120 @@ async def mock_list_tools(): call_count += 1 return await original_list_tools() - mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] - # Second call - should fetch fresh since no listChanged capability - tools2 = await mcp_server.list_tools() + # Second call - should fetch fresh since caching is disabled + tools2 = await server.list_tools() assert tools2 == tools1 - assert call_count == 1 # list_tools should have been called + assert call_count == 1 + + +async def test_resources_no_caching_when_disabled() -> None: + """Test that list_resources() always fetches fresh when cache_resources=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_resources=False) + async with server: + assert server.capabilities.resources + # First call + resources1 = await server.list_resources() + + # Mock _client.list_resources to track calls + original_list_resources = server._client.list_resources # pyright: ignore[reportPrivateUsage] + call_count = 0 + + async def mock_list_resources(): + nonlocal call_count + call_count += 1 + return await original_list_resources() + + server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + + # Second call - should fetch fresh since caching is disabled + resources2 = await server.list_resources() + assert resources2 == resources1 + assert call_count == 1 async def test_tools_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: """Test that tools cache is invalidated when ToolListChangedNotification is received.""" async with mcp_server: - # Enable caching - mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] - # Populate cache await mcp_server.list_tools() - assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] + assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] # Simulate receiving a tool list changed notification notification = ToolListChangedNotification() await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] - # Cache should be invalidated - assert mcp_server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage] + # Cache should be invalidated (set to None) + assert mcp_server._cached_tools is None # pyright: ignore[reportPrivateUsage] - # Cached tools are still present but marked invalid - assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] - -async def test_resources_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None: - """Test that list_resources() caches results when server supports listChanged notifications.""" +async def test_resources_caching_enabled_by_default(mcp_server: MCPServerStdio) -> None: + """Test that list_resources() caches results by default.""" async with mcp_server: - # Mock the server capabilities to indicate listChanged is supported - mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage] - + assert mcp_server.capabilities.resources # First call - should fetch from server and cache - if mcp_server.capabilities.resources: # pragma: no branch - resources1 = await mcp_server.list_resources() - assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage] + resources1 = await mcp_server.list_resources() + assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - # Mock _client.list_resources to track if it's called again - original_list_resources = mcp_server._client.list_resources # pyright: ignore[reportPrivateUsage] - call_count = 0 + # Mock _client.list_resources to track if it's called again + original_list_resources = mcp_server._client.list_resources # pyright: ignore[reportPrivateUsage] + call_count = 0 - async def mock_list_resources(): # pragma: no cover - nonlocal call_count - call_count += 1 - return await original_list_resources() + async def mock_list_resources(): # pragma: no cover + nonlocal call_count + call_count += 1 + return await original_list_resources() - mcp_server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + mcp_server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] - # Second call - should return cached value without calling server - resources2 = await mcp_server.list_resources() - assert resources2 == resources1 - assert call_count == 0 # list_resources should not have been called + # Second call - should return cached value without calling server + resources2 = await mcp_server.list_resources() + assert resources2 == resources1 + assert call_count == 0 async def test_resources_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" async with mcp_server: - # Enable caching - mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage] - - # Populate cache (if server supports resources) - if mcp_server.capabilities.resources: # pragma: no branch - await mcp_server.list_resources() - assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage] + assert mcp_server.capabilities.resources + # Populate cache + await mcp_server.list_resources() + assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - # Simulate receiving a resource list changed notification - notification = ResourceListChangedNotification() - await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + # Simulate receiving a "resource list changed" notification + notification = ResourceListChangedNotification() + await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] - # Cache should be invalidated - assert mcp_server._resources_cache_valid is False # pyright: ignore[reportPrivateUsage] + # Cache should be invalidated + assert mcp_server._cached_resources is None # pyright: ignore[reportPrivateUsage] -async def test_cache_invalidation_on_reconnection() -> None: - """Test that caches are cleared when reconnecting to the server.""" +async def test_cache_cleared_on_connection_close() -> None: + """Test that caches are cleared when the connection is closed.""" server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) # First connection async with server: - server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage] await server.list_tools() assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] - assert server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage] - # After exiting, the server is no longer running - # but cache state persists until next connection + # After exiting, cache should be cleared by __aexit__ + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] - # Reconnect + # Reconnect and verify cache starts empty async with server: - # Cache should be cleared on fresh connection assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] - assert server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage] + # Fetch again to populate + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] async def test_server_capabilities_list_changed_fields() -> None: """Test that ServerCapabilities correctly parses listChanged fields.""" server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - # Test that capabilities are accessible caps = server.capabilities + assert isinstance(caps.prompts_list_changed, bool) assert isinstance(caps.tools_list_changed, bool) assert isinstance(caps.resources_list_changed, bool) From fd9ddfb554112fc8f1c0422c649b123ee84a19a5 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 27 Nov 2025 17:35:00 -0500 Subject: [PATCH 03/10] reorder and add cahced prompts field --- pydantic_ai_slim/pydantic_ai/mcp.py | 11 ++++++++--- tests/test_mcp.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 6b47d03691..d07b2aa440 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -232,15 +232,15 @@ class ServerCapabilities: resources: bool = False """Whether the server offers any resources to read.""" + resources_list_changed: bool = False + """Whether the server will emit notifications when the list of resources changes.""" + tools: bool = False """Whether the server offers any tools to call.""" tools_list_changed: bool = False """Whether the server will emit notifications when the list of tools changes.""" - resources_list_changed: bool = False - """Whether the server will emit notifications when the list of resources changes.""" - completions: bool = False """Whether the server offers autocompletion suggestions for prompts and resources.""" @@ -369,6 +369,7 @@ class MCPServer(AbstractToolset[Any], ABC): _cached_tools: list[mcp_types.Tool] | None _cached_resources: list[Resource] | None + _cached_prompts: list[mcp_types.Prompt] | None def __init__( self, @@ -410,6 +411,7 @@ def __post_init__(self): self._exit_stack = None self._cached_tools = None self._cached_resources = None + self._cached_prompts = None @abstractmethod @asynccontextmanager @@ -717,6 +719,7 @@ async def __aexit__(self, *args: Any) -> bool | None: self._exit_stack = None self._cached_tools = None self._cached_resources = None + self._cached_prompts = None @property def is_running(self) -> bool: @@ -752,6 +755,8 @@ async def _handle_notification(self, message: Any) -> None: self._cached_tools = None elif isinstance(message, mcp_types.ResourceListChangedNotification): self._cached_resources = None + elif isinstance(message, mcp_types.PromptListChangedNotification): + self._cached_prompts = None async def _map_tool_result_part( self, part: mcp_types.ContentBlock diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 1db1ef0a5c..d2a601cfcf 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -54,6 +54,7 @@ ElicitRequestParams, ElicitResult, ImageContent, + PromptListChangedNotification, ResourceListChangedNotification, TextContent, ToolListChangedNotification, @@ -2147,6 +2148,21 @@ async def test_resources_cache_invalidation_on_notification(mcp_server: MCPServe assert mcp_server._cached_resources is None # pyright: ignore[reportPrivateUsage] +async def test_prompts_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: + """Test that prompts cache is invalidated when PromptListChangedNotification is received.""" + async with mcp_server: + # Manually set a cached value (no list_prompts() method exists yet) + mcp_server._cached_prompts = [] # pyright: ignore[reportPrivateUsage] + assert mcp_server._cached_prompts is not None # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a "prompt list changed" notification + notification = PromptListChangedNotification() + await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated + assert mcp_server._cached_prompts is None # pyright: ignore[reportPrivateUsage] + + async def test_cache_cleared_on_connection_close() -> None: """Test that caches are cleared when the connection is closed.""" server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) From 86e204f7e398ad0afa6166feb424b4297b4d5cab Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:46:13 -0500 Subject: [PATCH 04/10] remove cahed prompts and replace mocked tests --- pydantic_ai_slim/pydantic_ai/mcp.py | 9 +- tests/test_mcp.py | 163 ++++++++++------------------ 2 files changed, 61 insertions(+), 111 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index d07b2aa440..6e713f6a11 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -369,7 +369,6 @@ class MCPServer(AbstractToolset[Any], ABC): _cached_tools: list[mcp_types.Tool] | None _cached_resources: list[Resource] | None - _cached_prompts: list[mcp_types.Prompt] | None def __init__( self, @@ -411,7 +410,6 @@ def __post_init__(self): self._exit_stack = None self._cached_tools = None self._cached_resources = None - self._cached_prompts = None @abstractmethod @asynccontextmanager @@ -719,7 +717,6 @@ async def __aexit__(self, *args: Any) -> bool | None: self._exit_stack = None self._cached_tools = None self._cached_resources = None - self._cached_prompts = None @property def is_running(self) -> bool: @@ -755,8 +752,6 @@ async def _handle_notification(self, message: Any) -> None: self._cached_tools = None elif isinstance(message, mcp_types.ResourceListChangedNotification): self._cached_resources = None - elif isinstance(message, mcp_types.PromptListChangedNotification): - self._cached_prompts = None async def _map_tool_result_part( self, part: mcp_types.ContentBlock @@ -896,7 +891,9 @@ def __init__( max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -1058,7 +1055,9 @@ def __init__( max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: diff --git a/tests/test_mcp.py b/tests/test_mcp.py index d2a601cfcf..1130c54186 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -54,10 +54,7 @@ ElicitRequestParams, ElicitResult, ImageContent, - PromptListChangedNotification, - ResourceListChangedNotification, TextContent, - ToolListChangedNotification, ) from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response @@ -2017,150 +2014,104 @@ async def test_custom_http_client_not_closed(): # ============================================================================ -async def test_tools_caching_enabled_by_default(mcp_server: MCPServerStdio) -> None: +async def test_tools_caching_enabled_by_default() -> None: """Test that list_tools() caches results by default.""" - async with mcp_server: + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: # First call - should fetch from server and cache - tools1 = await mcp_server.list_tools() + tools1 = await server.list_tools() assert len(tools1) > 0 - assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] - - # Mock _client.list_tools to track if it's called again - original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage] - call_count = 0 - - async def mock_list_tools(): # pragma: no cover - nonlocal call_count - call_count += 1 - return await original_list_tools() - - mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] - # Second call - should return cached value without calling server - tools2 = await mcp_server.list_tools() + # Second call - should return cached value (cache is still populated) + tools2 = await server.list_tools() assert tools2 == tools1 - assert call_count == 0 + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] async def test_tools_no_caching_when_disabled() -> None: - """Test that list_tools() always fetches fresh when cache_tools=False.""" + """Test that list_tools() does not cache when cache_tools=False.""" server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_tools=False) async with server: - # First call + # First call - should not populate cache tools1 = await server.list_tools() assert len(tools1) > 0 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] - # Mock _client.list_tools to track calls - original_list_tools = server._client.list_tools # pyright: ignore[reportPrivateUsage] - call_count = 0 - - async def mock_list_tools(): - nonlocal call_count - call_count += 1 - return await original_list_tools() - - server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] - - # Second call - should fetch fresh since caching is disabled + # Second call - cache should still be None tools2 = await server.list_tools() assert tools2 == tools1 - assert call_count == 1 - - -async def test_resources_no_caching_when_disabled() -> None: - """Test that list_resources() always fetches fresh when cache_resources=False.""" - server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_resources=False) - async with server: - assert server.capabilities.resources - # First call - resources1 = await server.list_resources() - - # Mock _client.list_resources to track calls - original_list_resources = server._client.list_resources # pyright: ignore[reportPrivateUsage] - call_count = 0 - - async def mock_list_resources(): - nonlocal call_count - call_count += 1 - return await original_list_resources() - - server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] - - # Second call - should fetch fresh since caching is disabled - resources2 = await server.list_resources() - assert resources2 == resources1 - assert call_count == 1 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] -async def test_tools_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: +async def test_tools_cache_invalidation_on_notification() -> None: """Test that tools cache is invalidated when ToolListChangedNotification is received.""" - async with mcp_server: + from mcp.types import ToolListChangedNotification + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: # Populate cache - await mcp_server.list_tools() - assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] # Simulate receiving a tool list changed notification notification = ToolListChangedNotification() - await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] # Cache should be invalidated (set to None) - assert mcp_server._cached_tools is None # pyright: ignore[reportPrivateUsage] + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] -async def test_resources_caching_enabled_by_default(mcp_server: MCPServerStdio) -> None: +async def test_resources_caching_enabled_by_default() -> None: """Test that list_resources() caches results by default.""" - async with mcp_server: - assert mcp_server.capabilities.resources + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources + # First call - should fetch from server and cache - resources1 = await mcp_server.list_resources() - assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + resources1 = await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - # Mock _client.list_resources to track if it's called again - original_list_resources = mcp_server._client.list_resources # pyright: ignore[reportPrivateUsage] - call_count = 0 + # Second call - should return cached value (cache is still populated) + resources2 = await server.list_resources() + assert resources2 == resources1 + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - async def mock_list_resources(): # pragma: no cover - nonlocal call_count - call_count += 1 - return await original_list_resources() - mcp_server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue] +async def test_resources_no_caching_when_disabled() -> None: + """Test that list_resources() does not cache when cache_resources=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_resources=False) + async with server: + assert server.capabilities.resources - # Second call - should return cached value without calling server - resources2 = await mcp_server.list_resources() + # First call - should not populate cache + resources1 = await server.list_resources() + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + # Second call - cache should still be None + resources2 = await server.list_resources() assert resources2 == resources1 - assert call_count == 0 + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] -async def test_resources_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: +async def test_resources_cache_invalidation_on_notification() -> None: """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" - async with mcp_server: - assert mcp_server.capabilities.resources - # Populate cache - await mcp_server.list_resources() - assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - - # Simulate receiving a "resource list changed" notification - notification = ResourceListChangedNotification() - await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] - - # Cache should be invalidated - assert mcp_server._cached_resources is None # pyright: ignore[reportPrivateUsage] + from mcp.types import ResourceListChangedNotification + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources -async def test_prompts_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None: - """Test that prompts cache is invalidated when PromptListChangedNotification is received.""" - async with mcp_server: - # Manually set a cached value (no list_prompts() method exists yet) - mcp_server._cached_prompts = [] # pyright: ignore[reportPrivateUsage] - assert mcp_server._cached_prompts is not None # pyright: ignore[reportPrivateUsage] + # Populate cache + await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] - # Simulate receiving a "prompt list changed" notification - notification = PromptListChangedNotification() - await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + # Simulate receiving a resource list changed notification + notification = ResourceListChangedNotification() + await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] # Cache should be invalidated - assert mcp_server._cached_prompts is None # pyright: ignore[reportPrivateUsage] + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] async def test_cache_cleared_on_connection_close() -> None: From 1d6041f3c0f9af4d597cd0f1aef04058c443bc7d Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 30 Nov 2025 08:33:52 -0500 Subject: [PATCH 05/10] add live mcp tool change notification and narrow arg types --- pydantic_ai_slim/pydantic_ai/mcp.py | 17 ++++++++++++----- tests/mcp_server.py | 13 +++++++++++++ tests/test_mcp.py | 29 +++++++++++++++-------------- 3 files changed, 40 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 6e713f6a11..227b8e1399 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -36,6 +36,7 @@ from mcp.shared import exceptions as mcp_exceptions from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage + from mcp.shared.session import RequestResponder except ImportError as _import_error: raise ImportError( 'Please install the `mcp` package to use the MCP server, ' @@ -746,12 +747,18 @@ async def _sampling_callback( model=self.sampling_model.model_name, ) - async def _handle_notification(self, message: Any) -> None: + async def _handle_notification( + self, + message: RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] + | mcp_types.ServerNotification + | Exception, + ) -> None: """Handle notifications from the MCP server, invalidating caches as needed.""" - if isinstance(message, mcp_types.ToolListChangedNotification): - self._cached_tools = None - elif isinstance(message, mcp_types.ResourceListChangedNotification): - self._cached_resources = None + if isinstance(message, mcp_types.ServerNotification): # pragma: no branch + if isinstance(message.root, mcp_types.ToolListChangedNotification): + self._cached_tools = None + elif isinstance(message.root, mcp_types.ResourceListChangedNotification): + self._cached_resources = None async def _map_tool_result_part( self, part: mcp_types.ContentBlock diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 8ba9b9997f..cd042dc76a 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -235,6 +235,19 @@ async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> s return f'User {result.action}ed the elicitation' +async def hidden_tool() -> str: + """A tool that is hidden by default.""" + return 'I was hidden!' + + +@mcp.tool() +async def enable_hidden_tool(ctx: Context[ServerSession, None]) -> str: + """Enable the hidden tool, triggering a ToolListChangedNotification.""" + mcp._tool_manager.add_tool(hidden_tool) # pyright: ignore[reportPrivateUsage] + await ctx.session.send_tool_list_changed() + return 'Hidden tool enabled' + + @mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] async def set_logging_level(level: str) -> None: global log_level diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 1130c54186..7b1d152ea2 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -95,7 +95,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -156,7 +156,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) async def test_process_tool_call(run_context: RunContext[int]) -> int: @@ -2046,20 +2046,21 @@ async def test_tools_no_caching_when_disabled() -> None: async def test_tools_cache_invalidation_on_notification() -> None: """Test that tools cache is invalidated when ToolListChangedNotification is received.""" - from mcp.types import ToolListChangedNotification - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - # Populate cache - await server.list_tools() - assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + # Get initial tools - hidden_tool should NOT be present (it's disabled at startup) + tools1 = await server.list_tools() + tool_names1 = [t.name for t in tools1] + assert 'hidden_tool' not in tool_names1 + assert 'enable_hidden_tool' in tool_names1 - # Simulate receiving a tool list changed notification - notification = ToolListChangedNotification() - await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + # Enable the hidden tool (server sends ToolListChangedNotification) + await server.direct_call_tool('enable_hidden_tool', {}) - # Cache should be invalidated (set to None) - assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + # Get tools again - hidden_tool should now be present (cache was invalidated) + tools2 = await server.list_tools() + tool_names2 = [t.name for t in tools2] + assert 'hidden_tool' in tool_names2 async def test_resources_caching_enabled_by_default() -> None: @@ -2096,7 +2097,7 @@ async def test_resources_no_caching_when_disabled() -> None: async def test_resources_cache_invalidation_on_notification() -> None: """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" - from mcp.types import ResourceListChangedNotification + from mcp.types import ResourceListChangedNotification, ServerNotification server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: @@ -2107,7 +2108,7 @@ async def test_resources_cache_invalidation_on_notification() -> None: assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] # Simulate receiving a resource list changed notification - notification = ResourceListChangedNotification() + notification = ServerNotification(ResourceListChangedNotification()) await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] # Cache should be invalidated From a715fc4ae7aea78d59e44a3592be0348db5f3ac3 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 30 Nov 2025 11:24:53 -0500 Subject: [PATCH 06/10] fix mcp image content b64decode bug - add support for audio --- pydantic_ai_slim/pydantic_ai/_mcp.py | 12 +++++++-- pydantic_ai_slim/pydantic_ai/mcp.py | 6 +---- tests/test_mcp.py | 40 +++++++++++++++++++++++++++- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_mcp.py b/pydantic_ai_slim/pydantic_ai/_mcp.py index f00230ee07..56c9c9f123 100644 --- a/pydantic_ai_slim/pydantic_ai/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/_mcp.py @@ -91,11 +91,19 @@ def add_msg( 'user', mcp_types.ImageContent( type='image', - data=base64.b64decode(chunk.data).decode(), + data=base64.b64encode(chunk.data).decode(), + mimeType=chunk.media_type, + ), + ) + elif isinstance(chunk, messages.BinaryContent) and chunk.is_audio: + add_msg( + 'user', + mcp_types.AudioContent( + type='audio', + data=base64.b64encode(chunk.data).decode(), mimeType=chunk.media_type, ), ) - # TODO(Marcelo): Add support for audio content. else: raise NotImplementedError(f'Unsupported content type: {type(chunk)}') else: diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 227b8e1399..b3f5837816 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -776,11 +776,7 @@ async def _map_tool_result_part( elif isinstance(part, mcp_types.ImageContent): return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) elif isinstance(part, mcp_types.AudioContent): - # NOTE: The FastMCP server doesn't support audio content. - # See for more details. - return messages.BinaryContent( - data=base64.b64decode(part.data), media_type=part.mimeType - ) # pragma: no cover + return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) elif isinstance(part, mcp_types.EmbeddedResource): resource = part.resource return self._get_content(resource) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 7b1d152ea2..b03e17a933 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -57,7 +57,7 @@ TextContent, ) - from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response + from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response, map_from_pai_messages from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import OpenAIChatModel @@ -1525,6 +1525,44 @@ def test_map_from_mcp_params_model_response(): ) +def test_map_from_pai_messages_with_binary_content(): + """Test that map_from_pai_messages correctly converts image and audio content to MCP format.""" + message = ModelRequest( + parts=[ + UserPromptPart(content='text message'), + UserPromptPart(content=[BinaryContent(data=b'img', media_type='image/png')]), + UserPromptPart(content=[BinaryContent(data=b'audio', media_type='audio/wav')]), + ] + ) + system_prompt, sampling_msgs = map_from_pai_messages([message]) + assert system_prompt == '' + assert [m.model_dump(by_alias=True) for m in sampling_msgs] == snapshot( + [ + {'role': 'user', 'content': {'type': 'text', 'text': 'text message', 'annotations': None, '_meta': None}}, + { + 'role': 'user', + 'content': { + 'type': 'image', + 'data': 'aW1n', + 'mimeType': 'image/png', + 'annotations': None, + '_meta': None, + }, + }, + { + 'role': 'user', + 'content': { + 'type': 'audio', + 'data': 'YXVkaW8=', + 'mimeType': 'audio/wav', + 'annotations': None, + '_meta': None, + }, + }, + ] + ) + + def test_map_from_model_response(): with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'): map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')])) From 970ca03699dc45740e65cf9cc7ba331c3bf1b938 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 30 Nov 2025 12:38:17 -0500 Subject: [PATCH 07/10] decode was right, fix test --- pydantic_ai_slim/pydantic_ai/_mcp.py | 39 +++++++++++++++------------- tests/test_mcp.py | 28 ++++++++++++++++---- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_mcp.py b/pydantic_ai_slim/pydantic_ai/_mcp.py index 56c9c9f123..1eeb0abae0 100644 --- a/pydantic_ai_slim/pydantic_ai/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/_mcp.py @@ -86,24 +86,27 @@ def add_msg( for chunk in part.content: if isinstance(chunk, str): add_msg('user', mcp_types.TextContent(type='text', text=chunk)) - elif isinstance(chunk, messages.BinaryContent) and chunk.is_image: - add_msg( - 'user', - mcp_types.ImageContent( - type='image', - data=base64.b64encode(chunk.data).decode(), - mimeType=chunk.media_type, - ), - ) - elif isinstance(chunk, messages.BinaryContent) and chunk.is_audio: - add_msg( - 'user', - mcp_types.AudioContent( - type='audio', - data=base64.b64encode(chunk.data).decode(), - mimeType=chunk.media_type, - ), - ) + elif isinstance(chunk, messages.BinaryContent): + # `BinaryContent.data` are base64-encoded bytes. + base64_data = base64.b64decode(chunk.data).decode() + if chunk.is_image: + add_msg( + 'user', + mcp_types.ImageContent( + type='image', + data=base64_data, + mimeType=chunk.media_type, + ), + ) + elif chunk.is_audio: + add_msg( + 'user', + mcp_types.AudioContent( + type='audio', + data=base64_data, + mimeType=chunk.media_type, + ), + ) else: raise NotImplementedError(f'Unsupported content type: {type(chunk)}') else: diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b03e17a933..8f8fd96470 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -50,6 +50,7 @@ from mcp.client.session import ClientSession from mcp.shared.context import RequestContext from mcp.types import ( + AudioContent, CreateMessageRequestParams, ElicitRequestParams, ElicitResult, @@ -1526,12 +1527,20 @@ def test_map_from_mcp_params_model_response(): def test_map_from_pai_messages_with_binary_content(): - """Test that map_from_pai_messages correctly converts image and audio content to MCP format.""" + """Test that map_from_pai_messages correctly converts image and audio content to MCP format. + + Note: BinaryContent.data is base64-encoded bytes (e.g., base64.b64encode(b'raw')). + map_from_pai_messages decodes this to get the base64 string for MCP. + """ + # BinaryContent.data is base64-encoded bytes + image_data = base64.b64encode(b'raw_image_bytes') + audio_data = base64.b64encode(b'raw_audio_bytes') + message = ModelRequest( parts=[ UserPromptPart(content='text message'), - UserPromptPart(content=[BinaryContent(data=b'img', media_type='image/png')]), - UserPromptPart(content=[BinaryContent(data=b'audio', media_type='audio/wav')]), + UserPromptPart(content=[BinaryContent(data=image_data, media_type='image/png')]), + UserPromptPart(content=[BinaryContent(data=audio_data, media_type='audio/wav')]), ] ) system_prompt, sampling_msgs = map_from_pai_messages([message]) @@ -1543,7 +1552,7 @@ def test_map_from_pai_messages_with_binary_content(): 'role': 'user', 'content': { 'type': 'image', - 'data': 'aW1n', + 'data': 'raw_image_bytes', 'mimeType': 'image/png', 'annotations': None, '_meta': None, @@ -1553,7 +1562,7 @@ def test_map_from_pai_messages_with_binary_content(): 'role': 'user', 'content': { 'type': 'audio', - 'data': 'YXVkaW8=', + 'data': 'raw_audio_bytes', 'mimeType': 'audio/wav', 'annotations': None, '_meta': None, @@ -1568,6 +1577,15 @@ def test_map_from_model_response(): map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')])) +async def test_map_tool_result_part_audio(): + """Test that _map_tool_result_part correctly handles AudioContent.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + audio_content = AudioContent(type='audio', data='ZmFrZV9hdWRpb19kYXRh', mimeType='audio/mpeg') + async with server: + result = await server._map_tool_result_part(audio_content) # pyright: ignore[reportPrivateUsage] + assert result == BinaryContent(data=b'fake_audio_data', media_type='audio/mpeg') + + async def test_elicitation_callback_functionality(run_context: RunContext[int]): """Test that elicitation callback is actually called and works.""" # Track callback execution From c6ac978e10a259a9cd5496e460db34aac64689b3 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 30 Nov 2025 13:02:23 -0500 Subject: [PATCH 08/10] cover the continuation of the loop branch --- tests/test_mcp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 8f8fd96470..14acf6574e 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1540,7 +1540,8 @@ def test_map_from_pai_messages_with_binary_content(): parts=[ UserPromptPart(content='text message'), UserPromptPart(content=[BinaryContent(data=image_data, media_type='image/png')]), - UserPromptPart(content=[BinaryContent(data=audio_data, media_type='audio/wav')]), + # cover the loop continuation branch + UserPromptPart(content=[BinaryContent(data=audio_data, media_type='audio/wav'), 'text after audio']), ] ) system_prompt, sampling_msgs = map_from_pai_messages([message]) @@ -1568,6 +1569,10 @@ def test_map_from_pai_messages_with_binary_content(): '_meta': None, }, }, + { + 'role': 'user', + 'content': {'type': 'text', 'text': 'text after audio', 'annotations': None, '_meta': None}, + }, ] ) From 75da1b780f20357714132103c7ce0698cdab49e0 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 30 Nov 2025 13:02:23 -0500 Subject: [PATCH 09/10] cover the continuation of the loop branch --- pydantic_ai_slim/pydantic_ai/_mcp.py | 2 ++ tests/test_mcp.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/_mcp.py b/pydantic_ai_slim/pydantic_ai/_mcp.py index 1eeb0abae0..ce1b41b3d7 100644 --- a/pydantic_ai_slim/pydantic_ai/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/_mcp.py @@ -107,6 +107,8 @@ def add_msg( mimeType=chunk.media_type, ), ) + else: + raise NotImplementedError(f'Unsupported binary content type: {chunk.media_type}') else: raise NotImplementedError(f'Unsupported content type: {type(chunk)}') else: diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 14acf6574e..7bf8f2a848 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1576,6 +1576,14 @@ def test_map_from_pai_messages_with_binary_content(): ] ) + # Unsupported binary content type raises NotImplementedError + video_data = base64.b64encode(b'raw_video_bytes') + message_with_video = ModelRequest( + parts=[UserPromptPart(content=[BinaryContent(data=video_data, media_type='video/mp4')])] + ) + with pytest.raises(NotImplementedError, match='Unsupported binary content type: video/mp4'): + map_from_pai_messages([message_with_video]) + def test_map_from_model_response(): with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'): From b97fa5eea3a9b455255f063f86e35d7e588dd6ec Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:47:16 -0500 Subject: [PATCH 10/10] narrow wording --- pydantic_ai_slim/pydantic_ai/_mcp.py | 2 +- tests/test_mcp.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_mcp.py b/pydantic_ai_slim/pydantic_ai/_mcp.py index ce1b41b3d7..05a79a095d 100644 --- a/pydantic_ai_slim/pydantic_ai/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/_mcp.py @@ -87,7 +87,7 @@ def add_msg( if isinstance(chunk, str): add_msg('user', mcp_types.TextContent(type='text', text=chunk)) elif isinstance(chunk, messages.BinaryContent): - # `BinaryContent.data` are base64-encoded bytes. + # `chunk.data` in this case are base64-encoded bytes. base64_data = base64.b64decode(chunk.data).decode() if chunk.is_image: add_msg( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 7bf8f2a848..b86c8adcb4 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1529,10 +1529,9 @@ def test_map_from_mcp_params_model_response(): def test_map_from_pai_messages_with_binary_content(): """Test that map_from_pai_messages correctly converts image and audio content to MCP format. - Note: BinaryContent.data is base64-encoded bytes (e.g., base64.b64encode(b'raw')). + Note: `data` in this case are base64-encoded bytes (e.g., base64.b64encode(b'raw')). map_from_pai_messages decodes this to get the base64 string for MCP. """ - # BinaryContent.data is base64-encoded bytes image_data = base64.b64encode(b'raw_image_bytes') audio_data = base64.b64encode(b'raw_audio_bytes')