Skip to content

Commit 21def66

Browse files
authored
Merge pull request #516 from weaviate/v6-provider-updates
Model provider code updates
2 parents 33c1956 + d8b68dc commit 21def66

24 files changed

+414
-82
lines changed

src/main/java/io/weaviate/client6/v1/api/collections/Generative.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public enum Kind implements JsonEnum<Kind> {
3939
COHERE("generative-cohere"),
4040
DATABRICKS("generative-databricks"),
4141
FRIENDLIAI("generative-friendliai"),
42-
GOOGLE("generative-palm"),
42+
GOOGLE("generative-google"),
4343
MISTRAL("generative-mistral"),
4444
NVIDIA("generative-nvidia"),
4545
OLLAMA("generative-ollama"),
@@ -185,13 +185,13 @@ public static Generative friendliai(Function<FriendliaiGenerative.Builder, Objec
185185
return FriendliaiGenerative.of(fn);
186186
}
187187

188-
/** Configure a default {@code generative-palm} module. */
188+
/** Configure a default {@code generative-google} module. */
189189
public static Generative googleVertex(String projectId) {
190190
return GoogleGenerative.vertex(projectId);
191191
}
192192

193193
/**
194-
* Configure a {@code generative-palm} module.
194+
* Configure a {@code generative-google} module.
195195
*
196196
* @param projectId Project ID.
197197
* @param fn Lambda expression for optional parameters.
@@ -201,13 +201,13 @@ public static Generative googleVertex(String projectId,
201201
return GoogleGenerative.vertex(projectId, fn);
202202
}
203203

204-
/** Configure a default {@code generative-palm} module. */
204+
/** Configure a default {@code generative-google} module. */
205205
public static Generative googleAiStudio() {
206206
return GoogleGenerative.aiStudio();
207207
}
208208

209209
/**
210-
* Configure a {@code generative-palm} module.
210+
* Configure a {@code generative-google} module.
211211
*
212212
* @param fn Lambda expression for optional parameters.
213213
*/
@@ -399,7 +399,7 @@ default FriendliaiGenerative asFriendliai() {
399399
return _as(Generative.Kind.FRIENDLIAI);
400400
}
401401

402-
/** Is this a {@code generative-palm} provider? */
402+
/** Is this a {@code generative-google} provider? */
403403
default boolean isGoogle() {
404404
return _is(Generative.Kind.GOOGLE);
405405
}
@@ -408,7 +408,7 @@ default boolean isGoogle() {
408408
* Get as {@link GoogleGenerative} instance.
409409
*
410410
* @throws IllegalStateException if the current kind is not
411-
* {@code generative-palm}.
411+
* {@code generative-google}.
412412
*/
413413
default GoogleGenerative asGoogle() {
414414
return _as(Generative.Kind.GOOGLE);

src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public static GenerativeProvider friendliai(
100100
}
101101

102102
/**
103-
* Configure {@code generative-palm} as a dynamic provider.
103+
* Configure {@code generative-google} as a dynamic provider.
104104
*
105105
* @param fn Lambda expression for optional parameters.
106106
*/
@@ -110,7 +110,7 @@ public static GenerativeProvider googleAiStudio(
110110
}
111111

112112
/**
113-
* Configure {@code generative-palm} as a dynamic provider.
113+
* Configure {@code generative-google} as a dynamic provider.
114114
*
115115
* @param projectId Google project ID.
116116
* @param fn Lambda expression for optional parameters.

src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative;
1515

1616
public record AnthropicGenerative(
17+
@SerializedName("baseURL") String baseUrl,
1718
@SerializedName("model") String model,
1819
@SerializedName("maxTokens") Integer maxTokens,
1920
@SerializedName("temperature") Float temperature,
@@ -41,6 +42,7 @@ public static AnthropicGenerative of(Function<Builder, ObjectBuilder<AnthropicGe
4142

4243
public AnthropicGenerative(Builder builder) {
4344
this(
45+
builder.baseUrl,
4446
builder.model,
4547
builder.maxTokens,
4648
builder.temperature,
@@ -55,8 +57,15 @@ public static class Builder implements ObjectBuilder<AnthropicGenerative> {
5557
private String model;
5658
private Integer maxTokens;
5759
private Float temperature;
60+
private String baseUrl;
5861
private final List<String> stopSequences = new ArrayList<>();
5962

63+
/** Base URL of the generative provider. */
64+
public Builder baseUrl(String baseUrl) {
65+
this.baseUrl = baseUrl;
66+
return this;
67+
}
68+
6069
/** Top K value for sampling. */
6170
public Builder topK(int topK) {
6271
this.topK = topK;

src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@ public record AwsGenerative(
1818
@SerializedName("region") String region,
1919
@SerializedName("service") Service service,
2020
@SerializedName("endpoint") String endpoint,
21-
@SerializedName("model") String model) implements Generative {
21+
@SerializedName("model") String model,
22+
@SerializedName("targetModel") String targetModel,
23+
@SerializedName("targetVariant") String targetVariant,
24+
@SerializedName("temperature") Float temperature,
25+
@SerializedName("maxTokenCount") Integer maxTokenCount,
26+
@SerializedName("maxTokensToSample") Integer maxTokensToSample,
27+
@SerializedName("topP") Float topP,
28+
@SerializedName("topK") Integer topK,
29+
@SerializedName("stopSequences") List<String> stopSequences) implements Generative {
2230

2331
@Override
2432
public Generative.Kind _kind() {
@@ -53,7 +61,15 @@ public AwsGenerative(Builder builder) {
5361
builder.region,
5462
builder.service,
5563
builder.endpoint,
56-
builder.model);
64+
builder.model,
65+
builder.targetModel,
66+
builder.targetVariant,
67+
builder.temperature,
68+
builder.maxTokenCount,
69+
builder.maxTokensToSample,
70+
builder.topP,
71+
builder.topK,
72+
builder.stopSequences);
5773
}
5874

5975
public static class Builder implements ObjectBuilder<AwsGenerative> {
@@ -67,6 +83,14 @@ public Builder(Service service, String region) {
6783

6884
private String endpoint;
6985
private String model;
86+
private String targetModel;
87+
private String targetVariant;
88+
private Float temperature;
89+
private Integer maxTokenCount;
90+
private Integer maxTokensToSample;
91+
private Float topP;
92+
private Integer topK;
93+
private final List<String> stopSequences = new ArrayList<>();
7094

7195
/** Base URL of the generative provider. */
7296
protected Builder endpoint(String endpoint) {
@@ -80,6 +104,59 @@ protected Builder model(String model) {
80104
return this;
81105
}
82106

107+
/** Target model for Sagemaker. */
108+
public Builder targetModel(String targetModel) {
109+
this.targetModel = targetModel;
110+
return this;
111+
}
112+
113+
/** Target variant for Sagemaker. */
114+
public Builder targetVariant(String targetVariant) {
115+
this.targetVariant = targetVariant;
116+
return this;
117+
}
118+
119+
/** Control the randomness of the model's output. */
120+
public Builder temperature(Float temperature) {
121+
this.temperature = temperature;
122+
return this;
123+
}
124+
125+
/** Maximum number of tokens to generate. */
126+
public Builder maxTokenCount(Integer maxTokenCount) {
127+
this.maxTokenCount = maxTokenCount;
128+
return this;
129+
}
130+
131+
/** Maximum number of tokens to sample (for Anthropic models). */
132+
public Builder maxTokensToSample(Integer maxTokensToSample) {
133+
this.maxTokensToSample = maxTokensToSample;
134+
return this;
135+
}
136+
137+
/** Top-p sampling parameter. */
138+
public Builder topP(Float topP) {
139+
this.topP = topP;
140+
return this;
141+
}
142+
143+
/** Top-k sampling parameter. */
144+
public Builder topK(Integer topK) {
145+
this.topK = topK;
146+
return this;
147+
}
148+
149+
/** Stop sequences for the model. */
150+
public Builder stopSequences(String... stopSequences) {
151+
return stopSequences(Arrays.asList(stopSequences));
152+
}
153+
154+
/** Stop sequences for the model. */
155+
public Builder stopSequences(List<String> stopSequences) {
156+
this.stopSequences.addAll(stopSequences);
157+
return this;
158+
}
159+
83160
@Override
84161
public AwsGenerative build() {
85162
return new AwsGenerative(this);

src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative;
1515

1616
public record AzureOpenAiGenerative(
17+
@SerializedName("apiVersion") String apiVersion,
1718
@SerializedName("baseURL") String baseUrl,
18-
@SerializedName("frequencyPenaltyProperty") Float frequencyPenalty,
19-
@SerializedName("presencePenaltyProperty") Float presencePenalty,
20-
@SerializedName("maxTokensProperty") Integer maxTokens,
21-
@SerializedName("temperatureProperty") Float temperature,
22-
@SerializedName("topPProperty") Float topP,
19+
@SerializedName("frequencyPenalty") Float frequencyPenalty,
20+
@SerializedName("presencePenalty") Float presencePenalty,
21+
@SerializedName("maxTokens") Integer maxTokens,
22+
@SerializedName("temperature") Float temperature,
23+
@SerializedName("topP") Float topP,
24+
@SerializedName("model") String model,
2325

2426
@SerializedName("resourceName") String resourceName,
2527
@SerializedName("deploymentId") String deploymentId) implements Generative {
@@ -45,12 +47,14 @@ public static AzureOpenAiGenerative of(String resourceName, String deploymentId,
4547

4648
public AzureOpenAiGenerative(Builder builder) {
4749
this(
50+
builder.apiVersion,
4851
builder.baseUrl,
4952
builder.frequencyPenalty,
5053
builder.presencePenalty,
5154
builder.maxTokens,
5255
builder.temperature,
5356
builder.topP,
57+
builder.model,
5458
builder.resourceName,
5559
builder.deploymentId);
5660
}
@@ -59,24 +63,38 @@ public static class Builder implements ObjectBuilder<AzureOpenAiGenerative> {
5963
private final String resourceName;
6064
private final String deploymentId;
6165

66+
private String apiVersion;
6267
private String baseUrl;
6368
private Float frequencyPenalty;
6469
private Float presencePenalty;
6570
private Integer maxTokens;
6671
private Float temperature;
6772
private Float topP;
73+
private String model;
6874

6975
public Builder(String resourceName, String deploymentId) {
7076
this.resourceName = resourceName;
7177
this.deploymentId = deploymentId;
7278
}
7379

80+
/** API version for the generative provider. */
81+
public Builder apiVersion(String apiVersion) {
82+
this.apiVersion = apiVersion;
83+
return this;
84+
}
85+
7486
/** Base URL of the generative provider. */
7587
public Builder baseUrl(String baseUrl) {
7688
this.baseUrl = baseUrl;
7789
return this;
7890
}
7991

92+
/** Select generative model. */
93+
public Builder model(String model) {
94+
this.model = model;
95+
return this;
96+
}
97+
8098
/** Limit the number of tokens to generate in the response. */
8199
public Builder maxTokens(int maxTokens) {
82100
this.maxTokens = maxTokens;

src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515

1616
public record CohereGenerative(
1717
@SerializedName("baseURL") String baseUrl,
18-
@SerializedName("kProperty") Integer topK,
18+
@SerializedName("k") Integer topK,
1919
@SerializedName("model") String model,
20-
@SerializedName("maxTokensProperty") Integer maxTokens,
21-
@SerializedName("temperatureProperty") Float temperature,
22-
@SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty,
23-
@SerializedName("stopSequencesProperty") List<String> stopSequences) implements Generative {
20+
@SerializedName("maxTokens") Integer maxTokens,
21+
@SerializedName("temperature") Float temperature,
22+
@SerializedName("returnLikelihoods") String returnLikelihoodsProperty,
23+
@SerializedName("stopSequences") List<String> stopSequences,
24+
@SerializedName("P") Float topP,
25+
@SerializedName("presencePenalty") Float presencePenalty,
26+
@SerializedName("frequencyPenalty") Float frequencyPenalty) implements Generative {
2427

2528
@Override
2629
public Kind _kind() {
@@ -48,7 +51,10 @@ public CohereGenerative(Builder builder) {
4851
builder.maxTokens,
4952
builder.temperature,
5053
builder.returnLikelihoodsProperty,
51-
builder.stopSequences);
54+
builder.stopSequences,
55+
builder.topP,
56+
builder.presencePenalty,
57+
builder.frequencyPenalty);
5258
}
5359

5460
public static class Builder implements ObjectBuilder<CohereGenerative> {
@@ -58,7 +64,10 @@ public static class Builder implements ObjectBuilder<CohereGenerative> {
5864
private Integer maxTokens;
5965
private Float temperature;
6066
private String returnLikelihoodsProperty;
61-
private List<String> stopSequences = new ArrayList<>();
67+
private final List<String> stopSequences = new ArrayList<>();
68+
private Float topP;
69+
private Float presencePenalty;
70+
private Float frequencyPenalty;
6271

6372
/** Base URL of the generative provider. */
6473
public Builder baseUrl(String baseUrl) {
@@ -72,6 +81,12 @@ public Builder topK(int topK) {
7281
return this;
7382
}
7483

84+
/** Top P value for nucleus sampling. */
85+
public Builder topP(float topP) {
86+
this.topP = topP;
87+
return this;
88+
}
89+
7590
/** Select generative model. */
7691
public Builder model(String model) {
7792
this.model = model;
@@ -100,7 +115,7 @@ public Builder stopSequences(String... stopSequences) {
100115
* Set tokens which should signal the model to stop generating further output.
101116
*/
102117
public Builder stopSequences(List<String> stopSequences) {
103-
this.stopSequences = stopSequences;
118+
this.stopSequences.addAll(stopSequences);
104119
return this;
105120
}
106121

@@ -113,6 +128,16 @@ public Builder temperature(float temperature) {
113128
return this;
114129
}
115130

131+
public Builder presencePenalty(float presencePenalty) {
132+
this.presencePenalty = presencePenalty;
133+
return this;
134+
}
135+
136+
public Builder frequencyPenalty(float frequencyPenalty) {
137+
this.frequencyPenalty = frequencyPenalty;
138+
return this;
139+
}
140+
116141
@Override
117142
public CohereGenerative build() {
118143
return new CohereGenerative(this);

0 commit comments

Comments
 (0)