diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index af49629ff3..987828c21f 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -33,126 +33,135 @@ from .llm_agent import LlmAgent from .sequential_agent_config import SequentialAgentConfig -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) @experimental class SequentialAgentState(BaseAgentState): - """State for SequentialAgent.""" + """State for SequentialAgent.""" - current_sub_agent: str = '' - """The name of the current sub-agent to run.""" + current_sub_agent: str = "" + """The name of the current sub-agent to run.""" class SequentialAgent(BaseAgent): - """A shell agent that runs its sub-agents in sequence.""" - - config_type: ClassVar[Type[BaseAgentConfig]] = SequentialAgentConfig - """The config type for this agent.""" - - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - if not self.sub_agents: - return - - # Initialize or resume the execution state from the agent state. - agent_state = self._load_agent_state(ctx, SequentialAgentState) - start_index = self._get_start_index(agent_state) - - pause_invocation = False - resuming_sub_agent = agent_state is not None - for i in range(start_index, len(self.sub_agents)): - sub_agent = self.sub_agents[i] - if not resuming_sub_agent: - # If we are resuming from the current event, it means the same event has - # already been logged, so we should avoid yielding it again. + """A shell agent that runs its sub-agents in sequence.""" + + config_type: ClassVar[Type[BaseAgentConfig]] = SequentialAgentConfig + """The config type for this agent.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + if not self.sub_agents: + return + + # Initialize or resume the execution state from the agent state. + agent_state = self._load_agent_state(ctx, SequentialAgentState) + start_index = self._get_start_index(agent_state) + + pause_invocation = False + resuming_sub_agent = agent_state is not None + for i in range(start_index, len(self.sub_agents)): + sub_agent = self.sub_agents[i] + if not resuming_sub_agent: + # If we are resuming from the current event, it means the same event has + # already been logged, so we should avoid yielding it again. + if ctx.is_resumable: + agent_state = SequentialAgentState(current_sub_agent=sub_agent.name) + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + + async with Aclosing(sub_agent.run_async(ctx)) as agen: + async for event in agen: + yield event + if ctx.should_pause_invocation(event): + pause_invocation = True + # Check for escalate action to enable early exit from the sequence. + # When escalate is set, we terminate immediately, stopping both + # subsequent events in the current agent and all remaining agents. + # Note: escalate takes precedence over pause_invocation. + if event.actions and event.actions.escalate: + return + + # Skip the rest of the sub-agents if the invocation is paused. + if pause_invocation: + return + + # Reset the flag for the next sub-agent. + resuming_sub_agent = False + if ctx.is_resumable: - agent_state = SequentialAgentState(current_sub_agent=sub_agent.name) - ctx.set_agent_state(self.name, agent_state=agent_state) - yield self._create_agent_state_event(ctx) - - async with Aclosing(sub_agent.run_async(ctx)) as agen: - async for event in agen: - yield event - if ctx.should_pause_invocation(event): - pause_invocation = True - - # Skip the rest of the sub-agents if the invocation is paused. - if pause_invocation: - return - - # Reset the flag for the next sub-agent. - resuming_sub_agent = False - - if ctx.is_resumable: - ctx.set_agent_state(self.name, end_of_agent=True) - yield self._create_agent_state_event(ctx) - - def _get_start_index( - self, - agent_state: SequentialAgentState, - ) -> int: - """Calculates the start index for the sub-agent loop.""" - if not agent_state: - return 0 - - if not agent_state.current_sub_agent: - # This means the process was finished. - return len(self.sub_agents) - - try: - sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents] - return sub_agent_names.index(agent_state.current_sub_agent) - except ValueError: - # A sub-agent was removed so the agent name is not found. - # For now, we restart from the beginning. - logger.warning( - 'Sub-agent %s was removed so the agent name is not found. Restarting' - ' from the beginning.', - agent_state.current_sub_agent, - ) - return 0 - - @override - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Implementation for live SequentialAgent. - - Compared to the non-live case, live agents process a continuous stream of audio - or video, so there is no way to tell if it's finished and should pass - to the next agent or not. So we introduce a task_completed() function so the - model can call this function to signal that it's finished the task and we - can move on to the next agent. - - Args: - ctx: The invocation context of the agent. - """ - if not self.sub_agents: - return - - # There is no way to know if it's using live during init phase so we have to init it here - for sub_agent in self.sub_agents: - # add tool - def task_completed(): - """ - Signals that the agent has successfully completed the user's question - or task. + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + + def _get_start_index( + self, + agent_state: SequentialAgentState, + ) -> int: + """Calculates the start index for the sub-agent loop.""" + if not agent_state: + return 0 + + if not agent_state.current_sub_agent: + # This means the process was finished. + return len(self.sub_agents) + + try: + sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents] + return sub_agent_names.index(agent_state.current_sub_agent) + except ValueError: + # A sub-agent was removed so the agent name is not found. + # For now, we restart from the beginning. + logger.warning( + "Sub-agent %s was removed so the agent name is not found. Restarting" + " from the beginning.", + agent_state.current_sub_agent, + ) + return 0 + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Implementation for live SequentialAgent. + + Compared to the non-live case, live agents process a continuous stream of audio + or video, so there is no way to tell if it's finished and should pass + to the next agent or not. So we introduce a task_completed() function so the + model can call this function to signal that it's finished the task and we + can move on to the next agent. + + Args: + ctx: The invocation context of the agent. """ - return 'Task completion signaled.' - - if isinstance(sub_agent, LlmAgent): - # Use function name to dedupe. - if task_completed.__name__ not in sub_agent.tools: - sub_agent.tools.append(task_completed) - sub_agent.instruction += f"""If you finished the user's request + if not self.sub_agents: + return + + # There is no way to know if it's using live during init phase so we have to init it here + for sub_agent in self.sub_agents: + # add tool + def task_completed(): + """ + Signals that the agent has successfully completed the user's question + or task. + """ + return "Task completion signaled." + + if isinstance(sub_agent, LlmAgent): + # Use function name to dedupe. + if task_completed.__name__ not in sub_agent.tools: + sub_agent.tools.append(task_completed) + sub_agent.instruction += f"""If you finished the user's request according to its description, call the {task_completed.__name__} function to exit so the next agents can take over. When calling this function, do not generate any text other than the function call.""" - for sub_agent in self.sub_agents: - async with Aclosing(sub_agent.run_live(ctx)) as agen: - async for event in agen: - yield event + for sub_agent in self.sub_agents: + async with Aclosing(sub_agent.run_live(ctx)) as agen: + async for event in agen: + yield event + # Check for escalate action to enable early exit in live mode. + if event.actions and event.actions.escalate: + return diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index f5250d0a17..c620ccb479 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -22,6 +22,7 @@ from .enterprise_search_tool import enterprise_web_search_tool as enterprise_web_search from .example_tool import ExampleTool from .exit_loop_tool import exit_loop +from .exit_sequence_tool import exit_sequence from .function_tool import FunctionTool from .get_user_choice_tool import get_user_choice_tool as get_user_choice from .google_maps_grounding_tool import google_maps_grounding @@ -48,6 +49,7 @@ 'VertexAiSearchTool', 'ExampleTool', 'exit_loop', + 'exit_sequence', 'FunctionTool', 'get_user_choice', 'load_artifacts', diff --git a/src/google/adk/tools/exit_sequence_tool.py b/src/google/adk/tools/exit_sequence_tool.py new file mode 100644 index 0000000000..ea7e06d3f7 --- /dev/null +++ b/src/google/adk/tools/exit_sequence_tool.py @@ -0,0 +1,48 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .tool_context import ToolContext + + +def exit_sequence(tool_context: ToolContext): + """Exits the sequential execution of agents immediately. + + Call this function when you encounter a terminal condition and want to + prevent subsequent agents in the sequence from executing. This will also + stop any remaining events from the current agent. + + This tool is specifically designed for use within SequentialAgent contexts. + When called, it sets the escalate flag, which causes the SequentialAgent + to terminate the sequence immediately, preventing both: + - Subsequent events from the current sub-agent + - All remaining sub-agents in the sequence + + Use cases: + - A blocking error is encountered that makes further processing impossible + - A definitive answer is found early, making subsequent agents unnecessary + - A security or validation check fails and the workflow must stop + - Resource limits are reached and safe termination is required + + Example: + If you're in a sequence of [validator, processor, finalizer] agents, + and the validator finds invalid data, it can call exit_sequence() to + prevent the processor and finalizer from running on bad data. + + Args: + tool_context: The context of the current tool invocation. + """ + tool_context.actions.escalate = True + tool_context.actions.skip_summarization = True diff --git a/tests/unittests/agents/test_sequential_agent.py b/tests/unittests/agents/test_sequential_agent.py index 9703e0ca29..f6aeffd672 100644 --- a/tests/unittests/agents/test_sequential_agent.py +++ b/tests/unittests/agents/test_sequential_agent.py @@ -22,6 +22,7 @@ from google.adk.agents.sequential_agent import SequentialAgentState from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.genai import types import pytest @@ -29,175 +30,317 @@ class _TestingAgent(BaseAgent): + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f"Hello, async {self.name}!")] + ), + ) - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - yield Event( - author=self.name, - invocation_id=ctx.invocation_id, - content=types.Content( - parts=[types.Part(text=f'Hello, async {self.name}!')] - ), - ) + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text=f"Hello, live {self.name}!")]), + ) - @override - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - yield Event( - author=self.name, - invocation_id=ctx.invocation_id, - content=types.Content( - parts=[types.Part(text=f'Hello, live {self.name}!')] - ), - ) + +class _TestingAgentWithEscalateAction(BaseAgent): + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f"Hello, async {self.name}!")] + ), + actions=EventActions(escalate=True), + ) + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f"I should not be seen after escalation!")] + ), + ) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text=f"Hello, live {self.name}!")]), + actions=EventActions(escalate=True), + ) async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, resumable: bool = False ) -> InvocationContext: - session_service = InMemorySessionService() - session = await session_service.create_session( - app_name='test_app', user_id='test_user' - ) - return InvocationContext( - invocation_id=f'{test_name}_invocation_id', - agent=agent, - session=session, - session_service=session_service, - resumability_config=ResumabilityConfig(is_resumable=resumable), - ) + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + return InvocationContext( + invocation_id=f"{test_name}_invocation_id", + agent=agent, + session=session, + session_service=session_service, + resumability_config=ResumabilityConfig(is_resumable=resumable), + ) @pytest.mark.asyncio async def test_run_async(request: pytest.FixtureRequest): - agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') - agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') - sequential_agent = SequentialAgent( - name=f'{request.function.__name__}_test_agent', - sub_agents=[ - agent_1, - agent_2, - ], - ) - parent_ctx = await _create_parent_invocation_context( - request.function.__name__, sequential_agent - ) - events = [e async for e in sequential_agent.run_async(parent_ctx)] - - assert len(events) == 2 - assert events[0].author == agent_1.name - assert events[1].author == agent_2.name - assert events[0].content.parts[0].text == f'Hello, async {agent_1.name}!' - assert events[1].content.parts[0].text == f'Hello, async {agent_2.name}!' + agent_1 = _TestingAgent(name=f"{request.function.__name__}_test_agent_1") + agent_2 = _TestingAgent(name=f"{request.function.__name__}_test_agent_2") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + assert len(events) == 2 + assert events[0].author == agent_1.name + assert events[1].author == agent_2.name + assert events[0].content.parts[0].text == f"Hello, async {agent_1.name}!" + assert events[1].content.parts[0].text == f"Hello, async {agent_2.name}!" @pytest.mark.asyncio async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest): - sequential_agent = SequentialAgent( - name=f'{request.function.__name__}_test_agent', - sub_agents=[], - ) - parent_ctx = await _create_parent_invocation_context( - request.function.__name__, sequential_agent - ) - events = [e async for e in sequential_agent.run_async(parent_ctx)] + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] - assert not events + assert not events @pytest.mark.asyncio async def test_run_async_with_resumability(request: pytest.FixtureRequest): - agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') - agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') - sequential_agent = SequentialAgent( - name=f'{request.function.__name__}_test_agent', - sub_agents=[ - agent_1, - agent_2, - ], - ) - parent_ctx = await _create_parent_invocation_context( - request.function.__name__, sequential_agent, resumable=True - ) - events = [e async for e in sequential_agent.run_async(parent_ctx)] - - # 5 events: - # 1. SequentialAgent checkpoint event for agent 1 - # 2. Agent 1 event - # 3. SequentialAgent checkpoint event for agent 2 - # 4. Agent 2 event - # 5. SequentialAgent final checkpoint event - assert len(events) == 5 - assert events[0].author == sequential_agent.name - assert not events[0].actions.end_of_agent - assert events[0].actions.agent_state['current_sub_agent'] == agent_1.name - - assert events[1].author == agent_1.name - assert events[1].content.parts[0].text == f'Hello, async {agent_1.name}!' - - assert events[2].author == sequential_agent.name - assert not events[2].actions.end_of_agent - assert events[2].actions.agent_state['current_sub_agent'] == agent_2.name - - assert events[3].author == agent_2.name - assert events[3].content.parts[0].text == f'Hello, async {agent_2.name}!' - - assert events[4].author == sequential_agent.name - assert events[4].actions.end_of_agent + agent_1 = _TestingAgent(name=f"{request.function.__name__}_test_agent_1") + agent_2 = _TestingAgent(name=f"{request.function.__name__}_test_agent_2") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent, resumable=True + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # 5 events: + # 1. SequentialAgent checkpoint event for agent 1 + # 2. Agent 1 event + # 3. SequentialAgent checkpoint event for agent 2 + # 4. Agent 2 event + # 5. SequentialAgent final checkpoint event + assert len(events) == 5 + assert events[0].author == sequential_agent.name + assert not events[0].actions.end_of_agent + assert events[0].actions.agent_state["current_sub_agent"] == agent_1.name + + assert events[1].author == agent_1.name + assert events[1].content.parts[0].text == f"Hello, async {agent_1.name}!" + + assert events[2].author == sequential_agent.name + assert not events[2].actions.end_of_agent + assert events[2].actions.agent_state["current_sub_agent"] == agent_2.name + + assert events[3].author == agent_2.name + assert events[3].content.parts[0].text == f"Hello, async {agent_2.name}!" + + assert events[4].author == sequential_agent.name + assert events[4].actions.end_of_agent @pytest.mark.asyncio async def test_resume_async(request: pytest.FixtureRequest): - agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') - agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') - sequential_agent = SequentialAgent( - name=f'{request.function.__name__}_test_agent', - sub_agents=[ - agent_1, - agent_2, - ], - ) - parent_ctx = await _create_parent_invocation_context( - request.function.__name__, sequential_agent, resumable=True - ) - parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState( - current_sub_agent=agent_2.name - ).model_dump(mode='json') - - events = [e async for e in sequential_agent.run_async(parent_ctx)] - - # 2 events: - # 1. Agent 2 event - # 2. SequentialAgent final checkpoint event - assert len(events) == 2 - assert events[0].author == agent_2.name - assert events[0].content.parts[0].text == f'Hello, async {agent_2.name}!' - - assert events[1].author == sequential_agent.name - assert events[1].actions.end_of_agent + agent_1 = _TestingAgent(name=f"{request.function.__name__}_test_agent_1") + agent_2 = _TestingAgent(name=f"{request.function.__name__}_test_agent_2") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent, resumable=True + ) + parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState( + current_sub_agent=agent_2.name + ).model_dump(mode="json") + + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # 2 events: + # 1. Agent 2 event + # 2. SequentialAgent final checkpoint event + assert len(events) == 2 + assert events[0].author == agent_2.name + assert events[0].content.parts[0].text == f"Hello, async {agent_2.name}!" + + assert events[1].author == sequential_agent.name + assert events[1].actions.end_of_agent @pytest.mark.asyncio async def test_run_live(request: pytest.FixtureRequest): - agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') - agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') - sequential_agent = SequentialAgent( - name=f'{request.function.__name__}_test_agent', - sub_agents=[ - agent_1, - agent_2, - ], - ) - parent_ctx = await _create_parent_invocation_context( - request.function.__name__, sequential_agent - ) - events = [e async for e in sequential_agent.run_live(parent_ctx)] - - assert len(events) == 2 - assert events[0].author == agent_1.name - assert events[1].author == agent_2.name - assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!' - assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!' + agent_1 = _TestingAgent(name=f"{request.function.__name__}_test_agent_1") + agent_2 = _TestingAgent(name=f"{request.function.__name__}_test_agent_2") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_live(parent_ctx)] + + assert len(events) == 2 + assert events[0].author == agent_1.name + assert events[1].author == agent_2.name + assert events[0].content.parts[0].text == f"Hello, live {agent_1.name}!" + assert events[1].content.parts[0].text == f"Hello, live {agent_2.name}!" + + +@pytest.mark.asyncio +async def test_run_async_with_escalate_action(request: pytest.FixtureRequest): + """Test that SequentialAgent exits early when escalate action is triggered.""" + escalating_agent = _TestingAgentWithEscalateAction( + name=f"{request.function.__name__}_escalating_agent" + ) + normal_agent = _TestingAgent(name=f"{request.function.__name__}_normal_agent") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + escalating_agent, + normal_agent, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # Should only have 1 event from the escalating agent, normal agent should not run + assert len(events) == 1 + assert events[0].author == escalating_agent.name + assert events[0].content.parts[0].text == f"Hello, async {escalating_agent.name}!" + assert events[0].actions.escalate is True + + +@pytest.mark.asyncio +async def test_run_async_escalate_action_in_middle( + request: pytest.FixtureRequest, +): + """Test that SequentialAgent exits when escalation happens in middle of sequence.""" + first_agent = _TestingAgent(name=f"{request.function.__name__}_first_agent") + escalating_agent = _TestingAgentWithEscalateAction( + name=f"{request.function.__name__}_escalating_agent" + ) + third_agent = _TestingAgent(name=f"{request.function.__name__}_third_agent") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + first_agent, + escalating_agent, + third_agent, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # Should have 2 events: one from first agent, one from escalating agent + assert len(events) == 2 + assert events[0].author == first_agent.name + assert events[1].author == escalating_agent.name + assert events[1].actions.escalate is True + + # Verify third agent did not run + third_agent_events = [e for e in events if e.author == third_agent.name] + assert len(third_agent_events) == 0 + + +@pytest.mark.asyncio +async def test_run_async_no_escalate_action(request: pytest.FixtureRequest): + """Test that SequentialAgent continues normally when no escalate action.""" + agent_1 = _TestingAgent(name=f"{request.function.__name__}_test_agent_1") + agent_2 = _TestingAgent(name=f"{request.function.__name__}_test_agent_2") + agent_3 = _TestingAgent(name=f"{request.function.__name__}_test_agent_3") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + agent_1, + agent_2, + agent_3, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # All agents should execute + assert len(events) == 3 + assert events[0].author == agent_1.name + assert events[1].author == agent_2.name + assert events[2].author == agent_3.name + + +@pytest.mark.asyncio +async def test_run_live_with_escalate_action(request: pytest.FixtureRequest): + """Test that SequentialAgent exits early in live mode when escalate is triggered.""" + escalating_agent = _TestingAgentWithEscalateAction( + name=f"{request.function.__name__}_escalating_agent" + ) + normal_agent = _TestingAgent(name=f"{request.function.__name__}_normal_agent") + sequential_agent = SequentialAgent( + name=f"{request.function.__name__}_test_agent", + sub_agents=[ + escalating_agent, + normal_agent, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_live(parent_ctx)] + + # Should only have 1 event from the escalating agent, normal agent should not run + assert len(events) == 1 + assert events[0].author == escalating_agent.name + assert events[0].content.parts[0].text == f"Hello, live {escalating_agent.name}!" + assert events[0].actions.escalate is True