diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 02c35462857..b6575ec1def 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -17,6 +17,7 @@ package org.springframework.ai.model.tool; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -136,7 +137,7 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp throw new IllegalStateException("No tool call requested by the chat model"); } - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + AssistantMessage assistantMessage = safelyMergeAssistantMessageIfEmptyToolCallPresent(toolCallGeneration); ToolContext toolContext = buildToolContext(prompt, assistantMessage); @@ -152,6 +153,35 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp .build(); } + private AssistantMessage safelyMergeAssistantMessageIfEmptyToolCallPresent( + Optional toolCallGeneration) { + if (toolCallGeneration.isEmpty()) { + throw new IllegalStateException("No tool call requested by the chat model"); + } + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + List toolCalls = assistantMessage.getToolCalls(); + List reversedToolCalls = new ArrayList<>(toolCalls); + Collections.reverse(reversedToolCalls); + List newToolCalls = new ArrayList<>(); + StringBuilder args = new StringBuilder(); + for (AssistantMessage.ToolCall toolCall : reversedToolCalls) { + args.append(toolCall.arguments()); + if (StringUtils.hasText(toolCall.name())) { + AssistantMessage.ToolCall newToolCall = new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), + toolCall.name(), args.toString()); + newToolCalls.add(newToolCall); + args = new StringBuilder(); + } + } + Collections.reverse(newToolCalls); + return AssistantMessage.builder() + .content(assistantMessage.getText()) + .toolCalls(newToolCalls) + .media(assistantMessage.getMedia()) + .properties(assistantMessage.getMetadata()) + .build(); + } + private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) { Map toolContextMap = Map.of(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index bd60639c323..7f4d928ea58 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -423,4 +423,59 @@ public String call(String toolInput) { assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); } + @Test + void shouldHandleMultipleGenerationsWithToolCallsWhenNameIsEmpty() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("multiGenTool") + .description("Tool for multiple generations") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"success\"}"; + } + }; + + // Create multiple generations with tool calls + AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "multiGenTool", "{}"); + AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "multiGenTool", "{}"); + AssistantMessage.ToolCall toolCall3 = new AssistantMessage.ToolCall("3", "function", "", "{}"); + + AssistantMessage assistantMessage1 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall1)) + .build(); + + AssistantMessage assistantMessage2 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall2, toolCall3)) + .build(); + + Generation generation1 = new Generation(assistantMessage1); + Generation generation2 = new Generation(assistantMessage2); + + ChatResponse chatResponse = new ChatResponse(List.of(generation1, generation2)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test multiple generations"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "multiGenTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + }