diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 91b03ebb..14e20bf3 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -260,32 +260,6 @@ async def __call__( ) -> str: return await self.send(message) - # async def send( - # self, - # message: Union[ - # str, - # PromptMessage, - # PromptMessageExtended, - # Sequence[Union[str, PromptMessage, PromptMessageExtended]], - # ], - # request_params: RequestParams | None = None, - # ) -> str: - # """ - # Send a message to the agent and get a response. - - # Args: - # message: Message content in various formats: - # - String: Converted to a user PromptMessageExtended - # - PromptMessage: Converted to PromptMessageExtended - # - PromptMessageExtended: Used directly - # - request_params: Optional request parameters - - # Returns: - # The agent's response as a string - # """ - # response = await self.generate(message, request_params) - # return response.last_text() or "" - def _matches_pattern(self, name: str, pattern: str, server_name: str) -> bool: """ Check if a name matches a pattern for a specific server. @@ -712,37 +686,6 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam with self._tracer.start_as_current_span(f"Agent: '{self._name}' apply_prompt_template"): return await self._llm.apply_prompt_template(prompt_result, prompt_name) - # async def structured( - # self, - # messages: Union[ - # str, - # PromptMessage, - # PromptMessageExtended, - # Sequence[Union[str, PromptMessage, PromptMessageExtended]], - # ], - # model: Type[ModelT], - # request_params: RequestParams | None = None, - # ) -> Tuple[ModelT | None, PromptMessageExtended]: - # """ - # Apply the prompt and return the result as a Pydantic model. - # Normalizes input messages and delegates to the attached LLM. - - # Args: - # messages: Message(s) in various formats: - # - String: Converted to a user PromptMessageExtended - # - PromptMessage: Converted to PromptMessageExtended - # - PromptMessageExtended: Used directly - # - List of any combination of the above - # model: The Pydantic model class to parse the result into - # request_params: Optional parameters to configure the LLM request - - # Returns: - # An instance of the specified model, or None if coercion fails - # """ - - # with self._tracer.start_as_current_span(f"Agent: '{self._name}' structured"): - # return await super().structured(messages, model, request_params) - async def apply_prompt_messages( self, prompts: List[PromptMessageExtended], request_params: RequestParams | None = None ) -> str: diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index 9fb7d815..93c40069 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -96,7 +96,11 @@ async def generate_impl( if LlmStopReason.TOOL_USE == result.stop_reason: tool_message = await self.run_tools(result) + + # the error channel will be populated if the LLM call failed error_channel_messages = (tool_message.channels or {}).get(FAST_AGENT_ERROR_CHANNEL) + fatal_tool_error = False + if error_channel_messages: tool_result_contents = [ content @@ -107,8 +111,16 @@ async def generate_impl( if result.content is None: result.content = [] result.content.extend(tool_result_contents) - result.stop_reason = LlmStopReason.ERROR + result.stop_reason = LlmStopReason.ERROR + else: + fatal_tool_error = not bool(tool_message.tool_results) + + if fatal_tool_error: + break + elif not tool_message.tool_results: + # No tool results returned at all – treat as unrecoverable. break + if self.config.use_history: messages = [tool_message] else: diff --git a/tests/e2e/smoke/base/test_e2e_smoke.py b/tests/e2e/smoke/base/test_e2e_smoke.py index 86c69a15..544662eb 100644 --- a/tests/e2e/smoke/base/test_e2e_smoke.py +++ b/tests/e2e/smoke/base/test_e2e_smoke.py @@ -180,6 +180,35 @@ class WeatherForecast(BaseModel): summary: str = Field(..., description="Brief summary of the overall forecast") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.e2e +@pytest.mark.parametrize( + "model_name", + ["haiku", "kimi"], +) +async def test_error_handling_e2e(fast_agent, model_name): + """Call a faulty tool and make sure the loop does as we expect.""" + fast = fast_agent + + # Define the agent + @fast.agent( + "agent", + instruction="SYSTEM PROMPT", + model=model_name, + servers=["test_server"], + ) + async def agent_function(): + async with fast.run() as agent: + await agent.agent.generate("fail please") + + assert 4 == len(agent.agent.message_history) + # this makes sure that the user message has the tool result with the error + assert next(iter(agent.agent.message_history[-2].tool_results.values())).isError is True + + await agent_function() + + @pytest.mark.integration @pytest.mark.asyncio @pytest.mark.e2e diff --git a/tests/e2e/smoke/base/test_server.py b/tests/e2e/smoke/base/test_server.py index 42c9672e..c4e03541 100644 --- a/tests/e2e/smoke/base/test_server.py +++ b/tests/e2e/smoke/base/test_server.py @@ -32,6 +32,11 @@ def shirt_colour() -> str: return "blue polka dots" +@app.tool(name="fail_please", description="call when asked to fail") +def fail_please() -> str: + raise ValueError("Intentional failure") + + if __name__ == "__main__": # Run the server using stdio transport app.run(transport="stdio")