diff --git a/Directory.Packages.props b/Directory.Packages.props index dea329a59..0f30fc7fa 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -60,10 +60,10 @@ - - - - + + + + diff --git a/extensions/Chunkers/Chunkers/PlainTextChunker.cs b/extensions/Chunkers/Chunkers/PlainTextChunker.cs index 9eee711c5..f645f547d 100644 --- a/extensions/Chunkers/Chunkers/PlainTextChunker.cs +++ b/extensions/Chunkers/Chunkers/PlainTextChunker.cs @@ -62,7 +62,7 @@ internal enum SeparatorTypes ".", "?", "!", "⁉", "⁈", "⁇", "…", // Chinese punctuation "。", "?", "!", ";", ":" -]); + ]); // Prioritized list of characters to split inside a sentence. private static readonly SeparatorTrie s_potentialSeparators = new([ diff --git a/extensions/OpenAI/OpenAI/DependencyInjection.cs b/extensions/OpenAI/OpenAI/DependencyInjection.cs index 11ee2a92e..79e9f37ea 100644 --- a/extensions/OpenAI/OpenAI/DependencyInjection.cs +++ b/extensions/OpenAI/OpenAI/DependencyInjection.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Context; using OpenAI; #pragma warning disable IDE0130 // reduce number of "using" statements @@ -226,12 +227,11 @@ public static IServiceCollection AddOpenAITextEmbeddingGeneration( { config.Validate(); return services - .AddSingleton( - serviceProvider => new OpenAITextEmbeddingGenerator( - config: config, - textTokenizer: textTokenizer, - loggerFactory: serviceProvider.GetService(), - httpClient)); + .AddSingleton(serviceProvider => new OpenAITextEmbeddingGenerator( + config: config, + textTokenizer: textTokenizer, + loggerFactory: serviceProvider.GetService(), + httpClient)); } public static IServiceCollection AddOpenAITextEmbeddingGeneration( @@ -242,12 +242,11 @@ public static IServiceCollection AddOpenAITextEmbeddingGeneration( { config.Validate(); return services - .AddSingleton( - serviceProvider => new OpenAITextEmbeddingGenerator( - config: config, - openAIClient: openAIClient, - textTokenizer: textTokenizer, - loggerFactory: serviceProvider.GetService())); + .AddSingleton(serviceProvider => new OpenAITextEmbeddingGenerator( + config: config, + openAIClient: openAIClient, + textTokenizer: textTokenizer, + loggerFactory: serviceProvider.GetService())); } public static IServiceCollection AddOpenAITextGeneration( @@ -261,6 +260,7 @@ public static IServiceCollection AddOpenAITextGeneration( .AddSingleton(serviceProvider => new OpenAITextGenerator( config: config, textTokenizer: textTokenizer, + contextProvider: serviceProvider.GetService(), loggerFactory: serviceProvider.GetService(), httpClient)); } @@ -276,6 +276,7 @@ public static IServiceCollection AddOpenAITextGeneration( .AddSingleton(serviceProvider => new OpenAITextGenerator( config: config, openAIClient: openAIClient, + contextProvider: serviceProvider.GetService(), textTokenizer: textTokenizer, loggerFactory: serviceProvider.GetService())); } diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index f36812dde..41ecddc92 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.OpenAI.Internals; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -30,8 +31,9 @@ public sealed class OpenAITextGenerator : ITextGenerator private readonly OpenAIChatCompletionService _client; private readonly ITextTokenizer _textTokenizer; private readonly ILogger _log; + private readonly IContextProvider _contextProvider; - private readonly string _textModel; + private readonly string _modelName; /// public int MaxTokenTotal { get; } @@ -41,17 +43,20 @@ public sealed class OpenAITextGenerator : ITextGenerator /// /// Client and model configuration /// Text tokenizer, possibly matching the model used + /// Request context provider with runtime configuration overrides /// App logger factory /// Optional HTTP client with custom settings public OpenAITextGenerator( OpenAIConfig config, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null, HttpClient? httpClient = null) : this( config, OpenAIClientBuilder.Build(config, httpClient, loggerFactory), textTokenizer, + contextProvider, loggerFactory) { } @@ -62,16 +67,19 @@ public OpenAITextGenerator( /// Model configuration /// Custom OpenAI client, already configured /// Text tokenizer, possibly matching the model used + /// Request context provider with runtime configuration overrides /// App logger factory public OpenAITextGenerator( OpenAIConfig config, OpenAIClient openAIClient, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) : this( config, SkClientBuilder.BuildChatClient(config.TextModel, openAIClient, loggerFactory), textTokenizer, + contextProvider, loggerFactory) { } @@ -81,17 +89,20 @@ public OpenAITextGenerator( /// /// Model configuration /// Custom Semantic Kernel client, already configured + /// Request context provider with runtime configuration overrides /// Text tokenizer, possibly matching the model used /// App logger factory public OpenAITextGenerator( OpenAIConfig config, OpenAIChatCompletionService skClient, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) { this._client = skClient; + this._contextProvider = contextProvider ?? new RequestContextProvider(); this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); - this._textModel = config.TextModel; + this._modelName = config.TextModel; this.MaxTokenTotal = config.TextModelMaxTokenTotal; if (textTokenizer == null && !string.IsNullOrEmpty(config.TextModelTokenizer)) @@ -129,13 +140,15 @@ public async IAsyncEnumerable GenerateTextAsync( TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + var modelName = this._contextProvider.GetContext().GetCustomTextGenerationModelNameOrDefault(this._modelName); var skOptions = new OpenAIPromptExecutionSettings { + ModelId = modelName, MaxTokens = options.MaxTokens, Temperature = options.Temperature, FrequencyPenalty = options.FrequencyPenalty, PresencePenalty = options.PresencePenalty, - TopP = options.NucleusSampling + TopP = options.NucleusSampling, }; if (options.StopSequences is { Count: > 0 }) @@ -178,7 +191,7 @@ public async IAsyncEnumerable GenerateTextAsync( Timestamp = (DateTimeOffset?)x.Metadata["CreatedAt"] ?? DateTimeOffset.UtcNow, ServiceType = "OpenAI", ModelType = Constants.ModelType.TextGeneration, - ModelName = this._textModel, + ModelName = modelName, ServiceTokensIn = usage!.InputTokenCount, ServiceTokensOut = usage.OutputTokenCount, ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs index 8deea5766..2d84a90ca 100644 --- a/service/Abstractions/Context/IContext.cs +++ b/service/Abstractions/Context/IContext.cs @@ -244,10 +244,10 @@ public static int GetCustomEmbeddingGenerationBatchSizeOrDefault(this IContext? /// Extensions supported: /// - Ollama /// - Anthropic + /// - OpenAI /// Extensions not supported: /// - Azure OpenAI /// - ONNX - /// - OpenAI /// public static string GetCustomTextGenerationModelNameOrDefault(this IContext? context, string defaultValue) {