diff --git a/Directory.Packages.props b/Directory.Packages.props
index dea329a59..0f30fc7fa 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -60,10 +60,10 @@
-
-
-
-
+
+
+
+
diff --git a/extensions/Chunkers/Chunkers/PlainTextChunker.cs b/extensions/Chunkers/Chunkers/PlainTextChunker.cs
index 9eee711c5..f645f547d 100644
--- a/extensions/Chunkers/Chunkers/PlainTextChunker.cs
+++ b/extensions/Chunkers/Chunkers/PlainTextChunker.cs
@@ -62,7 +62,7 @@ internal enum SeparatorTypes
".", "?", "!", "⁉", "⁈", "⁇", "…",
// Chinese punctuation
"。", "?", "!", ";", ":"
-]);
+ ]);
// Prioritized list of characters to split inside a sentence.
private static readonly SeparatorTrie s_potentialSeparators = new([
diff --git a/extensions/OpenAI/OpenAI/DependencyInjection.cs b/extensions/OpenAI/OpenAI/DependencyInjection.cs
index 11ee2a92e..79e9f37ea 100644
--- a/extensions/OpenAI/OpenAI/DependencyInjection.cs
+++ b/extensions/OpenAI/OpenAI/DependencyInjection.cs
@@ -5,6 +5,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.OpenAI;
+using Microsoft.KernelMemory.Context;
using OpenAI;
#pragma warning disable IDE0130 // reduce number of "using" statements
@@ -226,12 +227,11 @@ public static IServiceCollection AddOpenAITextEmbeddingGeneration(
{
config.Validate();
return services
- .AddSingleton(
- serviceProvider => new OpenAITextEmbeddingGenerator(
- config: config,
- textTokenizer: textTokenizer,
- loggerFactory: serviceProvider.GetService(),
- httpClient));
+ .AddSingleton(serviceProvider => new OpenAITextEmbeddingGenerator(
+ config: config,
+ textTokenizer: textTokenizer,
+ loggerFactory: serviceProvider.GetService(),
+ httpClient));
}
public static IServiceCollection AddOpenAITextEmbeddingGeneration(
@@ -242,12 +242,11 @@ public static IServiceCollection AddOpenAITextEmbeddingGeneration(
{
config.Validate();
return services
- .AddSingleton(
- serviceProvider => new OpenAITextEmbeddingGenerator(
- config: config,
- openAIClient: openAIClient,
- textTokenizer: textTokenizer,
- loggerFactory: serviceProvider.GetService()));
+ .AddSingleton(serviceProvider => new OpenAITextEmbeddingGenerator(
+ config: config,
+ openAIClient: openAIClient,
+ textTokenizer: textTokenizer,
+ loggerFactory: serviceProvider.GetService()));
}
public static IServiceCollection AddOpenAITextGeneration(
@@ -261,6 +260,7 @@ public static IServiceCollection AddOpenAITextGeneration(
.AddSingleton(serviceProvider => new OpenAITextGenerator(
config: config,
textTokenizer: textTokenizer,
+ contextProvider: serviceProvider.GetService(),
loggerFactory: serviceProvider.GetService(),
httpClient));
}
@@ -276,6 +276,7 @@ public static IServiceCollection AddOpenAITextGeneration(
.AddSingleton(serviceProvider => new OpenAITextGenerator(
config: config,
openAIClient: openAIClient,
+ contextProvider: serviceProvider.GetService(),
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService()));
}
diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs
index f36812dde..41ecddc92 100644
--- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs
+++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs
@@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI.Internals;
+using Microsoft.KernelMemory.Context;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
@@ -30,8 +31,9 @@ public sealed class OpenAITextGenerator : ITextGenerator
private readonly OpenAIChatCompletionService _client;
private readonly ITextTokenizer _textTokenizer;
private readonly ILogger _log;
+ private readonly IContextProvider _contextProvider;
- private readonly string _textModel;
+ private readonly string _modelName;
///
public int MaxTokenTotal { get; }
@@ -41,17 +43,20 @@ public sealed class OpenAITextGenerator : ITextGenerator
///
/// Client and model configuration
/// Text tokenizer, possibly matching the model used
+ /// Request context provider with runtime configuration overrides
/// App logger factory
/// Optional HTTP client with custom settings
public OpenAITextGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
+ IContextProvider? contextProvider = null,
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(
config,
OpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
+ contextProvider,
loggerFactory)
{
}
@@ -62,16 +67,19 @@ public OpenAITextGenerator(
/// Model configuration
/// Custom OpenAI client, already configured
/// Text tokenizer, possibly matching the model used
+ /// Request context provider with runtime configuration overrides
/// App logger factory
public OpenAITextGenerator(
OpenAIConfig config,
OpenAIClient openAIClient,
ITextTokenizer? textTokenizer = null,
+ IContextProvider? contextProvider = null,
ILoggerFactory? loggerFactory = null)
: this(
config,
SkClientBuilder.BuildChatClient(config.TextModel, openAIClient, loggerFactory),
textTokenizer,
+ contextProvider,
loggerFactory)
{
}
@@ -81,17 +89,20 @@ public OpenAITextGenerator(
///
/// Model configuration
/// Custom Semantic Kernel client, already configured
+ /// Request context provider with runtime configuration overrides
/// Text tokenizer, possibly matching the model used
/// App logger factory
public OpenAITextGenerator(
OpenAIConfig config,
OpenAIChatCompletionService skClient,
ITextTokenizer? textTokenizer = null,
+ IContextProvider? contextProvider = null,
ILoggerFactory? loggerFactory = null)
{
this._client = skClient;
+ this._contextProvider = contextProvider ?? new RequestContextProvider();
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger();
- this._textModel = config.TextModel;
+ this._modelName = config.TextModel;
this.MaxTokenTotal = config.TextModelMaxTokenTotal;
if (textTokenizer == null && !string.IsNullOrEmpty(config.TextModelTokenizer))
@@ -129,13 +140,15 @@ public async IAsyncEnumerable GenerateTextAsync(
TextGenerationOptions options,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
+ var modelName = this._contextProvider.GetContext().GetCustomTextGenerationModelNameOrDefault(this._modelName);
var skOptions = new OpenAIPromptExecutionSettings
{
+ ModelId = modelName,
MaxTokens = options.MaxTokens,
Temperature = options.Temperature,
FrequencyPenalty = options.FrequencyPenalty,
PresencePenalty = options.PresencePenalty,
- TopP = options.NucleusSampling
+ TopP = options.NucleusSampling,
};
if (options.StopSequences is { Count: > 0 })
@@ -178,7 +191,7 @@ public async IAsyncEnumerable GenerateTextAsync(
Timestamp = (DateTimeOffset?)x.Metadata["CreatedAt"] ?? DateTimeOffset.UtcNow,
ServiceType = "OpenAI",
ModelType = Constants.ModelType.TextGeneration,
- ModelName = this._textModel,
+ ModelName = modelName,
ServiceTokensIn = usage!.InputTokenCount,
ServiceTokensOut = usage.OutputTokenCount,
ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount
diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs
index 8deea5766..2d84a90ca 100644
--- a/service/Abstractions/Context/IContext.cs
+++ b/service/Abstractions/Context/IContext.cs
@@ -244,10 +244,10 @@ public static int GetCustomEmbeddingGenerationBatchSizeOrDefault(this IContext?
/// Extensions supported:
/// - Ollama
/// - Anthropic
+ /// - OpenAI
/// Extensions not supported:
/// - Azure OpenAI
/// - ONNX
- /// - OpenAI
///
public static string GetCustomTextGenerationModelNameOrDefault(this IContext? context, string defaultValue)
{