Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public enum Kind implements JsonEnum<Kind> {
COHERE("generative-cohere"),
DATABRICKS("generative-databricks"),
FRIENDLIAI("generative-friendliai"),
GOOGLE("generative-palm"),
GOOGLE("generative-google"),
MISTRAL("generative-mistral"),
NVIDIA("generative-nvidia"),
OLLAMA("generative-ollama"),
Expand Down Expand Up @@ -185,13 +185,13 @@ public static Generative friendliai(Function<FriendliaiGenerative.Builder, Objec
return FriendliaiGenerative.of(fn);
}

/** Configure a default {@code generative-palm} module. */
/** Configure a default {@code generative-google} module. */
public static Generative googleVertex(String projectId) {
return GoogleGenerative.vertex(projectId);
}

/**
* Configure a {@code generative-palm} module.
* Configure a {@code generative-google} module.
*
* @param projectId Project ID.
* @param fn Lambda expression for optional parameters.
Expand All @@ -201,13 +201,13 @@ public static Generative googleVertex(String projectId,
return GoogleGenerative.vertex(projectId, fn);
}

/** Configure a default {@code generative-palm} module. */
/** Configure a default {@code generative-google} module. */
public static Generative googleAiStudio() {
return GoogleGenerative.aiStudio();
}

/**
* Configure a {@code generative-palm} module.
* Configure a {@code generative-google} module.
*
* @param fn Lambda expression for optional parameters.
*/
Expand Down Expand Up @@ -399,7 +399,7 @@ default FriendliaiGenerative asFriendliai() {
return _as(Generative.Kind.FRIENDLIAI);
}

/** Is this a {@code generative-palm} provider? */
/** Is this a {@code generative-google} provider? */
default boolean isGoogle() {
return _is(Generative.Kind.GOOGLE);
}
Expand All @@ -408,7 +408,7 @@ default boolean isGoogle() {
* Get as {@link GoogleGenerative} instance.
*
* @throws IllegalStateException if the current kind is not
* {@code generative-palm}.
* {@code generative-google}.
*/
default GoogleGenerative asGoogle() {
return _as(Generative.Kind.GOOGLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static GenerativeProvider friendliai(
}

/**
* Configure {@code generative-palm} as a dynamic provider.
* Configure {@code generative-google} as a dynamic provider.
*
* @param fn Lambda expression for optional parameters.
*/
Expand All @@ -110,7 +110,7 @@ public static GenerativeProvider googleAiStudio(
}

/**
* Configure {@code generative-palm} as a dynamic provider.
* Configure {@code generative-google} as a dynamic provider.
*
* @param projectId Google project ID.
* @param fn Lambda expression for optional parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative;

public record AnthropicGenerative(
@SerializedName("baseURL") String baseUrl,
@SerializedName("model") String model,
@SerializedName("maxTokens") Integer maxTokens,
@SerializedName("temperature") Float temperature,
Expand Down Expand Up @@ -41,6 +42,7 @@ public static AnthropicGenerative of(Function<Builder, ObjectBuilder<AnthropicGe

public AnthropicGenerative(Builder builder) {
this(
builder.baseUrl,
builder.model,
builder.maxTokens,
builder.temperature,
Expand All @@ -55,8 +57,15 @@ public static class Builder implements ObjectBuilder<AnthropicGenerative> {
private String model;
private Integer maxTokens;
private Float temperature;
private String baseUrl;
private final List<String> stopSequences = new ArrayList<>();

/** Base URL of the generative provider. */
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

/** Top K value for sampling. */
public Builder topK(int topK) {
this.topK = topK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ public record AwsGenerative(
@SerializedName("region") String region,
@SerializedName("service") Service service,
@SerializedName("endpoint") String endpoint,
@SerializedName("model") String model) implements Generative {
@SerializedName("model") String model,
@SerializedName("targetModel") String targetModel,
@SerializedName("targetVariant") String targetVariant,
@SerializedName("temperature") Float temperature,
@SerializedName("maxTokenCount") Integer maxTokenCount,
@SerializedName("maxTokensToSample") Integer maxTokensToSample,
@SerializedName("topP") Float topP,
@SerializedName("topK") Integer topK,
@SerializedName("stopSequences") List<String> stopSequences) implements Generative {

@Override
public Generative.Kind _kind() {
Expand Down Expand Up @@ -53,7 +61,15 @@ public AwsGenerative(Builder builder) {
builder.region,
builder.service,
builder.endpoint,
builder.model);
builder.model,
builder.targetModel,
builder.targetVariant,
builder.temperature,
builder.maxTokenCount,
builder.maxTokensToSample,
builder.topP,
builder.topK,
builder.stopSequences);
}

public static class Builder implements ObjectBuilder<AwsGenerative> {
Expand All @@ -67,6 +83,14 @@ public Builder(Service service, String region) {

private String endpoint;
private String model;
private String targetModel;
private String targetVariant;
private Float temperature;
private Integer maxTokenCount;
private Integer maxTokensToSample;
private Float topP;
private Integer topK;
private final List<String> stopSequences = new ArrayList<>();

/** Base URL of the generative provider. */
protected Builder endpoint(String endpoint) {
Expand All @@ -80,6 +104,59 @@ protected Builder model(String model) {
return this;
}

/** Target model for Sagemaker. */
public Builder targetModel(String targetModel) {
this.targetModel = targetModel;
return this;
}

/** Target variant for Sagemaker. */
public Builder targetVariant(String targetVariant) {
this.targetVariant = targetVariant;
return this;
}

/** Control the randomness of the model's output. */
public Builder temperature(Float temperature) {
this.temperature = temperature;
return this;
}

/** Maximum number of tokens to generate. */
public Builder maxTokenCount(Integer maxTokenCount) {
this.maxTokenCount = maxTokenCount;
return this;
}

/** Maximum number of tokens to sample (for Anthropic models). */
public Builder maxTokensToSample(Integer maxTokensToSample) {
this.maxTokensToSample = maxTokensToSample;
return this;
}

/** Top-p sampling parameter. */
public Builder topP(Float topP) {
this.topP = topP;
return this;
}

/** Top-k sampling parameter. */
public Builder topK(Integer topK) {
this.topK = topK;
return this;
}

/** Stop sequences for the model. */
public Builder stopSequences(String... stopSequences) {
return stopSequences(Arrays.asList(stopSequences));
}

/** Stop sequences for the model. */
public Builder stopSequences(List<String> stopSequences) {
this.stopSequences.addAll(stopSequences);
return this;
}

@Override
public AwsGenerative build() {
return new AwsGenerative(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative;

public record AzureOpenAiGenerative(
@SerializedName("apiVersion") String apiVersion,
@SerializedName("baseURL") String baseUrl,
@SerializedName("frequencyPenaltyProperty") Float frequencyPenalty,
@SerializedName("presencePenaltyProperty") Float presencePenalty,
@SerializedName("maxTokensProperty") Integer maxTokens,
@SerializedName("temperatureProperty") Float temperature,
@SerializedName("topPProperty") Float topP,
@SerializedName("frequencyPenalty") Float frequencyPenalty,
@SerializedName("presencePenalty") Float presencePenalty,
@SerializedName("maxTokens") Integer maxTokens,
@SerializedName("temperature") Float temperature,
@SerializedName("topP") Float topP,
@SerializedName("model") String model,

@SerializedName("resourceName") String resourceName,
@SerializedName("deploymentId") String deploymentId) implements Generative {
Expand All @@ -45,12 +47,14 @@ public static AzureOpenAiGenerative of(String resourceName, String deploymentId,

public AzureOpenAiGenerative(Builder builder) {
this(
builder.apiVersion,
builder.baseUrl,
builder.frequencyPenalty,
builder.presencePenalty,
builder.maxTokens,
builder.temperature,
builder.topP,
builder.model,
builder.resourceName,
builder.deploymentId);
}
Expand All @@ -59,24 +63,38 @@ public static class Builder implements ObjectBuilder<AzureOpenAiGenerative> {
private final String resourceName;
private final String deploymentId;

private String apiVersion;
private String baseUrl;
private Float frequencyPenalty;
private Float presencePenalty;
private Integer maxTokens;
private Float temperature;
private Float topP;
private String model;

public Builder(String resourceName, String deploymentId) {
this.resourceName = resourceName;
this.deploymentId = deploymentId;
}

/** API version for the generative provider. */
public Builder apiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

/** Base URL of the generative provider. */
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

/** Select generative model. */
public Builder model(String model) {
this.model = model;
return this;
}

/** Limit the number of tokens to generate in the response. */
public Builder maxTokens(int maxTokens) {
this.maxTokens = maxTokens;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

public record CohereGenerative(
@SerializedName("baseURL") String baseUrl,
@SerializedName("kProperty") Integer topK,
@SerializedName("k") Integer topK,
@SerializedName("model") String model,
@SerializedName("maxTokensProperty") Integer maxTokens,
@SerializedName("temperatureProperty") Float temperature,
@SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty,
@SerializedName("stopSequencesProperty") List<String> stopSequences) implements Generative {
@SerializedName("maxTokens") Integer maxTokens,
@SerializedName("temperature") Float temperature,
@SerializedName("returnLikelihoods") String returnLikelihoodsProperty,
@SerializedName("stopSequences") List<String> stopSequences,
@SerializedName("P") Float topP,
@SerializedName("presencePenalty") Float presencePenalty,
@SerializedName("frequencyPenalty") Float frequencyPenalty) implements Generative {

@Override
public Kind _kind() {
Expand Down Expand Up @@ -48,7 +51,10 @@ public CohereGenerative(Builder builder) {
builder.maxTokens,
builder.temperature,
builder.returnLikelihoodsProperty,
builder.stopSequences);
builder.stopSequences,
builder.topP,
builder.presencePenalty,
builder.frequencyPenalty);
}

public static class Builder implements ObjectBuilder<CohereGenerative> {
Expand All @@ -58,7 +64,10 @@ public static class Builder implements ObjectBuilder<CohereGenerative> {
private Integer maxTokens;
private Float temperature;
private String returnLikelihoodsProperty;
private List<String> stopSequences = new ArrayList<>();
private final List<String> stopSequences = new ArrayList<>();
private Float topP;
private Float presencePenalty;
private Float frequencyPenalty;

/** Base URL of the generative provider. */
public Builder baseUrl(String baseUrl) {
Expand All @@ -72,6 +81,12 @@ public Builder topK(int topK) {
return this;
}

/** Top P value for nucleus sampling. */
public Builder topP(float topP) {
this.topP = topP;
return this;
}

/** Select generative model. */
public Builder model(String model) {
this.model = model;
Expand Down Expand Up @@ -100,7 +115,7 @@ public Builder stopSequences(String... stopSequences) {
* Set tokens which should signal the model to stop generating further output.
*/
public Builder stopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
this.stopSequences.addAll(stopSequences);
return this;
}

Expand All @@ -113,6 +128,16 @@ public Builder temperature(float temperature) {
return this;
}

public Builder presencePenalty(float presencePenalty) {
this.presencePenalty = presencePenalty;
return this;
}

public Builder frequencyPenalty(float frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
return this;
}

@Override
public CohereGenerative build() {
return new CohereGenerative(this);
Expand Down
Loading