Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ public void chatOptionsTest() {
"spring.ai.openai.chat.options.topP=0.56",

// "spring.ai.openai.chat.options.toolChoice.functionName=toolChoiceFunctionName",
"spring.ai.openai.chat.options.toolChoice=" + ModelOptionsUtils.toJsonString(ToolChoiceBuilder.FUNCTION("toolChoiceFunctionName")),
"spring.ai.openai.chat.options.toolChoice=" + ModelOptionsUtils.toJsonString(ToolChoiceBuilder.function("toolChoiceFunctionName")),

"spring.ai.openai.chat.options.tools[0].function.name=myFunction1",
"spring.ai.openai.chat.options.tools[0].function.description=function description",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package org.springframework.ai.openai.api;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;
Expand Down Expand Up @@ -85,6 +87,12 @@ public static Builder builder() {

private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;

private static final String REQUEST_BODY_NULL_MESSAGE = "The request body can not be null.";

private static final String STREAM_FALSE_MESSAGE = "Request must set the stream property to false.";

private static final String ADDITIONAL_HEADERS_NULL_MESSAGE = "The additional HTTP headers can not be null.";

// Store config fields for mutate/copy
private final String baseUrl;

Expand Down Expand Up @@ -183,9 +191,9 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest,
MultiValueMap<String, String> additionalHttpHeader) {

Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
Assert.notNull(chatRequest, REQUEST_BODY_NULL_MESSAGE);
Assert.isTrue(!chatRequest.stream(), STREAM_FALSE_MESSAGE);
Assert.notNull(additionalHttpHeader, ADDITIONAL_HEADERS_NULL_MESSAGE);

// @formatter:off
return this.restClient.post()
Expand Down Expand Up @@ -221,7 +229,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest,
MultiValueMap<String, String> additionalHttpHeader) {

Assert.notNull(chatRequest, "The request body can not be null.");
Assert.notNull(chatRequest, REQUEST_BODY_NULL_MESSAGE);
Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");

AtomicBoolean isInsideTool = new AtomicBoolean(false);
Expand Down Expand Up @@ -284,7 +292,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
*/
public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<T> embeddingRequest) {

Assert.notNull(embeddingRequest, "The request body can not be null.");
Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_MESSAGE);

// Input text to embed, encoded as a string or array of tokens. To embed multiple
// inputs in a single
Expand All @@ -296,7 +304,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
// The input must not exceed the max input tokens for the model (8192 tokens for
// text-embedding-ada-002), cannot
// be an empty string, and any array must be 2048 dimensions or less.
if (embeddingRequest.input() instanceof List list) {
if (embeddingRequest.input() instanceof List<?> list) {
Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty.");
Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less");
Assert.isTrue(
Expand Down Expand Up @@ -1218,7 +1226,7 @@ public static class ToolChoiceBuilder {
/**
* Specifying a particular function forces the model to call that function.
*/
public static Object FUNCTION(String functionName) {
public static Object function(String functionName) {
return Map.of("type", "function", "function", Map.of("name", functionName));
}
}
Expand Down Expand Up @@ -1877,7 +1885,9 @@ public record ChunkChoice(// @formatter:off
@JsonIgnoreProperties(ignoreUnknown = true)
public record Embedding(// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("embedding") @JsonDeserialize(using = OpenAiEmbeddingDeserializer.class) float[] embedding,
@JsonProperty("embedding")
@JsonDeserialize(using = OpenAiEmbeddingDeserializer.class)
float[] embedding,
@JsonProperty("object") String object) { // @formatter:on

/**
Expand All @@ -1891,6 +1901,25 @@ public Embedding(Integer index, float[] embedding) {
this(index, embedding, "embedding");
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Embedding other)) {
return false;
}
return Objects.equals(this.index, other.index) && Arrays.equals(this.embedding, other.embedding)
&& Objects.equals(this.object, other.object);
}

@Override
public int hashCode() {
int result = Objects.hash(this.index, this.object);
result = 31 * result + Arrays.hashCode(this.embedding);
return result;
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.springframework.ai.openai.api;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -201,6 +203,32 @@ public static Builder builder() {
return new Builder();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof UploadFileRequest that)) {
return false;
}
return Arrays.equals(this.file, that.file) && Objects.equals(this.fileName, that.fileName)
&& Objects.equals(this.purpose, that.purpose);
}

@Override
public int hashCode() {
int result = Arrays.hashCode(this.file);
result = 31 * result + Objects.hashCode(this.fileName);
result = 31 * result + Objects.hashCode(this.purpose);
return result;
}

@Override
public String toString() {
return "UploadFileRequest{file=" + Arrays.toString(this.file) + ", fileName="
+ Objects.toString(this.fileName) + ", purpose=" + Objects.toString(this.purpose) + "}";
}

public static final class Builder {

private byte[] file;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiIT {
class OpenAiApiIT {

OpenAiApi openAiApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();

Expand Down Expand Up @@ -117,7 +117,7 @@ void inputAudio() throws IOException {
assertThat(response.getBody()).isNotNull();

assertThat(response.getBody().usage().promptTokensDetails().audioTokens()).isGreaterThan(0);
assertThat(response.getBody().usage().completionTokenDetails().audioTokens()).isEqualTo(0);
assertThat(response.getBody().usage().completionTokenDetails().audioTokens()).isZero();

assertThat(response.getBody().choices().get(0).message().content()).containsIgnoringCase("hobbits");
}
Expand All @@ -135,7 +135,7 @@ void outputAudio() {
assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();

assertThat(response.getBody().usage().promptTokensDetails().audioTokens()).isEqualTo(0);
assertThat(response.getBody().usage().promptTokensDetails().audioTokens()).isZero();
assertThat(response.getBody().usage().completionTokenDetails().audioTokens()).isGreaterThan(0);

assertThat(response.getBody().choices().get(0).message().audioOutput().data()).isNotNull();
Expand All @@ -153,8 +153,9 @@ void streamOutputAudio() {
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(List.of(chatCompletionMessage),
OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), audioParameters, true);

assertThatThrownBy(() -> this.openAiApi.chatCompletionStream(chatCompletionRequest).collectList().block())
.isInstanceOf(RuntimeException.class)
Flux<ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(chatCompletionRequest);

assertThatThrownBy(response::blockLast).isInstanceOf(RuntimeException.class)
.hasMessageContaining("400 Bad Request from POST https://api.openai.com/v1/chat/completions");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -52,8 +51,6 @@ void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
.openAiApi(gpt4Api)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
.build();
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();

// Mutate for Llama
OpenAiApi llamaApi = this.baseApi.mutate()
.baseUrl("https://your-custom-endpoint.com")
Expand All @@ -63,8 +60,6 @@ void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
.openAiApi(llamaApi)
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
.build();
ChatClient llamaClient = ChatClient.builder(llamaModel).build();

// Assert endpoints and models are different
assertThat(gpt4Model).isNotSameAs(llamaModel);
assertThat(gpt4Api).isNotSameAs(llamaApi);
Expand All @@ -78,7 +73,7 @@ void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
void testCloneCreatesDeepCopy() {
OpenAiChatModel clone = this.baseModel.clone();
assertThat(clone).isNotSameAs(this.baseModel);
assertThat(clone.toString()).isEqualTo(this.baseModel.toString());
assertThat(clone).hasToString(this.baseModel.toString());
}

@Test
Expand Down