diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index c4462f7d4..55c622ff7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -39,7 +39,7 @@ public enum Kind implements JsonEnum { COHERE("generative-cohere"), DATABRICKS("generative-databricks"), FRIENDLIAI("generative-friendliai"), - GOOGLE("generative-palm"), + GOOGLE("generative-google"), MISTRAL("generative-mistral"), NVIDIA("generative-nvidia"), OLLAMA("generative-ollama"), @@ -185,13 +185,13 @@ public static Generative friendliai(Function { private String model; private Integer maxTokens; private Float temperature; + private String baseUrl; private final List stopSequences = new ArrayList<>(); + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + /** Top K value for sampling. */ public Builder topK(int topK) { this.topK = topK; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index a5593bf26..0e3dd747f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -18,7 +18,15 @@ public record AwsGenerative( @SerializedName("region") String region, @SerializedName("service") Service service, @SerializedName("endpoint") String endpoint, - @SerializedName("model") String model) implements Generative { + @SerializedName("model") String model, + @SerializedName("targetModel") String targetModel, + @SerializedName("targetVariant") String targetVariant, + @SerializedName("temperature") Float temperature, + @SerializedName("maxTokenCount") Integer maxTokenCount, + @SerializedName("maxTokensToSample") Integer maxTokensToSample, + @SerializedName("topP") Float topP, + @SerializedName("topK") Integer topK, + @SerializedName("stopSequences") List stopSequences) implements Generative { @Override public Generative.Kind _kind() { @@ -53,7 +61,15 @@ public AwsGenerative(Builder builder) { builder.region, builder.service, builder.endpoint, - builder.model); + builder.model, + builder.targetModel, + builder.targetVariant, + builder.temperature, + builder.maxTokenCount, + builder.maxTokensToSample, + builder.topP, + builder.topK, + builder.stopSequences); } public static class Builder implements ObjectBuilder { @@ -67,6 +83,14 @@ public Builder(Service service, String region) { private String endpoint; private String model; + private String targetModel; + private String targetVariant; + private Float temperature; + private Integer maxTokenCount; + private Integer maxTokensToSample; + private Float topP; + private Integer topK; + private final List stopSequences = new ArrayList<>(); /** Base URL of the generative provider. */ protected Builder endpoint(String endpoint) { @@ -80,6 +104,59 @@ protected Builder model(String model) { return this; } + /** Target model for Sagemaker. */ + public Builder targetModel(String targetModel) { + this.targetModel = targetModel; + return this; + } + + /** Target variant for Sagemaker. */ + public Builder targetVariant(String targetVariant) { + this.targetVariant = targetVariant; + return this; + } + + /** Control the randomness of the model's output. */ + public Builder temperature(Float temperature) { + this.temperature = temperature; + return this; + } + + /** Maximum number of tokens to generate. */ + public Builder maxTokenCount(Integer maxTokenCount) { + this.maxTokenCount = maxTokenCount; + return this; + } + + /** Maximum number of tokens to sample (for Anthropic models). */ + public Builder maxTokensToSample(Integer maxTokensToSample) { + this.maxTokensToSample = maxTokensToSample; + return this; + } + + /** Top-p sampling parameter. */ + public Builder topP(Float topP) { + this.topP = topP; + return this; + } + + /** Top-k sampling parameter. */ + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + /** Stop sequences for the model. */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** Stop sequences for the model. */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + @Override public AwsGenerative build() { return new AwsGenerative(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java index 78c47c75f..cce9ad86c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java @@ -14,12 +14,14 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record AzureOpenAiGenerative( + @SerializedName("apiVersion") String apiVersion, @SerializedName("baseURL") String baseUrl, - @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, - @SerializedName("presencePenaltyProperty") Float presencePenalty, - @SerializedName("maxTokensProperty") Integer maxTokens, - @SerializedName("temperatureProperty") Float temperature, - @SerializedName("topPProperty") Float topP, + @SerializedName("frequencyPenalty") Float frequencyPenalty, + @SerializedName("presencePenalty") Float presencePenalty, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature, + @SerializedName("topP") Float topP, + @SerializedName("model") String model, @SerializedName("resourceName") String resourceName, @SerializedName("deploymentId") String deploymentId) implements Generative { @@ -45,12 +47,14 @@ public static AzureOpenAiGenerative of(String resourceName, String deploymentId, public AzureOpenAiGenerative(Builder builder) { this( + builder.apiVersion, builder.baseUrl, builder.frequencyPenalty, builder.presencePenalty, builder.maxTokens, builder.temperature, builder.topP, + builder.model, builder.resourceName, builder.deploymentId); } @@ -59,24 +63,38 @@ public static class Builder implements ObjectBuilder { private final String resourceName; private final String deploymentId; + private String apiVersion; private String baseUrl; private Float frequencyPenalty; private Float presencePenalty; private Integer maxTokens; private Float temperature; private Float topP; + private String model; public Builder(String resourceName, String deploymentId) { this.resourceName = resourceName; this.deploymentId = deploymentId; } + /** API version for the generative provider. */ + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { this.baseUrl = baseUrl; return this; } + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + /** Limit the number of tokens to generate in the response. */ public Builder maxTokens(int maxTokens) { this.maxTokens = maxTokens; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index cc803b20f..2811a61b6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -15,12 +15,15 @@ public record CohereGenerative( @SerializedName("baseURL") String baseUrl, - @SerializedName("kProperty") Integer topK, + @SerializedName("k") Integer topK, @SerializedName("model") String model, - @SerializedName("maxTokensProperty") Integer maxTokens, - @SerializedName("temperatureProperty") Float temperature, - @SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty, - @SerializedName("stopSequencesProperty") List stopSequences) implements Generative { + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature, + @SerializedName("returnLikelihoods") String returnLikelihoodsProperty, + @SerializedName("stopSequences") List stopSequences, + @SerializedName("P") Float topP, + @SerializedName("presencePenalty") Float presencePenalty, + @SerializedName("frequencyPenalty") Float frequencyPenalty) implements Generative { @Override public Kind _kind() { @@ -48,7 +51,10 @@ public CohereGenerative(Builder builder) { builder.maxTokens, builder.temperature, builder.returnLikelihoodsProperty, - builder.stopSequences); + builder.stopSequences, + builder.topP, + builder.presencePenalty, + builder.frequencyPenalty); } public static class Builder implements ObjectBuilder { @@ -58,7 +64,10 @@ public static class Builder implements ObjectBuilder { private Integer maxTokens; private Float temperature; private String returnLikelihoodsProperty; - private List stopSequences = new ArrayList<>(); + private final List stopSequences = new ArrayList<>(); + private Float topP; + private Float presencePenalty; + private Float frequencyPenalty; /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { @@ -72,6 +81,12 @@ public Builder topK(int topK) { return this; } + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + /** Select generative model. */ public Builder model(String model) { this.model = model; @@ -100,7 +115,7 @@ public Builder stopSequences(String... stopSequences) { * Set tokens which should signal the model to stop generating further output. */ public Builder stopSequences(List stopSequences) { - this.stopSequences = stopSequences; + this.stopSequences.addAll(stopSequences); return this; } @@ -113,6 +128,16 @@ public Builder temperature(float temperature) { return this; } + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + @Override public CohereGenerative build() { return new CohereGenerative(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java index f41d359bd..8a34921d3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -18,6 +18,9 @@ public record GoogleGenerative( @SerializedName("apiEndpoint") String apiEndpoint, @SerializedName("modelId") String modelId, @SerializedName("projectId") String projectId, + @SerializedName("endpointId") String endpointId, + @SerializedName("region") String region, + @SerializedName("model") String model, @SerializedName("maxOutputTokens") Integer maxTokens, @SerializedName("topK") Integer topK, @SerializedName("topP") Float topP, @@ -54,6 +57,9 @@ public GoogleGenerative(Builder builder) { builder.apiEndpoint, builder.modelId, builder.projectId, + builder.endpointId, + builder.region, + builder.model, builder.maxTokens, builder.topK, builder.topP, @@ -65,6 +71,9 @@ public abstract static class Builder implements ObjectBuilder private final String projectId; private String modelId; + private String endpointId; + private String region; + private String model; private Integer maxTokens; private Integer topK; private Float topP; @@ -87,6 +96,24 @@ public Builder modelId(String modelId) { return this; } + /** Endpoint ID for Vertex AI. */ + public Builder endpointId(String endpointId) { + this.endpointId = endpointId; + return this; + } + + /** Google region. */ + public Builder region(String region) { + this.region = region; + return this; + } + + /** Generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + /** Limit the number of tokens to generate in the response. */ public Builder maxTokens(int maxTokens) { this.maxTokens = maxTokens; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java index 81f414641..f88b5bb70 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -13,7 +13,8 @@ public record NvidiaGenerative( @SerializedName("baseURL") String baseUrl, @SerializedName("model") String model, @SerializedName("maxTokens") Integer maxTokens, - @SerializedName("temperature") Float temperature) implements Generative { + @SerializedName("temperature") Float temperature, + @SerializedName("topP") Float topP) implements Generative { @Override public Kind _kind() { @@ -38,7 +39,8 @@ public NvidiaGenerative(Builder builder) { builder.baseUrl, builder.model, builder.maxTokens, - builder.temperature); + builder.temperature, + builder.topP); } public static class Builder implements ObjectBuilder { @@ -46,6 +48,7 @@ public static class Builder implements ObjectBuilder { private String model; private Integer maxTokens; private Float temperature; + private Float topP; /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { @@ -74,6 +77,12 @@ public Builder temperature(float temperature) { return this; } + /** Top P value for sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + @Override public NvidiaGenerative build() { return new NvidiaGenerative(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java index 5e8c40db8..670abefca 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -14,14 +14,16 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record OpenAiGenerative( + @SerializedName("apiVersion") String apiVersion, @SerializedName("baseURL") String baseUrl, - @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, - @SerializedName("presencePenaltyProperty") Float presencePenalty, - @SerializedName("maxTokensProperty") Integer maxTokens, - @SerializedName("temperatureProperty") Float temperature, - @SerializedName("topPProperty") Float topP, - - @SerializedName("model") String model) implements Generative { + @SerializedName("frequencyPenalty") Float frequencyPenalty, + @SerializedName("presencePenalty") Float presencePenalty, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature, + @SerializedName("topP") Float topP, + @SerializedName("model") String model, + @SerializedName("reasoningEffort") ReasoningEffort reasoningEffort, + @SerializedName("verbosity") Verbosity verbosity) implements Generative { @Override public Kind _kind() { @@ -43,16 +45,20 @@ public static OpenAiGenerative of(Function { + private String apiVersion; private String baseUrl; private Float frequencyPenalty; private Float presencePenalty; @@ -60,6 +66,14 @@ public static class Builder implements ObjectBuilder { private Float temperature; private Float topP; private String model; + private ReasoningEffort reasoningEffort; + private Verbosity verbosity; + + /** API version for the generative provider. */ + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { @@ -73,6 +87,18 @@ public Builder model(String model) { return this; } + /** Set the reasoning effort level. */ + public Builder reasoningEffort(ReasoningEffort reasoningEffort) { + this.reasoningEffort = reasoningEffort; + return this; + } + + /** Set the verbosity level. */ + public Builder verbosity(Verbosity verbosity) { + this.verbosity = verbosity; + return this; + } + /** Limit the number of tokens to generate in the response. */ public Builder maxTokens(int maxTokens) { this.maxTokens = maxTokens; @@ -110,6 +136,26 @@ public OpenAiGenerative build() { } } + public enum ReasoningEffort { + @SerializedName("minimal") + MINIMAL, + @SerializedName("low") + LOW, + @SerializedName("medium") + MEDIUM, + @SerializedName("high") + HIGH; + } + + public enum Verbosity { + @SerializedName("low") + LOW, + @SerializedName("medium") + MEDIUM, + @SerializedName("high") + HIGH; + } + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java index 687d82dbc..c1d271f5a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -17,7 +17,8 @@ public record XaiGenerative( @SerializedName("baseURL") String baseUrl, @SerializedName("model") String model, @SerializedName("maxTokens") Integer maxTokens, - @SerializedName("temperature") Float temperature) implements Generative { + @SerializedName("temperature") Float temperature, + @SerializedName("topP") Float topP) implements Generative { @Override public Kind _kind() { @@ -42,7 +43,8 @@ public XaiGenerative(Builder builder) { builder.baseUrl, builder.model, builder.maxTokens, - builder.temperature); + builder.temperature, + builder.topP); } public static class Builder implements ObjectBuilder { @@ -50,6 +52,7 @@ public static class Builder implements ObjectBuilder { private String model; private Integer maxTokens; private Float temperature; + private Float topP; /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { @@ -78,6 +81,12 @@ public Builder temperature(float temperature) { return this; } + /** Top P value for sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + @Override public XaiGenerative build() { return new XaiGenerative(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java index ba435e572..5cdf60c16 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java @@ -11,7 +11,7 @@ public record CohereReranker( @SerializedName("model") String model) implements Reranker { public static final String RERANK_ENGLISH_V2 = "rerank-english-v2.0"; - public static final String RERANK_MULTILINGUAL_V2 = "rerank-mulilingual-v2.0"; + public static final String RERANK_MULTILINGUAL_V2 = "rerank-multilingual-v2.0"; @Override public Kind _kind() { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecGoogleVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecGoogleVectorizer.java index 9b4530e6a..449a7e756 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecGoogleVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecGoogleVectorizer.java @@ -18,6 +18,7 @@ public record Multi2VecGoogleVectorizer( @SerializedName("model") String model, @SerializedName("dimensions") Integer dimensions, @SerializedName("location") String location, + @SerializedName("videoIntervalSeconds") Integer videoIntervalSeconds, /** BLOB image properties included in the embedding. */ @SerializedName("imageFields") List imageFields, /** BLOB video properties included in the embedding. */ @@ -80,6 +81,7 @@ public Multi2VecGoogleVectorizer( String model, Integer dimensions, String location, + Integer videoIntervalSeconds, List imageFields, List videoFields, List textFields, @@ -93,6 +95,7 @@ public Multi2VecGoogleVectorizer( this.model = model; this.dimensions = dimensions; this.location = location; + this.videoIntervalSeconds = videoIntervalSeconds; this.imageFields = imageFields; this.videoFields = videoFields; this.textFields = textFields; @@ -107,6 +110,7 @@ public Multi2VecGoogleVectorizer(Builder builder) { builder.model, builder.dimensions, builder.location, + builder.videoIntervalSeconds, builder.imageFields.keySet().stream().toList(), builder.videoFields.keySet().stream().toList(), builder.textFields.keySet().stream().toList(), @@ -132,6 +136,7 @@ public static class Builder implements ObjectBuilder private String model; private String location; private Integer dimensions; + private Integer videoIntervalSeconds; public Builder(String projectId) { this.projectId = projectId; @@ -152,6 +157,11 @@ public Builder dimensions(int dimensions) { return this; } + public Builder videoIntervalSeconds(int videoIntervalSeconds) { + this.videoIntervalSeconds = videoIntervalSeconds; + return this; + } + /** Add BLOB image properties to include in the embedding. */ public Builder imageFields(List fields) { fields.forEach(field -> imageFields.put(field, null)); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecNvidiaVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecNvidiaVectorizer.java index 29dcc9599..42502fc13 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecNvidiaVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecNvidiaVectorizer.java @@ -20,7 +20,6 @@ public record Multi2VecNvidiaVectorizer( @SerializedName("model") String model, /** Whether to apply truncation. */ @SerializedName("truncate") Boolean truncate, - @SerializedName("output_encoding") String outputEncoding, /** BLOB properties included in the embedding. */ @SerializedName("imageFields") List imageFields, /** TEXT properties included in the embedding. */ @@ -68,7 +67,6 @@ public Multi2VecNvidiaVectorizer(Builder builder) { builder.baseUrl, builder.model, builder.truncate, - builder.outputEncoding, builder.imageFields.keySet().stream().toList(), builder.textFields.keySet().stream().toList(), new Weights( @@ -88,7 +86,6 @@ public static class Builder implements ObjectBuilder private String baseUrl; private String model; private Boolean truncate; - private String outputEncoding; /** Set base URL of the embedding service. */ public Builder baseUrl(String baseUrl) { @@ -106,11 +103,6 @@ public Builder truncate(Boolean truncate) { return this; } - public Builder outputEncoding(String outputEncoding) { - this.outputEncoding = outputEncoding; - return this; - } - /** Add BLOB properties to include in the embedding. */ public Builder imageFields(List fields) { fields.forEach(field -> imageFields.put(field, null)); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecVoyageAiVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecVoyageAiVectorizer.java index d32440dcd..e13075ce7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecVoyageAiVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecVoyageAiVectorizer.java @@ -18,7 +18,6 @@ public record Multi2VecVoyageAiVectorizer( @SerializedName("baseURL") String baseUrl, /** Inference model to use. */ @SerializedName("model") String model, - @SerializedName("outputEncoding") String outputEncoding, @SerializedName("truncate") Boolean truncate, /** BLOB properties included in the embedding. */ @SerializedName("imageFields") List imageFields, @@ -71,7 +70,6 @@ public static Multi2VecVoyageAiVectorizer of(Function imageFields, List textFields, @@ -82,7 +80,6 @@ public Multi2VecVoyageAiVectorizer( this.vectorizeCollectionName = false; this.baseUrl = baseUrl; this.model = model; - this.outputEncoding = outputEncoding; this.truncate = truncate; this.imageFields = imageFields; this.textFields = textFields; @@ -95,7 +92,6 @@ public Multi2VecVoyageAiVectorizer(Builder builder) { this( builder.baseUrl, builder.model, - builder.outputEncoding, builder.truncate, builder.imageFields.keySet().stream().toList(), builder.textFields.keySet().stream().toList(), @@ -117,7 +113,6 @@ public static class Builder implements ObjectBuilder sourceProperties, @@ -61,6 +63,7 @@ public Text2MultiVecJinaAiVectorizer( Quantization quantization) { this.model = model; this.dimensions = dimensions; + this.baseUrl = baseUrl; this.vectorizeCollectionName = false; this.sourceProperties = sourceProperties; @@ -72,6 +75,7 @@ public Text2MultiVecJinaAiVectorizer(Builder builder) { this( builder.model, builder.dimensions, + builder.baseUrl, builder.vectorizeCollectionName, builder.sourceProperties, @@ -87,6 +91,7 @@ public static class Builder implements ObjectBuilder sourceProperties, @@ -89,6 +93,8 @@ public Text2VecAwsVectorizer( this.model = model; this.region = region; this.service = service; + this.targetModel = targetModel; + this.targetVariant = targetVariant; this.vectorizeCollectionName = false; this.sourceProperties = sourceProperties; @@ -102,6 +108,8 @@ public Text2VecAwsVectorizer(Builder builder) { builder.model, builder.region, builder.service, + builder.targetModel, + builder.targetVariant, builder.vectorizeCollectionName, builder.sourceProperties, @@ -119,6 +127,8 @@ public abstract static class Builder implements ObjectBuilder sourceProperties, VectorIndex vectorIndex, Quantization quantization) { this.model = model; + this.baseUrl = baseUrl; + this.dimensions = dimensions; this.vectorizeCollectionName = false; this.sourceProperties = sourceProperties; @@ -71,6 +77,8 @@ public Text2VecJinaAiVectorizer( public Text2VecJinaAiVectorizer(Builder builder) { this( builder.model, + builder.baseUrl, + builder.dimensions, builder.vectorizeCollectionName, builder.sourceProperties, @@ -85,12 +93,24 @@ public static class Builder implements ObjectBuilder { private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX; private String model; + private Integer dimensions; + private String baseUrl; public Builder model(String model) { this.model = model; return this; } + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder dimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + /** Add properties to include in the embedding. */ public Builder sourceProperties(String... properties) { return sourceProperties(Arrays.asList(properties)); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecOpenAiVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecOpenAiVectorizer.java index 12892d897..f2eed9719 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecOpenAiVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecOpenAiVectorizer.java @@ -37,9 +37,9 @@ public VectorConfig.Kind _kind() { return VectorConfig.Kind.TEXT2VEC_OPENAI; } - public static String TEXT_EMBEDDING_3_SMALL = "text-embeding-3-small"; - public static String TEXT_EMBEDDING_3_LARGE = "text-embeding-3-large"; - public static String TEXT_EMBEDDING_ADA_002 = "text-embeding-ada-002"; + public static String TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"; + public static String TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"; + public static String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"; @Override public Object _self() { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecTransformersVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecTransformersVectorizer.java index ce05b956d..3f2a3c683 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecTransformersVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecTransformersVectorizer.java @@ -17,6 +17,7 @@ public record Text2VecTransformersVectorizer( @SerializedName("passageInferenceUrl") String passageInferenceUrl, @SerializedName("queryInferenceUrl") String queryInferenceUrl, @SerializedName("poolingStrategy") PoolingStrategy poolingStrategy, + @SerializedName("dimensions") Integer dimensions, /** Properties included in the embedding. */ @SerializedName("sourceProperties") List sourceProperties, @@ -57,6 +58,7 @@ public Text2VecTransformersVectorizer(Builder builder) { builder.passageInferenceUrl, builder.queryInferenceUrl, builder.poolingStrategy, + builder.dimensions, builder.sourceProperties, builder.vectorIndex, builder.quantization); @@ -71,6 +73,7 @@ public static class Builder implements ObjectBuilder cfg + .baseUrl("https://example.com") .topK(1) .maxTokens(2) .temperature(3f) @@ -949,6 +950,7 @@ public static Object[][] testCases() { """ { "generative-anthropic": { + "baseURL": "https://example.com", "topK": 1, "maxTokens": 2, "temperature": 3.0, @@ -964,13 +966,23 @@ public static Object[][] testCases() { "aws-region", "example-model", cfg -> cfg - .model("example-model")), + .model("example-model") + .temperature(0.7f) + .maxTokenCount(100) + .topK(50) + .topP(0.9f) + .stopSequences("STOP", "END")), """ { "generative-aws": { "model": "example-model", "region": "aws-region", - "service": "bedrock" + "service": "bedrock", + "temperature": 0.7, + "maxTokenCount": 100, + "topK": 50, + "topP": 0.9, + "stopSequences": ["STOP", "END"] } } """, @@ -981,13 +993,21 @@ public static Object[][] testCases() { "aws-region", "https://example.com", cfg -> cfg - .endpoint("https://example.com")), + .endpoint("https://example.com") + .targetModel("custom-model") + .targetVariant("variant-1") + .maxTokensToSample(200) + .stopSequences("STOP")), """ { "generative-aws": { "endpoint": "https://example.com", "region": "aws-region", - "service": "sagemaker" + "service": "sagemaker", + "targetModel": "custom-model", + "targetVariant": "variant-1", + "maxTokensToSample": 200, + "stopSequences": ["STOP"] } } """, @@ -1004,12 +1024,12 @@ public static Object[][] testCases() { """ { "generative-cohere": { - "kProperty": 1, - "maxTokensProperty": 2, - "temperatureProperty": 3.0, + "k": 1, + "maxTokens": 2, + "temperature": 3.0, "model": "example-model", - "returnLikelihoodsProperty": "likelihood", - "stopSequencesProperty": ["stop", "halt"] + "returnLikelihoods": "likelihood", + "stopSequences": ["stop", "halt"] } } """, @@ -1077,6 +1097,7 @@ public static Object[][] testCases() { .baseUrl("https://example.com") .maxTokens(2) .temperature(3f) + .topP(0.95f) .model("example-model")), """ { @@ -1084,6 +1105,7 @@ public static Object[][] testCases() { "baseURL": "https://example.com", "maxTokens": 2, "temperature": 3.0, + "topP": 0.95, "model": "example-model" } } @@ -1099,17 +1121,23 @@ public static Object[][] testCases() { .temperature(3f) .topK(4) .topP(5f) - .modelId("example-model")), + .modelId("example-model") + .endpointId("endpoint-123") + .region("us-central1") + .model("gemini-pro")), """ { - "generative-palm": { + "generative-google": { "apiEndpoint": "https://example.com", "maxOutputTokens": 2, "temperature": 3.0, "topK": 4, "topP": 5, "projectId": "google-project", - "modelId": "example-model" + "modelId": "example-model", + "endpointId": "endpoint-123", + "region": "us-central1", + "model": "gemini-pro" } } """, @@ -1134,6 +1162,7 @@ public static Object[][] testCases() { .baseUrl("https://example.com") .maxTokens(2) .temperature(3f) + .topP(0.9f) .model("example-model")), """ { @@ -1141,6 +1170,7 @@ public static Object[][] testCases() { "baseURL": "https://example.com", "maxTokens": 2, "temperature": 3.0, + "topP": 0.9, "model": "example-model" } } @@ -1160,11 +1190,11 @@ public static Object[][] testCases() { { "generative-openai": { "baseURL": "https://example.com", - "frequencyPenaltyProperty": 1.0, - "presencePenaltyProperty": 2.0, - "temperatureProperty": 3.0, - "topPProperty": 4.0, - "maxTokensProperty": 5, + "frequencyPenalty": 1.0, + "presencePenalty": 2.0, + "temperature": 3.0, + "topP": 4.0, + "maxTokens": 5, "model": "o3-mini" } } @@ -1186,11 +1216,11 @@ public static Object[][] testCases() { { "generative-openai": { "baseURL": "https://example.com", - "frequencyPenaltyProperty": 1.0, - "presencePenaltyProperty": 2.0, - "temperatureProperty": 3.0, - "topPProperty": 4.0, - "maxTokensProperty": 5, + "frequencyPenalty": 1.0, + "presencePenalty": 2.0, + "temperature": 3.0, + "topP": 4.0, + "maxTokens": 5, "resourceName": "azure-resource", "deploymentId": "azure-deployment" }