diff --git a/agent_sdks/python/src/a2ui/a2a.py b/agent_sdks/python/src/a2ui/a2a.py index 9c5e25cf7..143c971e3 100644 --- a/agent_sdks/python/src/a2ui/a2a.py +++ b/agent_sdks/python/src/a2ui/a2a.py @@ -16,11 +16,17 @@ from typing import Any, Optional, List from a2a.server.agent_execution import RequestContext -from a2a.types import AgentExtension, Part, DataPart, TextPart +from a2a.types import ( + AgentExtension, + AgentCard, + Part, + DataPart, + TextPart, +) logger = logging.getLogger(__name__) -A2UI_EXTENSION_URI = "https://a2ui.org/a2a-extension/a2ui/v0.8" +A2UI_EXTENSION_BASE_URI = "https://a2ui.org/a2a-extension/a2ui" AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY = "supportedCatalogIds" AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY = "acceptsInlineCatalogs" @@ -78,12 +84,14 @@ def get_a2ui_datapart(part: Part) -> Optional[DataPart]: def get_a2ui_agent_extension( + version: str, accepts_inline_catalogs: bool = False, supported_catalog_ids: List[str] = [], ) -> AgentExtension: """Creates the A2UI AgentExtension configuration. Args: + version: The version of the A2UI extension to use. accepts_inline_catalogs: Whether the agent accepts inline catalogs. supported_catalog_ids: All pre-defined catalogs the agent is known to support. @@ -100,7 +108,7 @@ def get_a2ui_agent_extension( params[AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY] = supported_catalog_ids return AgentExtension( - uri=A2UI_EXTENSION_URI, + uri=f"{A2UI_EXTENSION_BASE_URI}/v{version}", description="Provides agent driven UI using the A2UI JSON format.", params=params if params else None, ) @@ -151,20 +159,70 @@ def parse_response_to_parts( return parts -def try_activate_a2ui_extension(context: RequestContext) -> bool: +def _agent_extensions(agent_card: AgentCard) -> List[str]: + """Returns the A2UI extension URIs supported by the agent.""" + extensions = [] + if ( + agent_card + and hasattr(agent_card, "capabilities") + and agent_card.capabilities + and hasattr(agent_card.capabilities, "extensions") + and agent_card.capabilities.extensions + ): + for ext in agent_card.capabilities.extensions: + if ext.uri and ext.uri.startswith(A2UI_EXTENSION_BASE_URI): + extensions.append(ext.uri) + return extensions + + +def _requested_a2ui_extensions(context: RequestContext) -> List[str]: + """Returns the A2UI extension URIs requested by the client.""" + requested_extensions = [] + if hasattr(context, "requested_extensions") and context.requested_extensions: + requested_extensions.extend([ + ext + for ext in context.requested_extensions + if isinstance(ext, str) and ext.startswith(A2UI_EXTENSION_BASE_URI) + ]) + + if ( + hasattr(context, "message") + and context.message + and hasattr(context.message, "extensions") + and context.message.extensions + ): + requested_extensions.extend([ + ext + for ext in context.message.extensions + if isinstance(ext, str) and ext.startswith(A2UI_EXTENSION_BASE_URI) + ]) + + return requested_extensions + + +def try_activate_a2ui_extension( + context: RequestContext, agent_card: AgentCard +) -> Optional[str]: """Activates the A2UI extension if requested. Args: context: The request context to check. + agent_card: The agent card to check supported extensions. Returns: - True if activated, False otherwise. + The version string of the activated A2UI extension, or None if not activated. """ - if A2UI_EXTENSION_URI in context.requested_extensions or ( - context.message - and context.message.extensions - and A2UI_EXTENSION_URI in context.message.extensions - ): - context.add_activated_extension(A2UI_EXTENSION_URI) - return True - return False + requested_extensions = _requested_a2ui_extensions(context) + if not requested_extensions: + return None + + agent_advertised_extensions = _agent_extensions(agent_card) + if not agent_advertised_extensions: + return None + + for req_uri in requested_extensions: + if req_uri in agent_advertised_extensions: + context.add_activated_extension(req_uri) + return req_uri.replace(f"{A2UI_EXTENSION_BASE_URI}/v", "") + + return None diff --git a/agent_sdks/python/src/a2ui/adk/a2a_extension/send_a2ui_to_client_toolset.py b/agent_sdks/python/src/a2ui/adk/a2a_extension/send_a2ui_to_client_toolset.py index 1e37b8aa3..8803b71c2 100644 --- a/agent_sdks/python/src/a2ui/adk/a2a_extension/send_a2ui_to_client_toolset.py +++ b/agent_sdks/python/src/a2ui/adk/a2a_extension/send_a2ui_to_client_toolset.py @@ -105,7 +105,6 @@ async def get_examples(ctx: ReadonlyContext) -> str: from a2a import types as a2a_types from a2ui.a2a import ( - A2UI_EXTENSION_URI, create_a2ui_part, parse_response_to_parts, ) diff --git a/agent_sdks/python/tests/test_a2a.py b/agent_sdks/python/tests/test_a2a.py index 7b18b6c55..c311a3801 100644 --- a/agent_sdks/python/tests/test_a2a.py +++ b/agent_sdks/python/tests/test_a2a.py @@ -51,17 +51,19 @@ def test_non_a2ui_part(): def test_get_a2ui_agent_extension(): - agent_extension = get_a2ui_agent_extension() - assert agent_extension.uri == A2UI_EXTENSION_URI + version = "0.8" + agent_extension = get_a2ui_agent_extension(version) + assert agent_extension.uri == f"{A2UI_EXTENSION_BASE_URI}/v{version}" assert agent_extension.params is None def test_get_a2ui_agent_extension_with_accepts_inline_catalogs(): + version = "0.8" accepts_inline_catalogs = True agent_extension = get_a2ui_agent_extension( - accepts_inline_catalogs=accepts_inline_catalogs + version, accepts_inline_catalogs=accepts_inline_catalogs ) - assert agent_extension.uri == A2UI_EXTENSION_URI + assert agent_extension.uri == f"{A2UI_EXTENSION_BASE_URI}/v{version}" assert agent_extension.params is not None assert ( agent_extension.params.get(AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY) @@ -70,11 +72,12 @@ def test_get_a2ui_agent_extension_with_accepts_inline_catalogs(): def test_get_a2ui_agent_extension_with_supported_catalog_ids(): + version = "0.8" supported_catalog_ids = ["a", "b", "c"] agent_extension = get_a2ui_agent_extension( - supported_catalog_ids=supported_catalog_ids + version, supported_catalog_ids=supported_catalog_ids ) - assert agent_extension.uri == A2UI_EXTENSION_URI + assert agent_extension.uri == f"{A2UI_EXTENSION_BASE_URI}/v{version}" assert agent_extension.params is not None assert ( agent_extension.params.get(AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY) @@ -84,15 +87,26 @@ def test_get_a2ui_agent_extension_with_supported_catalog_ids(): def test_try_activate_a2ui_extension(): context = MagicMock(spec=RequestContext) - context.requested_extensions = [A2UI_EXTENSION_URI] + uri = f"{A2UI_EXTENSION_BASE_URI}/v0.8" + context.requested_extensions = [uri] - assert try_activate_a2ui_extension(context) - context.add_activated_extension.assert_called_once_with(A2UI_EXTENSION_URI) + card = MagicMock() + ext = MagicMock() + ext.uri = uri + card.capabilities.extensions = [ext] + + assert try_activate_a2ui_extension(context, card) == "0.8" + context.add_activated_extension.assert_called_once_with(uri) def test_try_activate_a2ui_extension_not_requested(): context = MagicMock(spec=RequestContext) context.requested_extensions = [] - assert not try_activate_a2ui_extension(context) + card = MagicMock() + ext = MagicMock() + ext.uri = f"{A2UI_EXTENSION_BASE_URI}/v0.8" + card.capabilities.extensions = [ext] + + assert try_activate_a2ui_extension(context, card) is None context.add_activated_extension.assert_not_called() diff --git a/samples/agent/adk/component_gallery/__main__.py b/samples/agent/adk/component_gallery/__main__.py index 050c53446..15d5d95fc 100644 --- a/samples/agent/adk/component_gallery/__main__.py +++ b/samples/agent/adk/component_gallery/__main__.py @@ -28,6 +28,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.staticfiles import StaticFiles from dotenv import load_dotenv +from a2ui.core.schema.constants import VERSION_0_8 from agent_executor import ComponentGalleryExecutor @@ -49,9 +50,12 @@ @click.option("--port", default=10005) def main(host, port): try: + extensions = [] + for v in [VERSION_0_8]: + extensions.append(get_a2ui_agent_extension(v)) capabilities = AgentCapabilities( streaming=True, - extensions=[get_a2ui_agent_extension()], + extensions=extensions, ) # Skill definition @@ -76,7 +80,7 @@ def main(host, port): skills=[skill], ) - agent_executor = ComponentGalleryExecutor(base_url=base_url) + agent_executor = ComponentGalleryExecutor(base_url=base_url, agent_card=agent_card) request_handler = DefaultRequestHandler( agent_executor=agent_executor, diff --git a/samples/agent/adk/component_gallery/a2ui_schema.py b/samples/agent/adk/component_gallery/a2ui_schema.py deleted file mode 100644 index 29c72d1d8..000000000 --- a/samples/agent/adk/component_gallery/a2ui_schema.py +++ /dev/null @@ -1,792 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# a2ui_schema.py - -A2UI_SCHEMA = r""" -{ - "title": "A2UI Message Schema", - "description": "Describes a JSON payload for an A2UI (Agent to UI) message, which is used to dynamically construct and update user interfaces. A message MUST contain exactly ONE of the action properties: 'beginRendering', 'surfaceUpdate', 'dataModelUpdate', or 'deleteSurface'.", - "type": "object", - "properties": { - "beginRendering": { - "type": "object", - "description": "Signals the client to begin rendering a surface with a root component and specific styles.", - "properties": { - "surfaceId": { - "type": "string", - "description": "The unique identifier for the UI surface to be rendered." - }, - "root": { - "type": "string", - "description": "The ID of the root component to render." - }, - "styles": { - "type": "object", - "description": "Styling information for the UI.", - "properties": { - "font": { - "type": "string", - "description": "The primary font for the UI." - }, - "primaryColor": { - "type": "string", - "description": "The primary UI color as a hexadecimal code (e.g., '#00BFFF').", - "pattern": "^#[0-9a-fA-F]{6}$" - } - } - } - }, - "required": ["root", "surfaceId"] - }, - "surfaceUpdate": { - "type": "object", - "description": "Updates a surface with a new set of components.", - "properties": { - "surfaceId": { - "type": "string", - "description": "The unique identifier for the UI surface to be updated. If you are adding a new surface this *must* be a new, unique identified that has never been used for any existing surfaces shown." - }, - "components": { - "type": "array", - "description": "A list containing all UI components for the surface.", - "minItems": 1, - "items": { - "type": "object", - "description": "Represents a *single* component in a UI widget tree. This component could be one of many supported types.", - "properties": { - "id": { - "type": "string", - "description": "The unique identifier for this component." - }, - "weight": { - "type": "number", - "description": "The relative weight of this component within a Row or Column. This corresponds to the CSS 'flex-grow' property. Note: this may ONLY be set when the component is a direct descendant of a Row or Column." - }, - "component": { - "type": "object", - "description": "A wrapper object that MUST contain exactly one key, which is the name of the component type (e.g., 'Heading'). The value is an object containing the properties for that specific component.", - "properties": { - "Text": { - "type": "object", - "properties": { - "text": { - "type": "object", - "description": "The text content to display. This can be a literal string or a reference to a value in the data model ('path', e.g., '/doc/title'). While simple Markdown formatting is supported (i.e. without HTML, images, or links), utilizing dedicated UI components is generally preferred for a richer and more structured presentation.", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "usageHint": { - "type": "string", - "description": "A hint for the base text style. One of:\n- `h1`: Largest heading.\n- `h2`: Second largest heading.\n- `h3`: Third largest heading.\n- `h4`: Fourth largest heading.\n- `h5`: Fifth largest heading.\n- `caption`: Small text for captions.\n- `body`: Standard body text.", - "enum": [ - "h1", - "h2", - "h3", - "h4", - "h5", - "caption", - "body" - ] - } - }, - "required": ["text"] - }, - "Image": { - "type": "object", - "properties": { - "url": { - "type": "object", - "description": "The URL of the image to display. This can be a literal string ('literal') or a reference to a value in the data model ('path', e.g. '/thumbnail/url').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "fit": { - "type": "string", - "description": "Specifies how the image should be resized to fit its container. This corresponds to the CSS 'object-fit' property.", - "enum": [ - "contain", - "cover", - "fill", - "none", - "scale-down" - ] - }, - "usageHint": { - "type": "string", - "description": "A hint for the image size and style. One of:\n- `icon`: Small square icon.\n- `avatar`: Circular avatar image.\n- `smallFeature`: Small feature image.\n- `mediumFeature`: Medium feature image.\n- `largeFeature`: Large feature image.\n- `header`: Full-width, full bleed, header image.", - "enum": [ - "icon", - "avatar", - "smallFeature", - "mediumFeature", - "largeFeature", - "header" - ] - } - }, - "required": ["url"] - }, - "Icon": { - "type": "object", - "properties": { - "name": { - "type": "object", - "description": "The name of the icon to display. This can be a literal string or a reference to a value in the data model ('path', e.g. '/form/submit').", - "properties": { - "literalString": { - "type": "string", - "enum": [ - "accountCircle", - "add", - "arrowBack", - "arrowForward", - "attachFile", - "calendarToday", - "call", - "camera", - "check", - "close", - "delete", - "download", - "edit", - "event", - "error", - "favorite", - "favoriteOff", - "folder", - "help", - "home", - "info", - "locationOn", - "lock", - "lockOpen", - "mail", - "menu", - "moreVert", - "moreHoriz", - "notificationsOff", - "notifications", - "payment", - "person", - "phone", - "photo", - "print", - "refresh", - "search", - "send", - "settings", - "share", - "shoppingCart", - "star", - "starHalf", - "starOff", - "upload", - "visibility", - "visibilityOff", - "warning" - ] - }, - "path": { - "type": "string" - } - } - } - }, - "required": ["name"] - }, - "Video": { - "type": "object", - "properties": { - "url": { - "type": "object", - "description": "The URL of the video to display. This can be a literal string or a reference to a value in the data model ('path', e.g. '/video/url').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - } - }, - "required": ["url"] - }, - "AudioPlayer": { - "type": "object", - "properties": { - "url": { - "type": "object", - "description": "The URL of the audio to be played. This can be a literal string ('literal') or a reference to a value in the data model ('path', e.g. '/song/url').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "description": { - "type": "object", - "description": "A description of the audio, such as a title or summary. This can be a literal string or a reference to a value in the data model ('path', e.g. '/song/title').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - } - }, - "required": ["url"] - }, - "Row": { - "type": "object", - "properties": { - "children": { - "type": "object", - "description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.", - "properties": { - "explicitList": { - "type": "array", - "items": { - "type": "string" - } - }, - "template": { - "type": "object", - "description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.", - "properties": { - "componentId": { - "type": "string" - }, - "dataBinding": { - "type": "string" - } - }, - "required": ["componentId", "dataBinding"] - } - } - }, - "distribution": { - "type": "string", - "description": "Defines the arrangement of children along the main axis (horizontally). This corresponds to the CSS 'justify-content' property.", - "enum": [ - "center", - "end", - "spaceAround", - "spaceBetween", - "spaceEvenly", - "start" - ] - }, - "alignment": { - "type": "string", - "description": "Defines the alignment of children along the cross axis (vertically). This corresponds to the CSS 'align-items' property.", - "enum": ["start", "center", "end", "stretch"] - } - }, - "required": ["children"] - }, - "Column": { - "type": "object", - "properties": { - "children": { - "type": "object", - "description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.", - "properties": { - "explicitList": { - "type": "array", - "items": { - "type": "string" - } - }, - "template": { - "type": "object", - "description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.", - "properties": { - "componentId": { - "type": "string" - }, - "dataBinding": { - "type": "string" - } - }, - "required": ["componentId", "dataBinding"] - } - } - }, - "distribution": { - "type": "string", - "description": "Defines the arrangement of children along the main axis (vertically). This corresponds to the CSS 'justify-content' property.", - "enum": [ - "start", - "center", - "end", - "spaceBetween", - "spaceAround", - "spaceEvenly" - ] - }, - "alignment": { - "type": "string", - "description": "Defines the alignment of children along the cross axis (horizontally). This corresponds to the CSS 'align-items' property.", - "enum": ["center", "end", "start", "stretch"] - } - }, - "required": ["children"] - }, - "List": { - "type": "object", - "properties": { - "children": { - "type": "object", - "description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.", - "properties": { - "explicitList": { - "type": "array", - "items": { - "type": "string" - } - }, - "template": { - "type": "object", - "description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.", - "properties": { - "componentId": { - "type": "string" - }, - "dataBinding": { - "type": "string" - } - }, - "required": ["componentId", "dataBinding"] - } - } - }, - "direction": { - "type": "string", - "description": "The direction in which the list items are laid out.", - "enum": ["vertical", "horizontal"] - }, - "alignment": { - "type": "string", - "description": "Defines the alignment of children along the cross axis.", - "enum": ["start", "center", "end", "stretch"] - } - }, - "required": ["children"] - }, - "Card": { - "type": "object", - "properties": { - "child": { - "type": "string", - "description": "The ID of the component to be rendered inside the card." - } - }, - "required": ["child"] - }, - "Tabs": { - "type": "object", - "properties": { - "tabItems": { - "type": "array", - "description": "An array of objects, where each object defines a tab with a title and a child component.", - "items": { - "type": "object", - "properties": { - "title": { - "type": "object", - "description": "The tab title. Defines the value as either a literal value or a path to data model value (e.g. '/options/title').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "child": { - "type": "string" - } - }, - "required": ["title", "child"] - } - } - }, - "required": ["tabItems"] - }, - "Divider": { - "type": "object", - "properties": { - "axis": { - "type": "string", - "description": "The orientation of the divider.", - "enum": ["horizontal", "vertical"] - } - } - }, - "Modal": { - "type": "object", - "properties": { - "entryPointChild": { - "type": "string", - "description": "The ID of the component that opens the modal when interacted with (e.g., a button)." - }, - "contentChild": { - "type": "string", - "description": "The ID of the component to be displayed inside the modal." - } - }, - "required": ["entryPointChild", "contentChild"] - }, - "Button": { - "type": "object", - "properties": { - "child": { - "type": "string", - "description": "The ID of the component to display in the button, typically a Text component." - }, - "primary": { - "type": "boolean", - "description": "Indicates if this button should be styled as the primary action." - }, - "action": { - "type": "object", - "description": "The client-side action to be dispatched when the button is clicked. It includes the action's name and an optional context payload.", - "properties": { - "name": { - "type": "string" - }, - "context": { - "type": "array", - "items": { - "type": "object", - "properties": { - "key": { - "type": "string" - }, - "value": { - "type": "object", - "description": "Defines the value to be included in the context as either a literal value or a path to a data model value (e.g. '/user/name').", - "properties": { - "path": { - "type": "string" - }, - "literalString": { - "type": "string" - }, - "literalNumber": { - "type": "number" - }, - "literalBoolean": { - "type": "boolean" - } - } - } - }, - "required": ["key", "value"] - } - } - }, - "required": ["name"] - } - }, - "required": ["child", "action"] - }, - "CheckBox": { - "type": "object", - "properties": { - "label": { - "type": "object", - "description": "The text to display next to the checkbox. Defines the value as either a literal value or a path to data model ('path', e.g. '/option/label').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "value": { - "type": "object", - "description": "The current state of the checkbox (true for checked, false for unchecked). This can be a literal boolean ('literalBoolean') or a reference to a value in the data model ('path', e.g. '/filter/open').", - "properties": { - "literalBoolean": { - "type": "boolean" - }, - "path": { - "type": "string" - } - } - } - }, - "required": ["label", "value"] - }, - "TextField": { - "type": "object", - "properties": { - "label": { - "type": "object", - "description": "The text label for the input field. This can be a literal string or a reference to a value in the data model ('path, e.g. '/user/name').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "text": { - "type": "object", - "description": "The value of the text field. This can be a literal string or a reference to a value in the data model ('path', e.g. '/user/name').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "textFieldType": { - "type": "string", - "description": "The type of input field to display.", - "enum": [ - "date", - "longText", - "number", - "shortText", - "obscured" - ] - }, - "validationRegexp": { - "type": "string", - "description": "A regular expression used for client-side validation of the input." - } - }, - "required": ["label"] - }, - "DateTimeInput": { - "type": "object", - "properties": { - "value": { - "type": "object", - "description": "The selected date and/or time value. This can be a literal string ('literalString') or a reference to a value in the data model ('path', e.g. '/user/dob').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "enableDate": { - "type": "boolean", - "description": "If true, allows the user to select a date." - }, - "enableTime": { - "type": "boolean", - "description": "If true, allows the user to select a time." - }, - "outputFormat": { - "type": "string", - "description": "The desired format for the output string after a date or time is selected." - } - }, - "required": ["value"] - }, - "MultipleChoice": { - "type": "object", - "properties": { - "selections": { - "type": "object", - "description": "The currently selected values for the component. This can be a literal array of strings or a path to an array in the data model('path', e.g. '/hotel/options').", - "properties": { - "literalArray": { - "type": "array", - "items": { - "type": "string" - } - }, - "path": { - "type": "string" - } - } - }, - "options": { - "type": "array", - "description": "An array of available options for the user to choose from.", - "items": { - "type": "object", - "properties": { - "label": { - "type": "object", - "description": "The text to display for this option. This can be a literal string or a reference to a value in the data model (e.g. '/option/label').", - "properties": { - "literalString": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, - "value": { - "type": "string", - "description": "The value to be associated with this option when selected." - } - }, - "required": ["label", "value"] - } - }, - "maxAllowedSelections": { - "type": "integer", - "description": "The maximum number of options that the user is allowed to select." - } - }, - "required": ["selections", "options"] - }, - "Slider": { - "type": "object", - "properties": { - "value": { - "type": "object", - "description": "The current value of the slider. This can be a literal number ('literalNumber') or a reference to a value in the data model ('path', e.g. '/restaurant/cost').", - "properties": { - "literalNumber": { - "type": "number" - }, - "path": { - "type": "string" - } - } - }, - "minValue": { - "type": "number", - "description": "The minimum value of the slider." - }, - "maxValue": { - "type": "number", - "description": "The maximum value of the slider." - } - }, - "required": ["value"] - } - } - } - }, - "required": ["id", "component"] - } - } - }, - "required": ["surfaceId", "components"] - }, - "dataModelUpdate": { - "type": "object", - "description": "Updates the data model for a surface.", - "properties": { - "surfaceId": { - "type": "string", - "description": "The unique identifier for the UI surface this data model update applies to." - }, - "path": { - "type": "string", - "description": "An optional path to a location within the data model (e.g., '/user/name'). If omitted, or set to '/', the entire data model will be replaced." - }, - "contents": { - "type": "array", - "description": "An array of data entries. Each entry must contain a 'key' and exactly one corresponding typed 'value*' property.", - "items": { - "type": "object", - "description": "A single data entry. Exactly one 'value*' property should be provided alongside the key.", - "properties": { - "key": { - "type": "string", - "description": "The key for this data entry." - }, - "valueString": { - "type": "string" - }, - "valueNumber": { - "type": "number" - }, - "valueBoolean": { - "type": "boolean" - }, - "valueMap": { - "description": "Represents a map as an adjacency list.", - "type": "array", - "items": { - "type": "object", - "description": "One entry in the map. Exactly one 'value*' property should be provided alongside the key.", - "properties": { - "key": { - "type": "string" - }, - "valueString": { - "type": "string" - }, - "valueNumber": { - "type": "number" - }, - "valueBoolean": { - "type": "boolean" - } - }, - "required": ["key"] - } - } - }, - "required": ["key"] - } - } - }, - "required": ["contents", "surfaceId"] - }, - "deleteSurface": { - "type": "object", - "description": "Signals the client to delete the surface identified by 'surfaceId'.", - "properties": { - "surfaceId": { - "type": "string", - "description": "The unique identifier for the UI surface to be deleted." - } - }, - "required": ["surfaceId"] - } - } -} -""" diff --git a/samples/agent/adk/component_gallery/agent_executor.py b/samples/agent/adk/component_gallery/agent_executor.py index d0bed73c7..ce9c44232 100644 --- a/samples/agent/adk/component_gallery/agent_executor.py +++ b/samples/agent/adk/component_gallery/agent_executor.py @@ -19,7 +19,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.tasks import TaskUpdater -from a2a.types import (DataPart, Part, TaskState, TextPart) +from a2a.types import (DataPart, Part, TaskState, TextPart, AgentCard) from a2a.utils import new_agent_parts_message, new_task from agent import ComponentGalleryAgent from a2ui.a2a import try_activate_a2ui_extension @@ -29,14 +29,15 @@ class ComponentGalleryExecutor(AgentExecutor): - def __init__(self, base_url: str): + def __init__(self, base_url: str, agent_card: AgentCard): self.agent = ComponentGalleryAgent(base_url=base_url) + self._agent_card = agent_card async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: query = "START" # Default start ui_event_part = None - try_activate_a2ui_extension(context) + try_activate_a2ui_extension(context, self._agent_card) if context.message and context.message.parts: for part in context.message.parts: diff --git a/samples/agent/adk/contact_lookup/__main__.py b/samples/agent/adk/contact_lookup/__main__.py index 69ac07eb0..11c2eb855 100644 --- a/samples/agent/adk/contact_lookup/__main__.py +++ b/samples/agent/adk/contact_lookup/__main__.py @@ -49,17 +49,16 @@ def main(host, port): ) base_url = f"http://{host}:{port}" - ui_agent = ContactAgent(base_url=base_url, use_ui=True) - text_agent = ContactAgent(base_url=base_url, use_ui=False) + agent = ContactAgent(base_url=base_url) - agent_executor = ContactAgentExecutor(ui_agent=ui_agent, text_agent=text_agent) + agent_executor = ContactAgentExecutor(agent=agent) request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=InMemoryTaskStore(), ) server = A2AStarletteApplication( - agent_card=ui_agent.get_agent_card(), http_handler=request_handler + agent_card=agent.agent_card, http_handler=request_handler ) import uvicorn diff --git a/samples/agent/adk/contact_lookup/agent.py b/samples/agent/adk/contact_lookup/agent.py index 2193d0c37..2f39b81d0 100644 --- a/samples/agent/adk/contact_lookup/agent.py +++ b/samples/agent/adk/contact_lookup/agent.py @@ -16,7 +16,8 @@ import logging import os from collections.abc import AsyncIterable -from typing import Any +from dataclasses import dataclass +from typing import Any, Dict, Optional import jsonschema @@ -53,38 +54,47 @@ class ContactAgent: SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] - def __init__(self, base_url: str, use_ui: bool = False): + def __init__(self, base_url: str): self.base_url = base_url - self.use_ui = use_ui - self._schema_manager = ( - A2uiSchemaManager( - version=VERSION_0_8, - catalogs=[ - BasicCatalog.get_config(version=VERSION_0_8, examples_path="examples") - ], - ) - if use_ui - else None - ) - self._agent = self._build_agent(use_ui) + self._agent_name = "contact_agent" self._user_id = "remote_agent" - self._runner = Runner( - app_name=self._agent.name, - agent=self._agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + self._text_runner: Optional[Runner] = self._build_runner(self._build_llm_agent()) + + self._schema_managers: Dict[str, A2uiSchemaManager] = {} + self._ui_runners: Dict[str, Runner] = {} + + for version in [VERSION_0_8]: + schema_manager = self._build_schema_manager(version) + self._schema_managers[version] = schema_manager + agent = self._build_llm_agent(schema_manager) + self._ui_runners[version] = self._build_runner(agent) + + self._agent_card = self._build_agent_card() + + @property + def agent_card(self) -> AgentCard: + return self._agent_card + + def _build_schema_manager(self, version: str) -> A2uiSchemaManager: + return A2uiSchemaManager( + version=version, + catalogs=[BasicCatalog.get_config(version=version, examples_path="examples")], ) - def get_agent_card(self) -> AgentCard: + def _build_agent_card(self) -> AgentCard: + extensions = [] + if self._schema_managers: + for version, sm in self._schema_managers.items(): + ext = get_a2ui_agent_extension( + version, + sm.accepts_inline_catalogs, + sm.supported_catalog_ids, + ) + extensions.append(ext) + capabilities = AgentCapabilities( streaming=True, - extensions=[ - get_a2ui_agent_extension( - self._schema_manager.accepts_inline_catalogs, - self._schema_manager.supported_catalog_ids, - ) - ], + extensions=extensions, ) skill = AgentSkill( id="find_contact", @@ -113,15 +123,26 @@ def get_agent_card(self) -> AgentCard: skills=[skill], ) + def _build_runner(self, agent: LlmAgent) -> Runner: + return Runner( + app_name=self._agent_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + def get_processing_message(self) -> str: return "Looking up contact information..." - def _build_agent(self, use_ui: bool) -> LlmAgent: + def _build_llm_agent( + self, schema_manager: Optional[A2uiSchemaManager] = None + ) -> LlmAgent: """Builds the LLM agent for the contact agent.""" LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") instruction = ( - self._schema_manager.generate_system_prompt( + schema_manager.generate_system_prompt( role_description=ROLE_DESCRIPTION, workflow_description=WORKFLOW_DESCRIPTION, ui_description=UI_DESCRIPTION, @@ -129,29 +150,43 @@ def _build_agent(self, use_ui: bool) -> LlmAgent: include_examples=True, validate_examples=False, # Use invalid examples to test retry logic ) - if use_ui + if schema_manager else get_text_prompt() ) return LlmAgent( model=LiteLlm(model=LITELLM_MODEL), - name="contact_agent", + name=self._agent_name, description="An agent that finds colleague contact info.", instruction=instruction, tools=[get_contact_info], ) - async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: + async def stream( + self, query, session_id, ui_version: Optional[str] = None + ) -> AsyncIterable[dict[str, Any]]: session_state = {"base_url": self.base_url} - session = await self._runner.session_service.get_session( - app_name=self._agent.name, + # Determine which runner to use based on whether the a2ui extension is active. + if ui_version: + runner = self._ui_runners[ui_version] + schema_manager = self._schema_managers[ui_version] + selected_catalog = ( + schema_manager.get_selected_catalog() if schema_manager else None + ) + else: + runner = self._text_runner + schema_manager = None + selected_catalog = None + + session = await runner.session_service.get_session( + app_name=self._agent_name, user_id=self._user_id, session_id=session_id, ) if session is None: - session = await self._runner.session_service.create_session( - app_name=self._agent.name, + session = await runner.session_service.create_session( + app_name=self._agent_name, user_id=self._user_id, state=session_state, session_id=session_id, @@ -165,8 +200,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: current_query_text = query # Ensure catalog schema was loaded - selected_catalog = self._schema_manager.get_selected_catalog() - if self.use_ui and not selected_catalog.catalog_schema: + if ui_version and (not selected_catalog or not selected_catalog.catalog_schema): logger.error( "--- ContactAgent.stream: A2UI_SCHEMA is not loaded. " "Cannot perform UI validation. ---" @@ -198,7 +232,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: ) final_response_content = None - async for event in self._runner.run_async( + async for event in runner.run_async( user_id=self._user_id, session_id=session.id, new_message=current_message, @@ -239,7 +273,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: is_valid = False error_message = "" - if self.use_ui: + if ui_version: logger.info( "--- ContactAgent.stream: Validating UI response (Attempt" f" {attempt})... ---" diff --git a/samples/agent/adk/contact_lookup/agent_executor.py b/samples/agent/adk/contact_lookup/agent_executor.py index f4e49aea5..f83df30fa 100644 --- a/samples/agent/adk/contact_lookup/agent_executor.py +++ b/samples/agent/adk/contact_lookup/agent_executor.py @@ -40,11 +40,8 @@ class ContactAgentExecutor(AgentExecutor): """Contact AgentExecutor Example.""" - def __init__(self, ui_agent: ContactAgent, text_agent: ContactAgent): - # Instantiate two agents: one for UI and one for text-only. - # The appropriate one will be chosen at execution time. - self.ui_agent = ui_agent - self.text_agent = text_agent + def __init__(self, agent: ContactAgent): + self._agent = agent async def execute( self, @@ -56,16 +53,16 @@ async def execute( action = None logger.info(f"--- Client requested extensions: {context.requested_extensions} ---") - use_ui = try_activate_a2ui_extension(context) + active_ui_version = try_activate_a2ui_extension(context, self._agent.agent_card) - # Determine which agent to use based on whether the a2ui extension is active. - if use_ui: - agent = self.ui_agent - logger.info("--- AGENT_EXECUTOR: A2UI extension is active. Using UI agent. ---") + if active_ui_version: + logger.info( + f"--- AGENT_EXECUTOR: A2UI extension is active (v{active_ui_version}). Using" + " UI runner. ---" + ) else: - agent = self.text_agent logger.info( - "--- AGENT_EXECUTOR: A2UI extension is not active. Using text agent. ---" + "--- AGENT_EXECUTOR: A2UI extension is not active. Using text runner. ---" ) if context.message and context.message.parts: @@ -127,7 +124,7 @@ async def execute( await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) - async for item in agent.stream(query, task.context_id): + async for item in self._agent.stream(query, task.context_id, active_ui_version): is_task_complete = item["is_task_complete"] if not is_task_complete: await updater.update_status( diff --git a/samples/agent/adk/contact_multiple_surfaces/__main__.py b/samples/agent/adk/contact_multiple_surfaces/__main__.py index cccd51e3b..537682315 100644 --- a/samples/agent/adk/contact_multiple_surfaces/__main__.py +++ b/samples/agent/adk/contact_multiple_surfaces/__main__.py @@ -50,7 +50,7 @@ def main(host, port): base_url = f"http://{host}:{port}" - agent = ContactAgent(base_url=base_url, use_ui=True) + agent = ContactAgent(base_url=base_url) agent_executor = ContactAgentExecutor(agent=agent) @@ -59,7 +59,7 @@ def main(host, port): task_store=InMemoryTaskStore(), ) server = A2AStarletteApplication( - agent_card=agent.get_agent_card(), http_handler=request_handler + agent_card=ui_agent.agent_card, http_handler=request_handler ) import uvicorn diff --git a/samples/agent/adk/contact_multiple_surfaces/agent.py b/samples/agent/adk/contact_multiple_surfaces/agent.py index 6967caaf7..4e46bdd03 100644 --- a/samples/agent/adk/contact_multiple_surfaces/agent.py +++ b/samples/agent/adk/contact_multiple_surfaces/agent.py @@ -16,37 +16,7 @@ import logging import os from collections.abc import AsyncIterable -from typing import Any - -import jsonschema -from a2ui_examples import load_floor_plan_example - -from a2a.types import ( - AgentCapabilities, - AgentCard, - AgentSkill, - DataPart, - Part, - TextPart, -) -from google.adk.agents.llm_agent import LlmAgent -from google.adk.artifacts import InMemoryArtifactService -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService -from google.adk.models.lite_llm import LiteLlm -from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService -from google.genai import types -from prompt_builder import ( - get_text_prompt, - ROLE_DESCRIPTION, - WORKFLOW_DESCRIPTION, - UI_DESCRIPTION, -) -from tools import get_contact_info -from a2ui.core.schema.constants import VERSION_0_8, A2UI_OPEN_TAG, A2UI_CLOSE_TAG -from a2ui.core.schema.manager import A2uiSchemaManager -from a2ui.core.parser.parser import parse_response, ResponsePart -from a2ui.basic_catalog.provider import BasicCatalog +from typing import Any, Dict, Optional from a2ui.core.schema.common_modifiers import remove_strict_validation from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts @@ -58,40 +28,49 @@ class ContactAgent: SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] - def __init__(self, base_url: str, use_ui: bool = False): + def __init__(self, base_url: str): self.base_url = base_url - self.use_ui = use_ui - self.schema_manager = ( - A2uiSchemaManager( - VERSION_0_8, - catalogs=[ - BasicCatalog.get_config(version=VERSION_0_8, examples_path="examples") - ], - schema_modifiers=[remove_strict_validation], - accepts_inline_catalogs=True, - ) - if use_ui - else None - ) - self._agent = self._build_agent(use_ui) + self._agent_name = "contact_agent" self._user_id = "remote_agent" - self._runner = Runner( - app_name=self._agent.name, - agent=self._agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + self._text_runner: Optional[Runner] = self._build_runner(self._build_llm_agent()) + + self._schema_managers: Dict[str, A2uiSchemaManager] = {} + self._ui_runners: Dict[str, Runner] = {} + + for version in [VERSION_0_8]: + schema_manager = self._build_schema_manager(version) + self._schema_managers[version] = schema_manager + agent = self._build_llm_agent(schema_manager) + self._ui_runners[version] = self._build_runner(agent) + + self._agent_card = self._build_agent_card() + + @property + def agent_card(self) -> AgentCard: + return self._agent_card + + def _build_schema_manager(self, version: str) -> A2uiSchemaManager: + return A2uiSchemaManager( + version=version, + catalogs=[BasicCatalog.get_config(version=version, examples_path="examples")], + schema_modifiers=[remove_strict_validation], + accepts_inline_catalogs=True, ) - def get_agent_card(self) -> AgentCard: + def _build_agent_card(self) -> AgentCard: + extensions = [] + if self._schema_managers: + for version, sm in self._schema_managers.items(): + ext = get_a2ui_agent_extension( + version, + sm.accepts_inline_catalogs, + sm.supported_catalog_ids, + ) + extensions.append(ext) + capabilities = AgentCapabilities( streaming=True, - extensions=[ - get_a2ui_agent_extension( - self.schema_manager.accepts_inline_catalogs, - self.schema_manager.supported_catalog_ids, - ) - ], + extensions=extensions, ) skill = AgentSkill( id="find_contact", @@ -120,29 +99,40 @@ def get_agent_card(self) -> AgentCard: skills=[skill], ) + def _build_runner(self, agent: LlmAgent) -> Runner: + return Runner( + app_name=self._agent_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + def get_processing_message(self) -> str: return "Looking up contact information..." - def _build_agent(self, use_ui: bool) -> LlmAgent: + def _build_llm_agent( + self, schema_manager: Optional[A2uiSchemaManager] = None + ) -> LlmAgent: """Builds the LLM agent for the contact agent.""" LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") instruction = ( - self.schema_manager.generate_system_prompt( + schema_manager.generate_system_prompt( role_description=ROLE_DESCRIPTION, workflow_description=WORKFLOW_DESCRIPTION, ui_description=UI_DESCRIPTION, include_examples=True, include_schema=True, - validate_examples=False, # Missing inline_catalogs for OrgChart and WebFrame validation + validate_examples=False, ) - if use_ui + if schema_manager else get_text_prompt() ) return LlmAgent( model=LiteLlm(model=LITELLM_MODEL), - name="contact_agent", + name=self._agent_name, description="An agent that finds colleague contact info.", instruction=instruction, tools=[get_contact_info], @@ -233,18 +223,34 @@ async def _handle_action(self, query: str) -> dict[str, Any] | None: return None async def stream( - self, query, session_id, client_ui_capabilities: dict[str, Any] | None = None + self, + query, + session_id, + client_ui_capabilities: dict[str, Any] | None = None, + ui_version: Optional[str] = None, ) -> AsyncIterable[dict[str, Any]]: session_state = {"base_url": self.base_url} - session = await self._runner.session_service.get_session( - app_name=self._agent.name, + # Determine which runner to use based on whether the a2ui extension is active. + if ui_version: + runner = self._ui_runners[ui_version] + schema_manager = self._schema_managers[ui_version] + selected_catalog = ( + schema_manager.get_selected_catalog() if schema_manager else None + ) + else: + runner = self._text_runner + schema_manager = None + selected_catalog = None + + session = await runner.session_service.get_session( + app_name=self._agent_name, user_id=self._user_id, session_id=session_id, ) if session is None: - session = await self._runner.session_service.create_session( - app_name=self._agent.name, + session = await runner.session_service.create_session( + app_name=self._agent_name, user_id=self._user_id, state=session_state, session_id=session_id, @@ -257,9 +263,51 @@ async def stream( attempt = 0 current_query_text = query - # Ensure schema was loaded - selected_catalog = self.schema_manager.get_selected_catalog(client_ui_capabilities) - if self.use_ui and not selected_catalog.catalog_schema: + if not ui_version: + # For non-UI Requests + while attempt <= max_retries: + attempt += 1 + logger.info( + f"--- ContactAgent.stream: Attempt {attempt}/{max_retries + 1} " + f"for session {session_id} (Text-only mode) ---" + ) + current_message = types.Content( + role="user", parts=[types.Part.from_text(text=current_query_text)] + ) + final_response_content = None + + async for event in runner.run_async( + user_id=self._user_id, + session_id=session.id, + new_message=current_message, + ): + if event.is_final_response(): + if event.content and event.content.parts and event.content.parts[0].text: + final_response_content = "\n".join( + [p.text for p in event.content.parts if p.text] + ) + break + + if final_response_content: + yield { + "is_task_complete": True, + "parts": [Part(root=TextPart(text=final_response_content))], + } + return + + yield { + "is_task_complete": True, + "parts": [ + Part( + root=TextPart( + text="I encountered an error and couldn't process your request." + ) + ) + ], + } + return + + if ui_version and (not selected_catalog or not selected_catalog.catalog_schema): logger.error( "--- ContactAgent.stream: A2UI_SCHEMA is not loaded. " "Cannot perform UI validation. ---" @@ -292,7 +340,7 @@ async def stream( ) final_response_content = None - async for event in self._runner.run_async( + async for event in runner.run_async( user_id=self._user_id, session_id=session.id, new_message=current_message, @@ -322,18 +370,16 @@ async def stream( "I received no response. Please try again." f"Please retry the original request: '{query}'" ) - continue # Go to next retry + continue else: - # Retries exhausted on no-response final_response_content = ( "I'm sorry, I encountered an error and couldn't process your request." ) - # Fall through to send this as a text-only error is_valid = False error_message = "" - if self.use_ui: + if ui_version: logger.info( "--- ContactAgent.stream: Validating UI response (Attempt" f" {attempt})... ---" @@ -432,4 +478,3 @@ async def stream( ) ], } - # --- End: UI Validation and Retry Logic --- diff --git a/samples/agent/adk/contact_multiple_surfaces/agent_executor.py b/samples/agent/adk/contact_multiple_surfaces/agent_executor.py index 82ce090fb..a53042f77 100644 --- a/samples/agent/adk/contact_multiple_surfaces/agent_executor.py +++ b/samples/agent/adk/contact_multiple_surfaces/agent_executor.py @@ -42,8 +42,7 @@ class ContactAgentExecutor(AgentExecutor): """Contact AgentExecutor Example.""" def __init__(self, agent: ContactAgent): - # Instantiate the UI agent. - self.ui_agent = agent + self._agent = agent async def execute( self, @@ -56,17 +55,17 @@ async def execute( client_ui_capabilities = None logger.info(f"--- Client requested extensions: {context.requested_extensions} ---") - use_ui = try_activate_a2ui_extension(context) + active_ui_version = try_activate_a2ui_extension(context, self.ui_agent.agent_card) - # Determine which agent to use based on whether the a2ui extension is active. - if use_ui: - agent = self.ui_agent - logger.info("--- AGENT_EXECUTOR: A2UI extension is active. Using UI agent. ---") + if active_ui_version: + logger.info( + f"--- AGENT_EXECUTOR: A2UI extension is active (v{active_ui_version}). Using" + " UI runner. ---" + ) else: - # Enforce A2UI extension as per review comment - error_msg = "A2UI extension is NOT active. This agent requires A2UI to function." - logger.error(f"--- AGENT_EXECUTOR: {error_msg} ---") - raise ServerError(error=UnsupportedOperationError(error_msg)) + logger.info( + "--- AGENT_EXECUTOR: A2UI extension is not active. Using text runner. ---" + ) if context.message and context.message.parts: logger.info( @@ -167,7 +166,9 @@ async def execute( await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) - async for item in agent.stream(query, task.context_id, client_ui_capabilities): + async for item in self._agent.stream( + query, task.context_id, client_ui_capabilities, active_ui_version + ): is_task_complete = item["is_task_complete"] if not is_task_complete: await updater.update_status( diff --git a/samples/agent/adk/mcp_app_proxy/__main__.py b/samples/agent/adk/mcp_app_proxy/__main__.py index eea7e22b0..ca3ea29ff 100644 --- a/samples/agent/adk/mcp_app_proxy/__main__.py +++ b/samples/agent/adk/mcp_app_proxy/__main__.py @@ -62,126 +62,13 @@ def main(host, port): lite_llm_model = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") base_url = f"http://{host}:{port}" - schema_manager = A2uiSchemaManager( - VERSION_0_8, - catalogs=[ - CatalogConfig.from_path( - name="mcp_app_proxy", - catalog_path="mcp_app_catalog.json", - ), - ], - accepts_inline_catalogs=True, - ) - - # Define get_calculator_app tool in a way that the LlmAgent can use. - async def get_calculator_app(tool_context: ToolContext): - """Fetches the calculator app.""" - # Connect to the MCP server via SSE - mcp_server_host = os.getenv("MCP_SERVER_HOST", "localhost") - mcp_server_port = os.getenv("MCP_SERVER_PORT", "8000") - sse_url = f"http://{mcp_server_host}:{mcp_server_port}/sse" - - try: - async with sse_client(sse_url) as streams: - async with ClientSession(streams[0], streams[1]) as session: - await session.initialize() - - # Read the resource - result = await session.read_resource("ui://calculator/app") - - # Package the resource as an A2UI message - if result.contents and hasattr(result.contents[0], "text"): - html_content = result.contents[0].text - encoded_html = "url_encoded:" + urllib.parse.quote(html_content) - messages = [ - { - "beginRendering": { - "surfaceId": "calculator_surface", - "root": "calculator_app_root", - }, - }, - { - "surfaceUpdate": { - "surfaceId": "calculator_surface", - "components": [{ - "id": "calculator_app_root", - "component": { - "McpApp": { - "content": {"literalString": encoded_html}, - "title": {"literalString": "Calculator"}, - "allowedTools": ["calculate"], - } - }, - }], - }, - }, - ] - tool_context.actions.skip_summarization = True - return {"validated_a2ui_json": messages} - else: - logger.error("Failed to get text content from resource") - return {"error": "Could not fetch calculator app content."} - - except Exception as e: - logger.error(f"Error fetching calculator app: {e} {traceback.format_exc()}") - return {"error": f"Failed to connect to MCP server or fetch app. Details: {e}"} - - async def calculate_via_mcp(operation: str, a: float, b: float): - """Calculates via the MCP server's Calculate tool. - - Args: - operation: The mathematical operation (e.g. 'add', 'subtract', 'multiply', 'divide'). - a: First operand. - b: Second operand. - """ - mcp_server_host = os.getenv("MCP_SERVER_HOST", "localhost") - mcp_server_port = os.getenv("MCP_SERVER_PORT", "8000") - sse_url = f"http://{mcp_server_host}:{mcp_server_port}/sse" - - try: - async with sse_client(sse_url) as streams: - async with ClientSession(streams[0], streams[1]) as session: - await session.initialize() - - result = await session.call_tool( - "calculate", arguments={"operation": operation, "a": a, "b": b} - ) - - if ( - result.content - and len(result.content) > 0 - and hasattr(result.content[0], "text") - ): - return result.content[0].text - return "No result text from MCP calculate tool." - except Exception as e: - logger.error(f"Error calling MCP calculate: {e} {traceback.format_exc()}") - return f"Error connecting to MCP server: {e}" - - tools = [get_calculator_app, calculate_via_mcp] - agent = McpAppProxyAgent( - base_url=base_url, model=LiteLlm(model=lite_llm_model), - schema_manager=schema_manager, - a2ui_enabled_provider=get_a2ui_enabled, - a2ui_catalog_provider=get_a2ui_catalog, - a2ui_examples_provider=get_a2ui_examples, - tools=tools, - ) - - runner = Runner( - app_name=agent.name, - agent=agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + base_url=base_url, ) - agent_executor = McpAppProxyAgentExecutor( base_url=base_url, - runner=runner, - schema_manager=schema_manager, + agent=agent, ) request_handler = DefaultRequestHandler( diff --git a/samples/agent/adk/mcp_app_proxy/agent.py b/samples/agent/adk/mcp_app_proxy/agent.py index 7ef04d594..9c8789d34 100644 --- a/samples/agent/adk/mcp_app_proxy/agent.py +++ b/samples/agent/adk/mcp_app_proxy/agent.py @@ -24,6 +24,8 @@ from google.adk.planners.built_in_planner import BuiltInPlanner from google.genai import types from pydantic import PrivateAttr +from .tools import get_calculator_app, calculate_via_mcp +from .a2ui_providers import get_a2ui_enabled, get_a2ui_catalog, get_a2ui_examples logger = logging.getLogger(__name__) @@ -47,56 +49,80 @@ """ -class McpAppProxyAgent(LlmAgent): +class McpAppProxyAgent: """An agent that proxies MCP Apps.""" SUPPORTED_CONTENT_TYPES: ClassVar[list[str]] = ["text", "text/plain"] - base_url: str = "" - schema_manager: A2uiSchemaManager = None - _a2ui_enabled_provider: A2uiEnabledProvider = PrivateAttr() - _a2ui_catalog_provider: A2uiCatalogProvider = PrivateAttr() - _a2ui_examples_provider: A2uiExamplesProvider = PrivateAttr() def __init__( self, - model: Any, base_url: str, - schema_manager: A2uiSchemaManager, - a2ui_enabled_provider: A2uiEnabledProvider, - a2ui_catalog_provider: A2uiCatalogProvider, - a2ui_examples_provider: A2uiExamplesProvider, - tools: list[Any], # tools passed in, including get_calculator_app + model: Any, ): - system_instructions = schema_manager.generate_system_prompt( - role_description=ROLE_DESCRIPTION, - workflow_description=WORKFLOW_DESCRIPTION, - ui_description=UI_DESCRIPTION, - include_schema=False, - include_examples=False, - validate_examples=False, - ) + self.base_url = base_url + self._model = model - super().__init__( - model=model, - name="mcp_app_proxy_agent", - description="An agent that provides access to MCP Apps.", - instruction=system_instructions, - tools=tools, - planner=BuiltInPlanner( - thinking_config=types.ThinkingConfig( - include_thoughts=True, - ) - ), - disallow_transfer_to_peers=True, - base_url=base_url, - schema_manager=schema_manager, + self._a2ui_enabled_provider = get_a2ui_enabled + self._a2ui_catalog_provider = get_a2ui_catalog + self._a2ui_examples_provider = get_a2ui_examples + + self._agent_name = "mcp_app_proxy_agent" + self._user_id = "remote_agent" + self._text_runner: Optional[Runner] = self._build_runner(self._build_llm_agent()) + + self._schema_managers: Dict[str, A2uiSchemaManager] = {} + self._ui_runners: Dict[str, Runner] = {} + + for version in [VERSION_0_8]: + schema_manager = self._build_schema_manager(version) + self._schema_managers[version] = schema_manager + agent = self._build_llm_agent(schema_manager) + self._ui_runners[version] = self._build_runner(agent) + + self._agent_card = self._build_agent_card() + + @property + def agent_card(self) -> AgentCard: + return self._agent_card + + def get_runner(self, version: Optional[str]) -> Runner: + if version is None: + return self._text_runner + return self._ui_runners[version] + + def get_schema_manager(self, version: Optional[str]) -> Optional[A2uiSchemaManager]: + if version is None: + return None + return self._schema_managers[version] + + def _build_schema_manager(self, version: str) -> A2uiSchemaManager: + return A2uiSchemaManager( + version=version, + catalogs=[ + CatalogConfig.from_path( + name="mcp_app_proxy", + catalog_path="mcp_app_catalog.json", + ), + ], + accepts_inline_catalogs=True, ) - self._a2ui_enabled_provider = a2ui_enabled_provider - self._a2ui_catalog_provider = a2ui_catalog_provider - self._a2ui_examples_provider = a2ui_examples_provider + def _build_agent_card(self) -> AgentCard: + extensions = [] + if self._schema_managers: + for version, sm in self._schema_managers.items(): + ext = get_a2ui_agent_extension( + version, + sm.accepts_inline_catalogs, + sm.supported_catalog_ids, + ) + extensions.append(ext) + + capabilities = AgentCapabilities( + streaming=True, + extensions=extensions, + ) - def get_agent_card(self) -> AgentCard: return AgentCard( name="MCP App Proxy Agent", description="Provides access to MCP Apps like Calculator.", @@ -104,15 +130,7 @@ def get_agent_card(self) -> AgentCard: version="1.0.0", default_input_modes=McpAppProxyAgent.SUPPORTED_CONTENT_TYPES, default_output_modes=McpAppProxyAgent.SUPPORTED_CONTENT_TYPES, - capabilities=AgentCapabilities( - streaming=True, - extensions=[ - get_a2ui_agent_extension( - self.schema_manager.accepts_inline_catalogs, - self.schema_manager.supported_catalog_ids, - ) - ], - ), + capabilities=capabilities, skills=[ AgentSkill( id="open_calculator", @@ -123,3 +141,45 @@ def get_agent_card(self) -> AgentCard: ) ], ) + + def _build_runner(self, agent: LlmAgent) -> Runner: + return Runner( + app_name=self._agent_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + + def _build_llm_agent( + self, schema_manager: Optional[A2uiSchemaManager] = None + ) -> LlmAgent: + """Builds the LLM agent for the contact agent.""" + LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") + + instruction = ( + schema_manager.generate_system_prompt( + role_description=ROLE_DESCRIPTION, + workflow_description=WORKFLOW_DESCRIPTION, + ui_description=UI_DESCRIPTION, + include_schema=False, + include_examples=False, + validate_examples=False, + ) + if schema_manager + else get_text_prompt() + ) + + return LlmAgent( + model=self._model, + name=self._agent_name, + description="An agent that provides access to MCP Apps.", + instruction=system_instructions, + tools=[get_calculator_app, calculate_via_mcp], + planner=BuiltInPlanner( + thinking_config=types.ThinkingConfig( + include_thoughts=True, + ) + ), + disallow_transfer_to_peers=True, + ) diff --git a/samples/agent/adk/mcp_app_proxy/agent_executor.py b/samples/agent/adk/mcp_app_proxy/agent_executor.py index 0302bf45a..16d5e9613 100644 --- a/samples/agent/adk/mcp_app_proxy/agent_executor.py +++ b/samples/agent/adk/mcp_app_proxy/agent_executor.py @@ -79,11 +79,10 @@ class McpAppProxyAgentExecutor(A2aAgentExecutor): def __init__( self, base_url: str, - runner: Runner, - schema_manager: A2uiSchemaManager, + agent: McpAppProxyAgent, ): self._base_url = base_url - self.schema_manager = schema_manager + self._agent = agent # Bypass tool check when converting tool responses to A2A parts to escape # the need to make the tool call to `SendA2uiJsonToClientTool`. By removing @@ -93,31 +92,39 @@ def __init__( config = A2aAgentExecutorConfig( event_converter=A2uiEventConverter(bypass_tool_check=True) ) - super().__init__(runner=runner, config=config) + # Use the text runner as the default runner. + super().__init__(runner=self._agent.get_runner(None), config=config) @override async def _prepare_session( self, context: RequestContext, run_request: AgentRunRequest, - runner: Runner, + _runner: Runner, ): logger.info(f"Loading session for message {context.message}") + active_ui_version = try_activate_a2ui_extension(context, self._agent.agent_card) + runner = self._agent.get_runner(active_ui_version) + schema_manager = self._agent.get_schema_manager(active_ui_version) + session = await super()._prepare_session(context, run_request, runner) if "base_url" not in session.state: session.state["base_url"] = self._base_url - use_ui = try_activate_a2ui_extension(context) - if use_ui: - capabilities = ( + if active_ui_version: + client_capabilities = ( context.message.metadata.get(A2UI_CLIENT_CAPABILITIES_KEY) if context.message and context.message.metadata else None ) - a2ui_catalog = self.schema_manager.get_selected_catalog( - client_ui_capabilities=capabilities + a2ui_catalog = ( + schema_manager.get_selected_catalog( + client_ui_capabilities=client_capabilities + ) + if schema_manager + else None ) # TODO: Load examples from files. diff --git a/samples/agent/adk/mcp_app_proxy/tools.py b/samples/agent/adk/mcp_app_proxy/tools.py new file mode 100644 index 000000000..ed392a496 --- /dev/null +++ b/samples/agent/adk/mcp_app_proxy/tools.py @@ -0,0 +1,100 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Define get_calculator_app tool in a way that the LlmAgent can use. +async def get_calculator_app(tool_context: ToolContext): + """Fetches the calculator app.""" + # Connect to the MCP server via SSE + mcp_server_host = os.getenv("MCP_SERVER_HOST", "localhost") + mcp_server_port = os.getenv("MCP_SERVER_PORT", "8000") + sse_url = f"http://{mcp_server_host}:{mcp_server_port}/sse" + + try: + async with sse_client(sse_url) as streams: + async with ClientSession(streams[0], streams[1]) as session: + await session.initialize() + + # Read the resource + result = await session.read_resource("ui://calculator/app") + + # Package the resource as an A2UI message + if result.contents and hasattr(result.contents[0], "text"): + html_content = result.contents[0].text + encoded_html = "url_encoded:" + urllib.parse.quote(html_content) + messages = [ + { + "beginRendering": { + "surfaceId": "calculator_surface", + "root": "calculator_app_root", + }, + }, + { + "surfaceUpdate": { + "surfaceId": "calculator_surface", + "components": [{ + "id": "calculator_app_root", + "component": { + "McpApp": { + "content": {"literalString": encoded_html}, + "title": {"literalString": "Calculator"}, + "allowedTools": ["calculate"], + } + }, + }], + }, + }, + ] + tool_context.actions.skip_summarization = True + return {"validated_a2ui_json": messages} + else: + logger.error("Failed to get text content from resource") + return {"error": "Could not fetch calculator app content."} + + except Exception as e: + logger.error(f"Error fetching calculator app: {e} {traceback.format_exc()}") + return {"error": f"Failed to connect to MCP server or fetch app. Details: {e}"} + + +async def calculate_via_mcp(operation: str, a: float, b: float): + """Calculates via the MCP server's Calculate tool. + + Args: + operation: The mathematical operation (e.g. 'add', 'subtract', 'multiply', 'divide'). + a: First operand. + b: Second operand. + """ + mcp_server_host = os.getenv("MCP_SERVER_HOST", "localhost") + mcp_server_port = os.getenv("MCP_SERVER_PORT", "8000") + sse_url = f"http://{mcp_server_host}:{mcp_server_port}/sse" + + try: + async with sse_client(sse_url) as streams: + async with ClientSession(streams[0], streams[1]) as session: + await session.initialize() + + result = await session.call_tool( + "calculate", arguments={"operation": operation, "a": a, "b": b} + ) + + if ( + result.content + and len(result.content) > 0 + and hasattr(result.content[0], "text") + ): + return result.content[0].text + return "No result text from MCP calculate tool." + except Exception as e: + logger.error(f"Error calling MCP calculate: {e} {traceback.format_exc()}") + return f"Error connecting to MCP server: {e}" diff --git a/samples/agent/adk/orchestrator/agent.py b/samples/agent/adk/orchestrator/agent.py index 383dbcdfb..80204824a 100644 --- a/samples/agent/adk/orchestrator/agent.py +++ b/samples/agent/adk/orchestrator/agent.py @@ -62,9 +62,12 @@ async def intercept( + json.dumps(request_payload) ) - if context and context.state and context.state.get("use_ui"): + if context and context.state and hasattr(context.state, "active_ui_version"): # Add A2UI extension header - http_kwargs["headers"] = {HTTP_EXTENSION_HEADER: A2UI_EXTENSION_URI} + a2ui_extension_uri = ( + f"{A2UI_EXTENSION_BASE_URI}:{context.state.active_ui_version}" + ) + http_kwargs["headers"] = {HTTP_EXTENSION_HEADER: a2ui_extension_uri} # Add A2UI client capabilities (supported catalogs, etc) to message metadata if (params := request_payload.get("params")) and ( @@ -153,6 +156,7 @@ async def build_agent( subagents = [] supported_catalog_ids = set() skills = [] + extensions = [] accepts_inline_catalogs = False for subagent_url in subagent_urls: async with httpx.AsyncClient() as httpx_client: @@ -163,14 +167,14 @@ async def build_agent( subagent_card = await resolver.get_agent_card() for extension in subagent_card.capabilities.extensions or []: - if extension.uri == A2UI_EXTENSION_URI and extension.params: + if extension.uri.startswith(A2UI_EXTENSION_BASE_URI) and extension.params: supported_catalog_ids.update( extension.params.get(AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY) or [] ) accepts_inline_catalogs |= bool( extension.params.get(AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY) ) - + extensions.append(extension) skills.extend(subagent_card.skills) logger.info( @@ -253,12 +257,7 @@ async def build_agent( default_output_modes=OrchestratorAgent.SUPPORTED_CONTENT_TYPES, capabilities=AgentCapabilities( streaming=True, - extensions=[ - get_a2ui_agent_extension( - accepts_inline_catalogs=accepts_inline_catalogs, - supported_catalog_ids=list(supported_catalog_ids), - ) - ], + extensions=extensions, ), skills=skills, ) diff --git a/samples/agent/adk/orchestrator/agent_executor.py b/samples/agent/adk/orchestrator/agent_executor.py index f2a3a29f2..cd673871f 100644 --- a/samples/agent/adk/orchestrator/agent_executor.py +++ b/samples/agent/adk/orchestrator/agent_executor.py @@ -31,7 +31,7 @@ A2aAgentExecutorConfig, A2aAgentExecutor, ) -from a2ui.a2a import is_a2ui_part, try_activate_a2ui_extension, A2UI_EXTENSION_URI +from a2ui.a2a import is_a2ui_part, try_activate_a2ui_extension from a2ui.core.schema.constants import A2UI_CLIENT_CAPABILITIES_KEY from google.adk.a2a.converters import event_converter from a2a.server.events import Event as A2AEvent @@ -134,7 +134,8 @@ async def _prepare_session( ): session = await super()._prepare_session(context, run_request, runner) - if try_activate_a2ui_extension(context): + active_ui_version = try_activate_a2ui_extension(context, runner.agent) + if active_ui_version: client_capabilities = ( context.message.metadata.get(A2UI_CLIENT_CAPABILITIES_KEY) if context.message and context.message.metadata @@ -149,7 +150,7 @@ async def _prepare_session( actions=EventActions( state_delta={ # These values are used to configure A2UI messages to remote agent calls - "use_ui": True, + "active_ui_version": active_ui_version, "client_capabilities": client_capabilities, } ), diff --git a/samples/agent/adk/restaurant_finder/__main__.py b/samples/agent/adk/restaurant_finder/__main__.py index f816cf6d8..530766148 100644 --- a/samples/agent/adk/restaurant_finder/__main__.py +++ b/samples/agent/adk/restaurant_finder/__main__.py @@ -50,17 +50,16 @@ def main(host, port): base_url = f"http://{host}:{port}" - ui_agent = RestaurantAgent(base_url=base_url, use_ui=True) - text_agent = RestaurantAgent(base_url=base_url, use_ui=False) + agent = RestaurantAgent(base_url=base_url) - agent_executor = RestaurantAgentExecutor(ui_agent, text_agent) + agent_executor = RestaurantAgentExecutor(agent) request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=InMemoryTaskStore(), ) server = A2AStarletteApplication( - agent_card=ui_agent.get_agent_card(), http_handler=request_handler + agent_card=agent.agent_card, http_handler=request_handler ) import uvicorn diff --git a/samples/agent/adk/restaurant_finder/agent.py b/samples/agent/adk/restaurant_finder/agent.py index e5239c5fd..b327a348c 100644 --- a/samples/agent/adk/restaurant_finder/agent.py +++ b/samples/agent/adk/restaurant_finder/agent.py @@ -16,7 +16,7 @@ import logging import os from collections.abc import AsyncIterable -from typing import Any +from typing import Any, Optional, Dict import jsonschema from a2a.types import ( @@ -55,39 +55,48 @@ class RestaurantAgent: SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] - def __init__(self, base_url: str, use_ui: bool = False): + def __init__(self, base_url: str): self.base_url = base_url - self.use_ui = use_ui - self._schema_manager = ( - A2uiSchemaManager( - VERSION_0_8, - catalogs=[ - BasicCatalog.get_config(version=VERSION_0_8, examples_path="examples") - ], - schema_modifiers=[remove_strict_validation], - ) - if use_ui - else None - ) - self._agent = self._build_agent(use_ui) + self._agent_name = "Restaurant Agent" self._user_id = "remote_agent" - self._runner = Runner( - app_name=self._agent.name, - agent=self._agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + self._text_runner: Optional[Runner] = self._build_runner(self._build_llm_agent()) + + self._schema_managers: Dict[str, A2uiSchemaManager] = {} + self._ui_runners: Dict[str, Runner] = {} + + for version in [VERSION_0_8]: + schema_manager = self._build_schema_manager(version) + self._schema_managers[version] = schema_manager + agent = self._build_llm_agent(schema_manager) + self._ui_runners[version] = self._build_runner(agent) + + self._agent_card = self._build_agent_card() + + @property + def agent_card(self) -> AgentCard: + return self._agent_card + + def _build_schema_manager(self, version: str) -> A2uiSchemaManager: + return A2uiSchemaManager( + version=version, + catalogs=[BasicCatalog.get_config(version=version, examples_path="examples")], + schema_modifiers=[remove_strict_validation], ) - def get_agent_card(self) -> AgentCard: + def _build_agent_card(self) -> AgentCard: + extensions = [] + if self._schema_managers: + for version, sm in self._schema_managers.items(): + ext = get_a2ui_agent_extension( + version, + sm.accepts_inline_catalogs, + sm.supported_catalog_ids, + ) + extensions.append(ext) + capabilities = AgentCapabilities( streaming=True, - extensions=[ - get_a2ui_agent_extension( - self._schema_manager.accepts_inline_catalogs, - self._schema_manager.supported_catalog_ids, - ) - ], + extensions=extensions, ) skill = AgentSkill( id="find_restaurants", @@ -110,22 +119,33 @@ def get_agent_card(self) -> AgentCard: skills=[skill], ) + def _build_runner(self, agent: LlmAgent) -> Runner: + return Runner( + app_name=self._agent_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + def get_processing_message(self) -> str: return "Finding restaurants that match your criteria..." - def _build_agent(self, use_ui: bool) -> LlmAgent: + def _build_llm_agent( + self, schema_manager: Optional[A2uiSchemaManager] = None + ) -> LlmAgent: """Builds the LLM agent for the restaurant agent.""" LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") instruction = ( - self._schema_manager.generate_system_prompt( + schema_manager.generate_system_prompt( role_description=ROLE_DESCRIPTION, ui_description=UI_DESCRIPTION, include_schema=True, include_examples=True, validate_examples=True, ) - if use_ui + if schema_manager else get_text_prompt() ) @@ -137,17 +157,31 @@ def _build_agent(self, use_ui: bool) -> LlmAgent: tools=[get_restaurants], ) - async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: + async def stream( + self, query, session_id, ui_version: Optional[str] = None + ) -> AsyncIterable[dict[str, Any]]: session_state = {"base_url": self.base_url} - session = await self._runner.session_service.get_session( - app_name=self._agent.name, + # Determine which runner to use based on whether the a2ui extension is active. + if ui_version: + runner = self._ui_runners[ui_version] + schema_manager = self._schema_managers[ui_version] + selected_catalog = ( + schema_manager.get_selected_catalog() if schema_manager else None + ) + else: + runner = self._text_runner + schema_manager = None + selected_catalog = None + + session = await runner.session_service.get_session( + app_name=self._agent_name, user_id=self._user_id, session_id=session_id, ) if session is None: - session = await self._runner.session_service.create_session( - app_name=self._agent.name, + session = await runner.session_service.create_session( + app_name=self._agent_name, user_id=self._user_id, state=session_state, session_id=session_id, @@ -161,8 +195,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: current_query_text = query # Ensure schema was loaded - selected_catalog = self._schema_manager.get_selected_catalog() - if self.use_ui and not selected_catalog.catalog_schema: + if ui_version and (not selected_catalog or not selected_catalog.catalog_schema): logger.error( "--- RestaurantAgent.stream: A2UI_SCHEMA is not loaded. " "Cannot perform UI validation. ---" @@ -194,7 +227,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: ) final_response_content = None - async for event in self._runner.run_async( + async for event in runner.run_async( user_id=self._user_id, session_id=session.id, new_message=current_message, @@ -235,7 +268,7 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]: is_valid = False error_message = "" - if self.use_ui: + if ui_version: logger.info( "--- RestaurantAgent.stream: Validating UI response (Attempt" f" {attempt})... ---" diff --git a/samples/agent/adk/restaurant_finder/agent_executor.py b/samples/agent/adk/restaurant_finder/agent_executor.py index f552b407d..c6e34d140 100644 --- a/samples/agent/adk/restaurant_finder/agent_executor.py +++ b/samples/agent/adk/restaurant_finder/agent_executor.py @@ -40,11 +40,8 @@ class RestaurantAgentExecutor(AgentExecutor): """Restaurant AgentExecutor Example.""" - def __init__(self, ui_agent: RestaurantAgent, text_agent: RestaurantAgent): - # Instantiate two agents: one for UI and one for text-only. - # The appropriate one will be chosen at execution time. - self.ui_agent = ui_agent - self.text_agent = text_agent + def __init__(self, agent: RestaurantAgent): + self._agent = agent async def execute( self, @@ -56,14 +53,12 @@ async def execute( action = None logger.info(f"--- Client requested extensions: {context.requested_extensions} ---") - use_ui = try_activate_a2ui_extension(context) + active_ui_version = try_activate_a2ui_extension(context, self._agent.agent_card) # Determine which agent to use based on whether the a2ui extension is active. - if use_ui: - agent = self.ui_agent + if active_ui_version: logger.info("--- AGENT_EXECUTOR: A2UI extension is active. Using UI agent. ---") else: - agent = self.text_agent logger.info( "--- AGENT_EXECUTOR: A2UI extension is not active. Using text agent. ---" ) @@ -126,7 +121,7 @@ async def execute( await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) - async for item in agent.stream(query, task.context_id): + async for item in self._agent.stream(query, task.context_id, active_ui_version): is_task_complete = item["is_task_complete"] if not is_task_complete: await updater.update_status( diff --git a/samples/agent/adk/restaurant_finder/examples/confirmation.json b/samples/agent/adk/restaurant_finder/examples/confirmation.json index d91dca3e0..eabc3fac1 100644 --- a/samples/agent/adk/restaurant_finder/examples/confirmation.json +++ b/samples/agent/adk/restaurant_finder/examples/confirmation.json @@ -21,6 +21,25 @@ } } }, + { + "id": "confirmation-column", + "component": { + "Column": { + "children": { + "explicitList": [ + "confirm-title", + "confirm-image", + "divider1", + "confirm-details", + "divider2", + "confirm-dietary", + "divider3", + "confirm-text" + ] + } + } + } + }, { "id": "confirm-title", "component": { diff --git a/samples/agent/adk/rizzcharts/__main__.py b/samples/agent/adk/rizzcharts/__main__.py index 9c988f102..6945df7d9 100644 --- a/samples/agent/adk/rizzcharts/__main__.py +++ b/samples/agent/adk/rizzcharts/__main__.py @@ -21,16 +21,9 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from a2ui.core.schema.constants import VERSION_0_8 -from a2ui.core.schema.manager import A2uiSchemaManager, CatalogConfig -from a2ui.basic_catalog.provider import BasicCatalog -from agent_executor import RizzchartsAgentExecutor, get_a2ui_enabled, get_a2ui_catalog, get_a2ui_examples +from agent_executor import RizzchartsAgentExecutor from agent import RizzchartsAgent -from google.adk.artifacts import InMemoryArtifactService -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.lite_llm import LiteLlm -from google.adk.runners import Runner -from google.adk.sessions.in_memory_session_service import InMemorySessionService from dotenv import load_dotenv from starlette.middleware.cors import CORSMiddleware @@ -61,42 +54,14 @@ def main(host, port): base_url = f"http://{host}:{port}" - schema_manager = A2uiSchemaManager( - VERSION_0_8, - catalogs=[ - CatalogConfig.from_path( - name="rizzcharts", - catalog_path="rizzcharts_catalog_definition.json", - examples_path="examples/rizzcharts_catalog", - ), - BasicCatalog.get_config( - version=VERSION_0_8, - examples_path="examples/standard_catalog", - ), - ], - accepts_inline_catalogs=True, - ) - agent = RizzchartsAgent( base_url=base_url, model=LiteLlm(model=lite_llm_model), - schema_manager=schema_manager, - a2ui_enabled_provider=get_a2ui_enabled, - a2ui_catalog_provider=get_a2ui_catalog, - a2ui_examples_provider=get_a2ui_examples, - ) - runner = Runner( - app_name=agent.name, - agent=agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), ) agent_executor = RizzchartsAgentExecutor( base_url=base_url, - runner=runner, - schema_manager=schema_manager, + agent=agent, ) request_handler = DefaultRequestHandler( @@ -104,7 +69,7 @@ def main(host, port): task_store=InMemoryTaskStore(), ) server = A2AStarletteApplication( - agent_card=agent.get_agent_card(), http_handler=request_handler + agent_card=agent.agent_card, http_handler=request_handler ) import uvicorn diff --git a/samples/agent/adk/rizzcharts/agent.py b/samples/agent/adk/rizzcharts/agent.py index 850b6cc12..44e8a070d 100644 --- a/samples/agent/adk/rizzcharts/agent.py +++ b/samples/agent/adk/rizzcharts/agent.py @@ -16,19 +16,26 @@ import logging from pathlib import Path import pkgutil -from typing import Any, ClassVar +from typing import Any, ClassVar, Dict, Optional from a2a.types import AgentCapabilities, AgentCard, AgentSkill from a2ui.a2a import get_a2ui_agent_extension from a2ui.adk.a2a_extension.send_a2ui_to_client_toolset import SendA2uiToClientToolset, A2uiEnabledProvider, A2uiCatalogProvider, A2uiExamplesProvider -from a2ui.core.schema.manager import A2uiSchemaManager +from a2ui.core.schema.manager import A2uiSchemaManager, CatalogConfig +from a2ui.basic_catalog.provider import BasicCatalog +from a2ui.core.schema.constants import VERSION_0_8 from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext from google.adk.planners.built_in_planner import BuiltInPlanner from google.genai import types from pydantic import PrivateAttr +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from agent_executor import get_a2ui_enabled, get_a2ui_catalog, get_a2ui_examples +from google.adk.runners import Runner try: - from .tools import get_sales_data, get_store_sales + from tools import get_sales_data, get_store_sales except ImportError: from tools import get_sales_data, get_store_sales @@ -82,78 +89,99 @@ """ -class RizzchartsAgent(LlmAgent): +class RizzchartsAgent: """An agent that runs an ecommerce dashboard""" SUPPORTED_CONTENT_TYPES: ClassVar[list[str]] = ["text", "text/plain"] - base_url: str = "" - schema_manager: A2uiSchemaManager = None - _a2ui_enabled_provider: A2uiEnabledProvider = PrivateAttr() - _a2ui_catalog_provider: A2uiCatalogProvider = PrivateAttr() - _a2ui_examples_provider: A2uiExamplesProvider = PrivateAttr() def __init__( self, - model: Any, base_url: str, - schema_manager: A2uiSchemaManager, - a2ui_enabled_provider: A2uiEnabledProvider, - a2ui_catalog_provider: A2uiCatalogProvider, - a2ui_examples_provider: A2uiExamplesProvider, + model: Any, ): - """Initializes the RizzchartsAgent. - - Args: - model: The LLM model to use. - base_url: The base URL for the agent. - schema_manager: The A2UI schema manager. - a2ui_enabled_provider: A provider to check if A2UI is enabled. - a2ui_catalog_provider: A provider to retrieve the A2UI catalog (A2uiCatalog object). - a2ui_examples_provider: A provider to retrieve the A2UI examples (str). - """ + self.base_url = base_url + self._model = model - system_instructions = schema_manager.generate_system_prompt( - role_description=ROLE_DESCRIPTION, - workflow_description=WORKFLOW_DESCRIPTION, - ui_description=UI_DESCRIPTION, - include_schema=False, - include_examples=False, - validate_examples=False, - ) - super().__init__( - model=model, - name="rizzcharts_agent", - description="An agent that lets sales managers request sales data.", - instruction=system_instructions, - tools=[ - get_store_sales, - get_sales_data, - SendA2uiToClientToolset( - a2ui_catalog=a2ui_catalog_provider, - a2ui_enabled=a2ui_enabled_provider, - a2ui_examples=a2ui_examples_provider, + self._a2ui_enabled_provider = get_a2ui_enabled + self._a2ui_catalog_provider = get_a2ui_catalog + self._a2ui_examples_provider = get_a2ui_examples + + self._agent_name = "mcp_app_proxy_agent" + self._user_id = "remote_agent" + + self._session_service = InMemorySessionService() + self._memory_service = InMemoryMemoryService() + self._artifact_service = InMemoryArtifactService() + + self._text_runner: Optional[Runner] = self._build_runner(self._build_llm_agent()) + + self._schema_managers: Dict[str, A2uiSchemaManager] = {} + self._ui_runners: Dict[str, Runner] = {} + + for version in [VERSION_0_8]: + schema_manager = self._build_schema_manager(version) + self._schema_managers[version] = schema_manager + agent = self._build_llm_agent(schema_manager) + self._ui_runners[version] = self._build_runner(agent) + + self._agent_card = self._build_agent_card() + + @property + def agent_card(self) -> AgentCard: + return self._agent_card + + def get_runner(self, version: Optional[str]) -> Runner: + if version is None: + return self._text_runner + return self._ui_runners[version] + + def get_schema_manager(self, version: Optional[str]) -> Optional[A2uiSchemaManager]: + if version is None: + return None + return self._schema_managers[version] + + def _build_schema_manager(self, version: str) -> A2uiSchemaManager: + return A2uiSchemaManager( + version=version, + catalogs=[ + CatalogConfig.from_path( + name="rizzcharts", + catalog_path="rizzcharts_catalog_definition.json", + examples_path="examples/rizzcharts_catalog", + ), + BasicCatalog.get_config( + version=VERSION_0_8, + examples_path="examples/standard_catalog", ), ], - planner=BuiltInPlanner( - thinking_config=types.ThinkingConfig( - include_thoughts=True, - ) - ), - disallow_transfer_to_peers=True, - base_url=base_url, - schema_manager=schema_manager, + accepts_inline_catalogs=True, ) self._a2ui_enabled_provider = a2ui_enabled_provider self._a2ui_catalog_provider = a2ui_catalog_provider self._a2ui_examples_provider = a2ui_examples_provider - def get_agent_card(self) -> AgentCard: + def _build_agent_card(self) -> AgentCard: """Returns the AgentCard defining this agent's metadata and skills. Returns: An AgentCard object. """ + extensions = [] + if self._schema_managers: + for version, sm in self._schema_managers.items(): + ext = get_a2ui_agent_extension( + version, + sm.accepts_inline_catalogs, + sm.supported_catalog_ids, + ) + extensions.append(ext) + + capabilities = AgentCapabilities( + streaming=True, + extensions=extensions, + ) + return AgentCard( name="Ecommerce Dashboard Agent", description=( @@ -164,15 +192,7 @@ def get_agent_card(self) -> AgentCard: version="1.0.0", default_input_modes=RizzchartsAgent.SUPPORTED_CONTENT_TYPES, default_output_modes=RizzchartsAgent.SUPPORTED_CONTENT_TYPES, - capabilities=AgentCapabilities( - streaming=True, - extensions=[ - get_a2ui_agent_extension( - self.schema_manager.accepts_inline_catalogs, - self.schema_manager.supported_catalog_ids, - ) - ], - ), + capabilities=capabilities, skills=[ AgentSkill( id="view_sales_by_category", @@ -202,3 +222,51 @@ def get_agent_card(self) -> AgentCard: ), ], ) + + def _build_runner(self, agent: LlmAgent) -> Runner: + return Runner( + app_name=self._agent_name, + agent=agent, + artifact_service=self._artifact_service, + session_service=self._session_service, + memory_service=self._memory_service, + ) + + def _build_llm_agent( + self, schema_manager: Optional[A2uiSchemaManager] = None + ) -> LlmAgent: + """Builds the LLM agent for the contact agent.""" + instruction = ( + schema_manager.generate_system_prompt( + role_description=ROLE_DESCRIPTION, + workflow_description=WORKFLOW_DESCRIPTION, + ui_description=UI_DESCRIPTION, + include_schema=False, + include_examples=False, + validate_examples=False, + ) + if schema_manager + else "" + ) + + return LlmAgent( + model=self._model, + name=self._agent_name, + description="An agent that lets sales managers request sales data.", + instruction=instruction, + tools=[ + get_store_sales, + get_sales_data, + SendA2uiToClientToolset( + a2ui_catalog=self._a2ui_catalog_provider, + a2ui_enabled=self._a2ui_enabled_provider, + a2ui_examples=self._a2ui_examples_provider, + ), + ], + planner=BuiltInPlanner( + thinking_config=types.ThinkingConfig( + include_thoughts=True, + ) + ), + disallow_transfer_to_peers=True, + ) diff --git a/samples/agent/adk/rizzcharts/agent_executor.py b/samples/agent/adk/rizzcharts/agent_executor.py index 5892eea94..eaf09983e 100644 --- a/samples/agent/adk/rizzcharts/agent_executor.py +++ b/samples/agent/adk/rizzcharts/agent_executor.py @@ -79,47 +79,59 @@ def get_a2ui_enabled(ctx: ReadonlyContext): return ctx.state.get(_A2UI_ENABLED_KEY, False) +from agent import RizzchartsAgent + + class RizzchartsAgentExecutor(A2aAgentExecutor): """Executor for the Rizzcharts agent that handles A2UI session setup.""" def __init__( self, base_url: str, - runner: Runner, - schema_manager: A2uiSchemaManager, + agent: RizzchartsAgent, ): self._base_url = base_url - self.schema_manager = schema_manager + self._agent = agent config = A2aAgentExecutorConfig(event_converter=A2uiEventConverter()) - super().__init__(runner=runner, config=config) + # Use the text runner as the default runner. + super().__init__(runner=self._agent.get_runner(None), config=config) @override async def _prepare_session( self, context: RequestContext, run_request: AgentRunRequest, - runner: Runner, + _runner: Runner, ): logger.info(f"Loading session for message {context.message}") + active_ui_version = try_activate_a2ui_extension(context, self._agent.agent_card) + runner = self._agent.get_runner(active_ui_version) + schema_manager = self._agent.get_schema_manager(active_ui_version) + session = await super()._prepare_session(context, run_request, runner) if "base_url" not in session.state: session.state["base_url"] = self._base_url - use_ui = try_activate_a2ui_extension(context) - if use_ui: + if active_ui_version: capabilities = ( context.message.metadata.get(A2UI_CLIENT_CAPABILITIES_KEY) if context.message and context.message.metadata else None ) - a2ui_catalog = self.schema_manager.get_selected_catalog( - client_ui_capabilities=capabilities + a2ui_catalog = ( + schema_manager.get_selected_catalog(client_ui_capabilities=capabilities) + if schema_manager + else None ) - examples = self.schema_manager.load_examples(a2ui_catalog, validate=True) + examples = ( + schema_manager.load_examples(a2ui_catalog, validate=True) + if schema_manager + else None + ) await runner.session_service.append_event( session, diff --git a/samples/agent/adk/rizzcharts/examples/standard_catalog/map.json b/samples/agent/adk/rizzcharts/examples/standard_catalog/map.json index f38396c00..a0b4300f7 100644 --- a/samples/agent/adk/rizzcharts/examples/standard_catalog/map.json +++ b/samples/agent/adk/rizzcharts/examples/standard_catalog/map.json @@ -39,6 +39,16 @@ } } }, + { + "id": "map-image", + "component": { + "Image": { + "url": { + "literalString": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d7/Los_Angeles_from_Mount_Hollywood.jpg/1280px-Los_Angeles_from_Mount_Hollywood.jpg" + } + } + } + }, { "id": "location-list", "component": {