Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
* ChatCompletionChunk in case of function calling message.
*
* @author Geng Rong
* @author Sun Yuhan
*/
public class DeepSeekStreamFunctionCallingHelper {

Expand Down Expand Up @@ -76,22 +77,23 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
}

private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
String content = (current.content() != null ? current.content()
: "" + ((previous.content() != null) ? previous.content() : ""));
String content = (current.content() != null
? (previous.content() != null ? previous.content() + current.content() : current.content())
: (previous.content() != null ? previous.content() : ""));
Role role = (current.role() != null ? current.role() : previous.role());
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
String name = (current.name() != null ? current.name() : previous.name());
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());

List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
if (previous.toolCalls() != null) {
if (!CollectionUtils.isEmpty(previous.toolCalls())) {
lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
if (previous.toolCalls().size() > 1) {
toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1));
}
}
if (current.toolCalls() != null) {
if (!CollectionUtils.isEmpty(current.toolCalls())) {
if (current.toolCalls().size() > 1) {
throw new IllegalStateException("Currently only one tool call is supported per message!");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.deepseek.api;

import java.util.List;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Unit test for {@link DeepSeekStreamFunctionCallingHelper}.
*
* @author Sun Yuhan
*/
class DeepSeekStreamFunctionCallingHelperTest {

private DeepSeekStreamFunctionCallingHelper helper;

@BeforeEach
void setUp() {
this.helper = new DeepSeekStreamFunctionCallingHelper();
}

@Test
void mergeWhenPreviousIsNullShouldReturnCurrent() {
// Given
ChatCompletionChunk current = new ChatCompletionChunk("id1", List.of(), 123L, "model1", null, null, null, null);

// When
ChatCompletionChunk result = this.helper.merge(null, current);

// Then
assertThat(result).isEqualTo(current);
}

@Test
void mergeShouldMergeBasicFieldsFromCurrentAndPrevious() {
// Given
ChatCompletionChunk previous = new ChatCompletionChunk("id1", List.of(), 123L, "model1", null, null, null,
null);
ChatCompletionChunk current = new ChatCompletionChunk("id2", List.of(), null, null, null, null, null, null);

// When
ChatCompletionChunk result = this.helper.merge(previous, current);

// Then
assertThat(result.id()).isEqualTo("id2"); // from current
assertThat(result.created()).isEqualTo(123L); // from previous
assertThat(result.model()).isEqualTo("model1"); // from previous
}

@Test
void mergeShouldMergeMessagesContent() {
// Given
ChatCompletionMessage previousMsg = new ChatCompletionMessage("Hello ", Role.ASSISTANT, null, null, null);
ChatCompletionMessage currentMsg = new ChatCompletionMessage("World!", Role.ASSISTANT, null, null, null);

ChatCompletionChunk previous = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
null, null);

ChatCompletionChunk current = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
null, null);

// When
ChatCompletionChunk result = this.helper.merge(previous, current);

// Then
assertThat(result.choices().get(0).delta().content()).isEqualTo("Hello World!");
}

@Test
void mergeShouldHandleToolCallsMerging() {
// Given
ChatCompletionFunction func1 = new ChatCompletionFunction("func1", "{\"arg1\":");
ToolCall toolCall1 = new ToolCall("call_123", "function", func1);
ChatCompletionMessage previousMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
List.of(toolCall1));

ChatCompletionFunction func2 = new ChatCompletionFunction("func1", "\"value1\"}");
ToolCall toolCall2 = new ToolCall(null, "function", func2); // No ID -
// continuation
ChatCompletionMessage currentMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
List.of(toolCall2));

ChatCompletionChunk previous = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
null, null);

ChatCompletionChunk current = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
null, null);

// When
ChatCompletionChunk result = this.helper.merge(previous, current);

// Then
assertThat(result.choices()).hasSize(1);
assertThat(result.choices().get(0).delta().toolCalls()).hasSize(1);
ToolCall mergedToolCall = result.choices().get(0).delta().toolCalls().get(0);
assertThat(mergedToolCall.id()).isEqualTo("call_123");
assertThat(mergedToolCall.function().name()).isEqualTo("func1");
assertThat(mergedToolCall.function().arguments()).isEqualTo("{\"arg1\":\"value1\"}");
}

@Test
void mergeWithSingleToolCallShouldWork() {
// Given
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func1", "{}"));
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of(toolCall));

ChatCompletionChunk previous = new ChatCompletionChunk("id", List.of(), 123L, "model", null, null, null, null);
ChatCompletionChunk current = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, msg, null)), 123L, "model", null, null, null,
null);

// When
ChatCompletionChunk result = this.helper.merge(previous, current);

// Then
assertThat(result).isNotNull();
assertThat(result.choices().get(0).delta().toolCalls()).hasSize(1);
}

@Test
void isStreamingToolFunctionCallWhenNullChunkShouldReturnFalse() {
// When & Then
assertThat(this.helper.isStreamingToolFunctionCall(null)).isFalse();
}

@Test
void isStreamingToolFunctionCallWhenEmptyChoicesShouldReturnFalse() {
// Given
ChatCompletionChunk chunk = new ChatCompletionChunk("id", List.of(), 123L, "model", null, null, null, null);

// When & Then
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse();
}

@Test
void isStreamingToolFunctionCallWhenHasToolCallsShouldReturnTrue() {
// Given
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func", "{}"));
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of(toolCall));
ChatCompletionChunk chunk = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, msg, null)), 123L, "model", null, null, null,
null);

// When & Then
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isTrue();
}

@Test
void isStreamingToolFunctionCallFinishWhenFinishReasonIsToolCallsShouldReturnTrue() {
// Given
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, null);
ChatCompletionChunk.ChunkChoice choice = new ChatCompletionChunk.ChunkChoice(
DeepSeekApi.ChatCompletionFinishReason.TOOL_CALLS, 0, msg, null);
ChatCompletionChunk chunk = new ChatCompletionChunk("id", List.of(choice), 123L, "model", null, null, null,
null);

// When & Then
assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue();
}

@Test
void mergeWhenCurrentToolCallsIsEmptyListShouldNotThrowException() {
// Given
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func1", "{}"));
ChatCompletionMessage previousMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
List.of(toolCall));

// Empty list instead of null
ChatCompletionMessage currentMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of());

ChatCompletionChunk previous = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
null, null);

ChatCompletionChunk current = new ChatCompletionChunk("id",
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
null, null);

// When
ChatCompletionChunk result = this.helper.merge(previous, current);

// Then
assertThat(result).isNotNull();
}

}