From 4861c6ad19b4a42af1ca7a2ddf53adba91fad2d9 Mon Sep 17 00:00:00 2001 From: Russ Cam Date: Thu, 22 May 2025 15:36:42 +1000 Subject: [PATCH 1/2] Use Qdrant.Client for QdrantMemory internals This commit updates QdrantMemory to use the official Qdrant client for QdrantMemory internals. The official client uses gRPC rather than JSON over REST, with gRPC typically offering a significant overall performance improvement. --- Directory.Packages.props | 1 + .../Qdrant.FunctionalTests/appsettings.json | 2 +- .../Qdrant/Qdrant/Internals/QdrantFilter.cs | 49 ++++++++++ .../Qdrant/Internals/QdrantPointStruct.cs | 81 ++++++++++++++++ extensions/Qdrant/Qdrant/Qdrant.csproj | 1 + extensions/Qdrant/Qdrant/QdrantMemory.cs | 96 +++++++++++-------- 6 files changed, 187 insertions(+), 43 deletions(-) create mode 100644 extensions/Qdrant/Qdrant/Internals/QdrantFilter.cs create mode 100644 extensions/Qdrant/Qdrant/Internals/QdrantPointStruct.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index dea329a59..adf43a563 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -34,6 +34,7 @@ + diff --git a/extensions/Qdrant/Qdrant.FunctionalTests/appsettings.json b/extensions/Qdrant/Qdrant.FunctionalTests/appsettings.json index 2a01687d7..d0ee528e0 100644 --- a/extensions/Qdrant/Qdrant.FunctionalTests/appsettings.json +++ b/extensions/Qdrant/Qdrant.FunctionalTests/appsettings.json @@ -6,7 +6,7 @@ }, "Services": { "Qdrant": { - "Endpoint": "http://127.0.0.1:6333", + "Endpoint": "http://127.0.0.1:6334", "APIKey": "" }, "OpenAI": { diff --git a/extensions/Qdrant/Qdrant/Internals/QdrantFilter.cs b/extensions/Qdrant/Qdrant/Internals/QdrantFilter.cs new file mode 100644 index 000000000..8fb511dbb --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/QdrantFilter.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using Qdrant.Client.Grpc; + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client; + +internal static class QdrantFilter +{ + public static Filter? BuildFilter(IEnumerable?>? tagGroups) + { + if (tagGroups == null) + { + return null; + } + + var list = tagGroups.ToList(); + var filter = new Filter(); + + if (list.Count < 2) + { + var tags = list.FirstOrDefault(); + if (tags == null) + { + return null; + } + + filter.Must.AddRange(tags.Where(t => !string.IsNullOrEmpty(t)).Select(t => Conditions.MatchText("tags", t))); + return filter; + } + + var orFilter = new Filter(); + foreach (var tags in list) + { + if (tags == null) + { + continue; + } + + var andFilter = new Filter(); + andFilter.Must.AddRange(tags.Where(t => !string.IsNullOrEmpty(t)).Select(t => Conditions.MatchText("tags", t))); + orFilter.Should.Add(Conditions.Filter(andFilter)); + } + + filter.Must.Add(Conditions.Filter(orFilter)); + return filter; + } +} diff --git a/extensions/Qdrant/Qdrant/Internals/QdrantPointStruct.cs b/extensions/Qdrant/Qdrant/Internals/QdrantPointStruct.cs new file mode 100644 index 000000000..dd506aacc --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/QdrantPointStruct.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using Google.Protobuf; +using Microsoft.KernelMemory.MemoryStorage; +using Qdrant.Client.Grpc; + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client; + +internal static class QdrantPointStruct +{ + private const string Id = "id"; + private const string Tags = "tags"; + private const string Payload = "payload"; + + public static PointStruct FromMemoryRecord(MemoryRecord record) + { + return new PointStruct + { + Vectors = new Vectors { Vector = new Vector { Data = { record.Vector.Data.ToArray() } } }, + Payload = + { + [Id] = record.Id, + [Tags] = record.Tags.Pairs.Select(tag => $"{tag.Key}{Constants.ReservedEqualsChar}{tag.Value}").ToArray(), + [Payload] = Value.Parser.ParseJson(JsonSerializer.Serialize(record.Payload, QdrantConfig.JSONOptions)), + } + }; + } + + public static MemoryRecord ToMemoryRecord(ScoredPoint scoredPoint, bool withEmbedding = true) + { + MemoryRecord result = new() + { + Id = scoredPoint.Id.Uuid, + Payload = scoredPoint.Payload.TryGetValue(Payload, out var payload) + ? JsonSerializer.Deserialize>(JsonFormatter.Default.Format(payload.StructValue), QdrantConfig.JSONOptions) ?? [] + : [] + }; + + if (withEmbedding) + { + result.Vector = new Embedding(scoredPoint.Vectors.Vector.Data.ToArray()); + } + + foreach (string[] keyValue in scoredPoint.Payload[Tags].ListValue.Values.Select(tag => tag.StringValue.Split(Constants.ReservedEqualsChar, 2))) + { + string key = keyValue[0]; + string? value = keyValue.Length == 1 ? null : keyValue[1]; + result.Tags.Add(key, value); + } + + return result; + } + + public static MemoryRecord ToMemoryRecord(RetrievedPoint retrievedPoint, bool withEmbedding = true) + { + MemoryRecord result = new() + { + Id = retrievedPoint.Id.Uuid, + Payload = retrievedPoint.Payload.TryGetValue(Payload, out var payload) + ? JsonSerializer.Deserialize>(JsonFormatter.Default.Format(payload.StructValue), QdrantConfig.JSONOptions) ?? [] + : [] + }; + + if (withEmbedding) + { + result.Vector = new Embedding(retrievedPoint.Vectors.Vector.Data.ToArray()); + } + + foreach (string[] keyValue in retrievedPoint.Payload[Tags].ListValue.Values.Select(tag => tag.StringValue.Split(Constants.ReservedEqualsChar, 2))) + { + string key = keyValue[0]; + string? value = keyValue.Length == 1 ? null : keyValue[1]; + result.Tags.Add(key, value); + } + + return result; + } +} diff --git a/extensions/Qdrant/Qdrant/Qdrant.csproj b/extensions/Qdrant/Qdrant/Qdrant.csproj index ad7b54ea7..becca1db8 100644 --- a/extensions/Qdrant/Qdrant/Qdrant.csproj +++ b/extensions/Qdrant/Qdrant/Qdrant.csproj @@ -14,6 +14,7 @@ + diff --git a/extensions/Qdrant/Qdrant/QdrantMemory.cs b/extensions/Qdrant/Qdrant/QdrantMemory.cs index 64607f690..c371663e2 100644 --- a/extensions/Qdrant/Qdrant/QdrantMemory.cs +++ b/extensions/Qdrant/Qdrant/QdrantMemory.cs @@ -8,11 +8,14 @@ using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; +using Grpc.Core; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client; using Microsoft.KernelMemory.MemoryStorage; +using Qdrant.Client; +using Qdrant.Client.Grpc; namespace Microsoft.KernelMemory.MemoryDb.Qdrant; @@ -25,7 +28,7 @@ namespace Microsoft.KernelMemory.MemoryDb.Qdrant; public sealed class QdrantMemory : IMemoryDb, IMemoryDbUpsertBatch { private readonly ITextEmbeddingGenerator _embeddingGenerator; - private readonly QdrantClient _qdrantClient; + private readonly QdrantClient _qdrantClient; private readonly ILogger _log; /// @@ -47,43 +50,51 @@ public QdrantMemory( } this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); - this._qdrantClient = new QdrantClient(endpoint: config.Endpoint, apiKey: config.APIKey, loggerFactory: loggerFactory); + this._qdrantClient = new QdrantClient(new Uri(config.Endpoint), apiKey: config.APIKey, loggerFactory: loggerFactory); } /// - public Task CreateIndexAsync( + public async Task CreateIndexAsync( string index, int vectorSize, CancellationToken cancellationToken = default) { index = NormalizeIndexName(index); - return this._qdrantClient.CreateCollectionAsync(index, vectorSize, cancellationToken); + try + { + await this._qdrantClient.CreateCollectionAsync(index, new VectorParams + { + Distance = Distance.Cosine, + Size = (ulong)vectorSize + }, cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException e) when (e.Status.StatusCode == StatusCode.AlreadyExists) + { + this._log.LogInformation("Index already exists"); + } } /// public async Task> GetIndexesAsync(CancellationToken cancellationToken = default) { return await this._qdrantClient - .GetCollectionsAsync(cancellationToken) - .ToListAsync(cancellationToken: cancellationToken) + .ListCollectionsAsync(cancellationToken) .ConfigureAwait(false); } /// - public Task DeleteIndexAsync( + public async Task DeleteIndexAsync( string index, CancellationToken cancellationToken = default) { try { index = NormalizeIndexName(index); - return this._qdrantClient.DeleteCollectionAsync(index, cancellationToken); + await this._qdrantClient.DeleteCollectionAsync(index, cancellationToken: cancellationToken).ConfigureAwait(false); } - catch (IndexNotFoundException) + catch (QdrantException) { this._log.LogInformation("Index not found, nothing to delete"); } - - return Task.CompletedTask; } /// @@ -105,33 +116,33 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable // Call ToList to avoid multiple enumerations (CA1851: Possible multiple enumerations of 'IEnumerable' collection. Consider using an implementation that avoids multiple enumerations). var localRecords = records.ToList(); - var qdrantPoints = new List>(); + var qdrantPoints = new List(); foreach (var record in localRecords) { - QdrantPoint qdrantPoint; + PointStruct qdrantPoint; if (string.IsNullOrEmpty(record.Id)) { record.Id = Guid.NewGuid().ToString("N"); - qdrantPoint = QdrantPoint.FromMemoryRecord(record); + qdrantPoint = QdrantPointStruct.FromMemoryRecord(record); qdrantPoint.Id = Guid.NewGuid(); this._log.LogTrace("Generate new Qdrant point ID {0} and record ID {1}", qdrantPoint.Id, record.Id); } else { - qdrantPoint = QdrantPoint.FromMemoryRecord(record); - QdrantPoint? existingPoint = await this._qdrantClient - .GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken) + qdrantPoint = QdrantPointStruct.FromMemoryRecord(record); + IReadOnlyList retrievedPoints = await this._qdrantClient + .RetrieveAsync(index, new Guid(record.Id), cancellationToken: cancellationToken) .ConfigureAwait(false); - if (existingPoint == null) + if (retrievedPoints.Count == 0) { qdrantPoint.Id = Guid.NewGuid(); this._log.LogTrace("No record with ID {0} found, generated a new point ID {1}", record.Id, qdrantPoint.Id); } else { - qdrantPoint.Id = existingPoint.Id; + qdrantPoint.Id = retrievedPoints[0].Id; this._log.LogTrace("Point ID {0} found, updating...", qdrantPoint.Id); } } @@ -139,7 +150,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable qdrantPoints.Add(qdrantPoint); } - await this._qdrantClient.UpsertVectorsAsync(index, qdrantPoints, cancellationToken).ConfigureAwait(false); + await this._qdrantClient.UpsertAsync(index, qdrantPoints, cancellationToken: cancellationToken).ConfigureAwait(false); foreach (var record in localRecords) { @@ -171,19 +182,19 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false); - List<(QdrantPoint, double)> results; + IReadOnlyList results; try { - results = await this._qdrantClient.GetSimilarListAsync( + results = await this._qdrantClient.SearchAsync( collectionName: index, - target: textEmbedding, - scoreThreshold: minRelevance, - requiredTags: requiredTags, - limit: limit, - withVectors: withEmbeddings, + vector: textEmbedding.Data, + scoreThreshold: Convert.ToSingle(minRelevance), + filter: QdrantFilter.BuildFilter(requiredTags), + limit: Convert.ToUInt64(limit), + vectorsSelector: withEmbeddings, cancellationToken: cancellationToken).ConfigureAwait(false); } - catch (IndexNotFoundException e) + catch (RpcException e) when (e.Status.StatusCode == StatusCode.NotFound) { this._log.LogWarning(e, "Index not found"); // Nothing to return @@ -192,7 +203,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable foreach (var point in results) { - yield return (point.Item1.ToMemoryRecord(), point.Item2); + yield return (QdrantPointStruct.ToMemoryRecord(point, withEmbeddings), point.Score); } } @@ -216,27 +227,27 @@ public async IAsyncEnumerable GetListAsync( requiredTags.AddRange(filters.Select(filter => filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}"))); } - List> results; + ScrollResponse results; try { - results = await this._qdrantClient.GetListAsync( + results = await this._qdrantClient.ScrollAsync( collectionName: index, - requiredTags: requiredTags, + filter: QdrantFilter.BuildFilter(requiredTags), offset: 0, - limit: limit, - withVectors: withEmbeddings, + limit: Convert.ToUInt32(limit), + vectorsSelector: withEmbeddings, cancellationToken: cancellationToken).ConfigureAwait(false); } - catch (IndexNotFoundException e) + catch (RpcException e) when (e.Status.StatusCode == StatusCode.NotFound) { this._log.LogWarning(e, "Index not found"); // Nothing to return yield break; } - foreach (var point in results) + foreach (var point in results.Result) { - yield return point.ToMemoryRecord(); + yield return QdrantPointStruct.ToMemoryRecord(point, withEmbeddings); } } @@ -250,19 +261,20 @@ public async Task DeleteAsync( try { - QdrantPoint? existingPoint = await this._qdrantClient - .GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken) + IReadOnlyList existingPoints = await this._qdrantClient + .RetrieveAsync(index, new Guid(record.Id), cancellationToken: cancellationToken) .ConfigureAwait(false); - if (existingPoint == null) + if (existingPoints.Count == 0) { this._log.LogTrace("No record with ID {0} found, nothing to delete", record.Id); return; } + RetrievedPoint existingPoint = existingPoints[0]; this._log.LogTrace("Point ID {0} found, deleting...", existingPoint.Id); - await this._qdrantClient.DeleteVectorsAsync(index, [existingPoint.Id], cancellationToken).ConfigureAwait(false); + await this._qdrantClient.DeleteAsync(index, new Guid(existingPoint.Id.Uuid), cancellationToken: cancellationToken).ConfigureAwait(false); } - catch (IndexNotFoundException e) + catch (RpcException e) when (e.Status.StatusCode == StatusCode.NotFound) { this._log.LogInformation(e, "Index not found, nothing to delete"); } From 6944ea2aca1e12a2783df2a8e490c9f380351af8 Mon Sep 17 00:00:00 2001 From: Russ Cam Date: Thu, 22 May 2025 15:54:22 +1000 Subject: [PATCH 2/2] Implement IDisposable on QdrantMemory --- extensions/Qdrant/Qdrant/QdrantMemory.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extensions/Qdrant/Qdrant/QdrantMemory.cs b/extensions/Qdrant/Qdrant/QdrantMemory.cs index c371663e2..cdbd19f3f 100644 --- a/extensions/Qdrant/Qdrant/QdrantMemory.cs +++ b/extensions/Qdrant/Qdrant/QdrantMemory.cs @@ -25,7 +25,7 @@ namespace Microsoft.KernelMemory.MemoryDb.Qdrant; /// * allow using more Qdrant specific filtering logic /// [Experimental("KMEXP03")] -public sealed class QdrantMemory : IMemoryDb, IMemoryDbUpsertBatch +public sealed class QdrantMemory : IMemoryDb, IMemoryDbUpsertBatch, IDisposable { private readonly ITextEmbeddingGenerator _embeddingGenerator; private readonly QdrantClient _qdrantClient; @@ -280,6 +280,11 @@ public async Task DeleteAsync( } } + public void Dispose() + { + this._qdrantClient.Dispose(); + } + #region private ================================================================================ // Note: "_" is allowed in Qdrant, but we normalize it to "-" for consistency with other DBs