Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version: v2
inputs:
- git_repo: https://github.com/a2aproject/A2A.git
ref: main
ref: transports
subdir: specification/grpc
managed:
enabled: true
Expand All @@ -21,11 +21,11 @@ plugins:
# Generate python protobuf related code
# Generates *_pb2.py files, one for each .proto
- remote: buf.build/protocolbuffers/python:v29.3
out: src/a2a/grpc
out: src/a2a/types
# Generate python service code.
# Generates *_pb2_grpc.py
- remote: buf.build/grpc/python
out: src/a2a/grpc
out: src/a2a/types
# Generates *_pb2.pyi files.
- remote: buf.build/protocolbuffers/pyi
out: src/a2a/grpc
out: src/a2a/types
17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"pydantic>=2.11.3",
"protobuf>=5.29.5",
"google-api-core>=1.26.0",
"json-rpc>=1.15.0",
]

classifiers = [
Expand Down Expand Up @@ -114,7 +115,7 @@ explicit = true

[tool.mypy]
plugins = ["pydantic.mypy"]
exclude = ["src/a2a/grpc/"]
exclude = ["src/a2a/types/a2a_pb2\\.py", "src/a2a/types/a2a_pb2_grpc\\.py"]
disable_error_code = [
"import-not-found",
"annotation-unchecked",
Expand All @@ -134,7 +135,8 @@ exclude = [
"**/node_modules",
"**/venv",
"**/.venv",
"src/a2a/grpc/",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]
reportMissingImports = "none"
reportMissingModuleSource = "none"
Expand All @@ -145,7 +147,8 @@ omit = [
"*/tests/*",
"*/site-packages/*",
"*/__init__.py",
"src/a2a/grpc/*",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]

[tool.coverage.report]
Expand Down Expand Up @@ -257,7 +260,9 @@ exclude = [
"node_modules",
"venv",
"*/migrations/*",
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
"tests/**",
]

Expand Down Expand Up @@ -311,7 +316,9 @@ inline-quotes = "single"

[tool.ruff.format]
exclude = [
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
]
docstring-code-format = true
docstring-code-line-length = "dynamic"
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from a2a.client.auth.credentials import CredentialService
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types import (
from a2a.types.a2a_pb2 import (
AgentCard,
APIKeySecurityScheme,
HTTPAuthSecurityScheme,
Expand Down
131 changes: 68 additions & 63 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, AsyncGenerator
from typing import Any

from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
ClientEvent,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.errors import A2AClientInvalidStateError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types import (
from a2a.types.a2a_pb2 import (
AgentCard,
GetTaskPushNotificationConfigParams,
Message,
MessageSendConfiguration,
MessageSendParams,
SendMessageConfiguration,
SendMessageRequest,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
SubscribeToTaskRequest,
CancelTaskRequest,
TaskPushNotificationConfig,
TaskQueryParams,
GetTaskRequest,
TaskStatusUpdateEvent,
StreamResponse,
SetTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
)


Expand All @@ -50,7 +54,7 @@ async def send_message(
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the agent.

This method handles both streaming and non-streaming (polling) interactions
Expand All @@ -64,9 +68,9 @@ async def send_message(
extensions: List of extensions to be activated.

Yields:
An async iterator of `ClientEvent` or a final `Message` response.
An async iterator of `ClientEvent`
"""
config = MessageSendConfiguration(
config = SendMessageConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
push_notification_config=(
Expand All @@ -75,67 +79,67 @@ async def send_message(
else None
),
)
params = MessageSendParams(
message=request, configuration=config, metadata=request_metadata
sendMessageRequest = SendMessageRequest(
request=request, configuration=config, metadata=request_metadata
)

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
params, context=context, extensions=extensions
)
result = (
(response, None) if isinstance(response, Task) else response
sendMessageRequest, context=context, extensions=extensions
)
await self.consume(result, self._card)
yield result
return

tracker = ClientTaskManager()
stream = self._transport.send_message_streaming(
params, context=context, extensions=extensions
)
# In non-streaming case we convert to a StreamResponse so that the
# client always sees the same iterator.
stream_response = StreamResponse()
client_event: ClientEvent
if response.HasField("task"):
stream_response.task = response.task
client_event = (stream_response, response.task)

first_event = await anext(stream)
# The response from a server may be either exactly one Message or a
# series of Task updates. Separate out the first message for special
# case handling, which allows us to simplify further stream processing.
if isinstance(first_event, Message):
await self.consume(first_event, self._card)
yield first_event
return
elif response.HasField("message"):
stream_response.msg = response.msg
client_event = (stream_response, None)

yield await self._process_response(tracker, first_event)
await self.consume(client_event, self._card)
yield client_event
return

async for event in stream:
yield await self._process_response(tracker, event)
stream = self._transport.send_message_streaming(
sendMessageRequest, context=context, extensions=extensions
)
async for client_event in self._process_stream(stream):
yield client_event

async def _process_response(
self,
tracker: ClientTaskManager,
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
) -> ClientEvent:
if isinstance(event, Message):
raise A2AClientInvalidStateError(
'received a streamed Message from server after first response; this is not supported'
)
await tracker.process(event)
task = tracker.get_task_or_raise()
update = None if isinstance(event, Task) else event
client_event = (task, update)
await self.consume(client_event, self._card)
return client_event
async def _process_stream(self, stream: AsyncIterator[StreamResponse]) -> AsyncGenerator[ClientEvent]:
tracker = ClientTaskManager()
async for stream_response in stream:
client_event: ClientEvent
# When we get a message in the stream then we don't expect any
# further messages so yield and return
if stream_response.HasField("message"):
client_event = (stream_response, None)
await self.consume(client_event, self._card)
yield client_event
return

# Otherwise track the task / task update then yield to the client
await tracker.process(stream_response)
updated_task = tracker.get_task_or_raise()
client_event = (stream_response, updated_task)
await self.consume(client_event, self._card)
yield client_event

async def get_task(
self,
request: TaskQueryParams,
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.

Args:
request: The `TaskQueryParams` object specifying the task ID.
request: The `GetTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.

Expand All @@ -148,15 +152,15 @@ async def get_task(

async def cancel_task(
self,
request: TaskIdParams,
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.

Args:
request: The `TaskIdParams` object specifying the task ID.
request: The `CancelTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.

Expand All @@ -169,7 +173,7 @@ async def cancel_task(

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
request: SetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
Expand All @@ -190,7 +194,7 @@ async def set_task_callback(

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
Expand All @@ -209,9 +213,9 @@ async def get_task_callback(
request, context=context, extensions=extensions
)

async def resubscribe(
async def subscribe(
self,
request: TaskIdParams,
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
Expand Down Expand Up @@ -240,12 +244,13 @@ async def resubscribe(
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
async for event in self._transport.resubscribe(
stream = self._transport.subscribe(
request, context=context, extensions=extensions
):
yield await self._process_response(tracker, event)
)
async for client_event in self._process_stream(stream):
yield client_event

async def get_card(
async def get_extended_agent_card(
self,
*,
context: ClientCallContext | None = None,
Expand All @@ -263,7 +268,7 @@ async def get_card(
Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(
card = await self._transport.get_extended_agent_card(
context=context, extensions=extensions
)
self._card = card
Expand Down
5 changes: 3 additions & 2 deletions src/a2a/client/card_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from pydantic import ValidationError

from google.protobuf.json_format import ParseDict
from a2a.client.errors import (
A2AClientHTTPError,
A2AClientJSONError,
)
from a2a.types import (
from a2a.types.a2a_pb2 import (
AgentCard,
)
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
Expand Down Expand Up @@ -85,7 +86,7 @@ async def get_agent_card(
target_url,
agent_card_data,
)
agent_card = AgentCard.model_validate(agent_card_data)
agent_card = ParseDict(agent_card_data, AgentCard())
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(
e.response.status_code,
Expand Down
Loading