diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 5719bc1b..fac7ecad 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -47,6 +47,7 @@ async def send_message( self, request: Message, *, + configuration: MessageSendConfiguration | None = None, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, extensions: list[str] | None = None, @@ -59,6 +60,7 @@ async def send_message( Args: request: The message to send to the agent. + configuration: Optional per-call overrides for message sending behavior. context: The client call context. request_metadata: Extensions Metadata attached to the request. extensions: List of extensions to be activated. @@ -66,7 +68,7 @@ async def send_message( Yields: An async iterator of `ClientEvent` or a final `Message` response. """ - config = MessageSendConfiguration( + base_config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, push_notification_config=( @@ -75,6 +77,15 @@ async def send_message( else None ), ) + if configuration is not None: + update_data = configuration.model_dump( + exclude_unset=True, + by_alias=False, + ) + config = base_config.model_copy(update=update_data) + else: + config = base_config + params = MessageSendParams( message=request, configuration=config, metadata=request_metadata ) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index f5ab2543..7aa47902 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -9,6 +9,7 @@ AgentCapabilities, AgentCard, Message, + MessageSendConfiguration, Part, Role, Task, @@ -125,3 +126,78 @@ async def test_send_message_non_streaming_agent_capability_false( assert not mock_transport.send_message_streaming.called assert len(events) == 1 assert events[0][0].id == 'task-789' + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_non_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = False + mock_transport.send_message.return_value = Task( + id='task-cfg-ns-1', + context_id='ctx-cfg-ns-1', + status=TaskStatus(state=TaskState.completed), + ) + + cfg = MessageSendConfiguration( + history_length=2, + blocking=False, + accepted_output_modes=['application/json'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-ns-1' + + params = mock_transport.send_message.call_args[0][0] + assert params.configuration.history_length == 2 + assert params.configuration.blocking is False + assert params.configuration.accepted_output_modes == ['application/json'] + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = True + base_client._card.capabilities.streaming = True + + async def create_stream(*args, **kwargs): + yield Task( + id='task-cfg-s-1', + context_id='ctx-cfg-s-1', + status=TaskStatus(state=TaskState.completed), + ) + + mock_transport.send_message_streaming.return_value = create_stream() + + cfg = MessageSendConfiguration( + history_length=0, + blocking=True, + accepted_output_modes=['text/plain'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-s-1' + + params = mock_transport.send_message_streaming.call_args[0][0] + assert params.configuration.history_length == 0 + assert params.configuration.blocking is True + assert params.configuration.accepted_output_modes == ['text/plain']