Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,41 @@ public DefaultQueryProvider(SqlServerConfig config)
/// <inheritdoc/>
public string PrepareCreateIndexQuery(int sqlServerVersion, string index, int vectorSize)
{
// Cache table names
var collectionsTable = this.GetFullTableName(this._config.MemoryCollectionTableName);
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);
var tagsIndexTable = this.GetFullTableName($"{this._config.TagsTableName}_{index}");
var embeddingsIndexTable = this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}");
var schema = this._config.Schema;
var embeddingsIndexName = $"IXC_{this._config.EmbeddingsTableName}_{index}";

var sql = $"""
BEGIN TRANSACTION;

INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id])
INSERT INTO {collectionsTable}([id])
VALUES (@index);

IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.TagsTableName}_{index}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")}
IF OBJECT_ID(N'{tagsIndexTable}', N'U') IS NULL
CREATE TABLE {tagsIndexTable}
(
[memory_id] UNIQUEIDENTIFIER NOT NULL,
[name] NVARCHAR(256) NOT NULL,
[value] NVARCHAR(256) NOT NULL,
FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
FOREIGN KEY ([memory_id]) REFERENCES {memoryTable}([id])
);

IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
IF OBJECT_ID(N'{embeddingsIndexTable}', N'U') IS NULL
CREATE TABLE {embeddingsIndexTable}
(
[memory_id] UNIQUEIDENTIFIER NOT NULL,
[vector_value_id] [int] NOT NULL,
[vector_value] [float] NOT NULL,
FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
FOREIGN KEY ([memory_id]) REFERENCES {memoryTable}([id])
);

IF OBJECT_ID(N'[{this._config.Schema}.IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]', N'U') IS NULL
CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]
ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
IF OBJECT_ID(N'[{schema}.{embeddingsIndexName}]', N'U') IS NULL
CREATE CLUSTERED COLUMNSTORE INDEX [{embeddingsIndexName}]
ON {embeddingsIndexTable}
{(sqlServerVersion >= 16 ? "ORDER ([memory_id])" : "")};

COMMIT;
Expand All @@ -54,24 +62,28 @@ CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{i
/// <inheritdoc/>
public string PrepareDeleteRecordQuery(string index)
{
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);
var tagsIndexTable = this.GetFullTableName($"{this._config.TagsTableName}_{index}");
var embeddingsIndexTable = this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}");

var sql = $"""
BEGIN TRANSACTION;

DELETE [tags]
FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} [tags]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tags].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
FROM {tagsIndexTable} [tags]
INNER JOIN {memoryTable} ON [tags].[memory_id] = {memoryTable}.[id]
WHERE
{this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
AND {this.GetFullTableName(this._config.MemoryTableName)}.[key]=@key;
{memoryTable}.[collection] = @index
AND {memoryTable}.[key]=@key;

DELETE [embeddings]
FROM {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} [embeddings]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [embeddings].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
FROM {embeddingsIndexTable} [embeddings]
INNER JOIN {memoryTable} ON [embeddings].[memory_id] = {memoryTable}.[id]
WHERE
{this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
AND {this.GetFullTableName(this._config.MemoryTableName)}.[key]=@key;
{memoryTable}.[collection] = @index
AND {memoryTable}.[key]=@key;

DELETE FROM {this.GetFullTableName(this._config.MemoryTableName)}
DELETE FROM {memoryTable}
WHERE [collection] = @index AND [key]=@key;

COMMIT;
Expand All @@ -83,13 +95,17 @@ DELETE [embeddings]
/// <inheritdoc/>
public string PrepareDeleteIndexQuery(string index)
{
var embeddingsIndexTable = this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}");
var tagsIndexTable = this.GetFullTableName($"{this._config.TagsTableName}_{index}");
var collectionsTable = this.GetFullTableName(this._config.MemoryCollectionTableName);

var sql = $"""
BEGIN TRANSACTION;

DROP TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")};
DROP TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")};
DROP TABLE {embeddingsIndexTable};
DROP TABLE {tagsIndexTable};

DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}
DELETE FROM {collectionsTable}
WHERE [id] = @index;

COMMIT;
Expand All @@ -101,7 +117,8 @@ public string PrepareDeleteIndexQuery(string index)
/// <inheritdoc/>
public string PrepareGetIndexesQuery()
{
var sql = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}";
var collectionsTable = this.GetFullTableName(this._config.MemoryCollectionTableName);
var sql = $"SELECT [id] FROM {collectionsTable}";
return sql;
}

Expand All @@ -111,6 +128,8 @@ public string PrepareGetRecordsListQuery(string index,
bool withEmbeddings,
SqlParameterCollection parameters)
{
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);

var queryColumns = "[key], [payload], [tags]";
if (withEmbeddings) { queryColumns += ", [embedding]"; }

Expand All @@ -125,9 +144,9 @@ FROM openjson(@filters) [filters]
SELECT TOP (@limit)
{queryColumns}
FROM
{this.GetFullTableName(this._config.MemoryTableName)}
{memoryTable}
WHERE
{this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
{memoryTable}.[collection] = @index
{this.GenerateFilters(index, parameters, filters)};
""";

Expand All @@ -140,15 +159,17 @@ public string PrepareGetSimilarRecordsListQuery(string index,
bool withEmbedding,
SqlParameterCollection parameters)
{
var queryColumns = $"{this.GetFullTableName(this._config.MemoryTableName)}.[id]," +
$"{this.GetFullTableName(this._config.MemoryTableName)}.[key]," +
$"{this.GetFullTableName(this._config.MemoryTableName)}.[payload]," +
$"{this.GetFullTableName(this._config.MemoryTableName)}.[tags]";
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);
var embeddingsIndexTable = this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}");

var queryColumns = $"{memoryTable}.[id]," +
$"{memoryTable}.[key]," +
$"{memoryTable}.[payload]," +
$"{memoryTable}.[tags]";

if (withEmbedding)
{
queryColumns += $"," +
$"{this.GetFullTableName(this._config.MemoryTableName)}.[embedding]";
queryColumns += $",{memoryTable}.[embedding]";
}

var generatedFilters = this.GenerateFilters(index, parameters, filters);
Expand All @@ -166,24 +187,23 @@ [embedding] as
[similarity] AS
(
SELECT TOP (@limit)
{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[memory_id],
SUM([embedding].[vector_value] * {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[vector_value]) /
{embeddingsIndexTable}.[memory_id],
SUM([embedding].[vector_value] * {embeddingsIndexTable}.[vector_value]) /
(
SQRT(SUM([embedding].[vector_value] * [embedding].[vector_value]))
*
SQRT(SUM({this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[vector_value] * {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[vector_value]))
SQRT(SUM({embeddingsIndexTable}.[vector_value] * {embeddingsIndexTable}.[vector_value]))
) AS cosine_similarity
-- sum([embedding].[vector_value] * {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[vector_value]) as cosine_distance -- Optimized as per https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use
FROM
[embedding]
INNER JOIN
{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} ON [embedding].vector_value_id = {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.vector_value_id
{embeddingsIndexTable} ON [embedding].vector_value_id = {embeddingsIndexTable}.vector_value_id
INNER JOIN
{this.GetFullTableName(this._config.MemoryTableName)} ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
{memoryTable} ON {embeddingsIndexTable}.[memory_id] = {memoryTable}.[id]
WHERE 1=1
{generatedFilters}
GROUP BY
{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}.[memory_id]
{embeddingsIndexTable}.[memory_id]
ORDER BY
cosine_similarity DESC
)
Expand All @@ -193,7 +213,7 @@ SELECT DISTINCT
FROM
[similarity]
INNER JOIN
{this.GetFullTableName(this._config.MemoryTableName)} ON [similarity].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
{memoryTable} ON [similarity].[memory_id] = {memoryTable}.[id]
WHERE
[cosine_similarity] >= @min_relevance_score
{generatedFilters}
Expand All @@ -206,29 +226,33 @@ ORDER BY [cosine_similarity] desc
/// <inheritdoc/>
public string PrepareUpsertRecordsBatchQuery(string index)
{
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);
var embeddingsIndexTable = this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}");
var tagsIndexTable = this.GetFullTableName($"{this._config.TagsTableName}_{index}");

var sql = $"""
BEGIN TRANSACTION;

MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
MERGE INTO {memoryTable}
USING (SELECT @key) as [src]([key])
ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
ON {memoryTable}.[key] = [src].[key]
WHEN MATCHED THEN
UPDATE SET payload=@payload, embedding=@embedding, tags=@tags
WHEN NOT MATCHED THEN
INSERT ([id], [key], [collection], [payload], [tags], [embedding])
VALUES (NEWID(), @key, @index, @payload, @tags, @embedding);

MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt]
MERGE {embeddingsIndexTable} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
{memoryTable}.[id],
cast([vector].[key] AS INT) AS [vector_value_id],
cast([vector].[value] AS FLOAT) AS [vector_value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
FROM {memoryTable}
CROSS APPLY
openjson(@embedding) [vector]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
WHERE {memoryTable}.[key] = @key
AND {memoryTable}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
WHEN MATCHED THEN
Expand All @@ -240,22 +264,22 @@ WHEN NOT MATCHED THEN
[src].[vector_value] );

DELETE FROM [tgt]
FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index;
FROM {tagsIndexTable} AS [tgt]
INNER JOIN {memoryTable} ON [tgt].[memory_id] = {memoryTable}.[id]
WHERE {memoryTable}.[key] = @key
AND {memoryTable}.[collection] = @index;

MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
MERGE {tagsIndexTable} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
{memoryTable}.[id],
cast([tags].[key] AS NVARCHAR(MAX)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
[tag_value].[value] AS [value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
FROM {memoryTable}
CROSS APPLY openjson(@tags) [tags]
CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(MAX)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
WHERE {memoryTable}.[key] = @key
AND {memoryTable}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
WHEN MATCHED THEN
Expand All @@ -275,29 +299,34 @@ WHEN NOT MATCHED THEN
/// <inheritdoc/>
public string PrepareCreateAllSupportingTablesQuery()
{
var collectionsTable = this.GetFullTableName(this._config.MemoryCollectionTableName);
var memoryTable = this.GetFullTableName(this._config.MemoryTableName);
var schema = this._config.Schema;
var memoryTableName = this._config.MemoryTableName; // used for constraint name

var sql = $"""
IF NOT EXISTS (SELECT *
FROM sys.schemas
WHERE name = N'{this._config.Schema}' )
EXEC('CREATE SCHEMA [{this._config.Schema}]');
WHERE name = N'{schema}' )
EXEC('CREATE SCHEMA [{schema}]');

IF OBJECT_ID(N'{this.GetFullTableName(this._config.MemoryCollectionTableName)}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName(this._config.MemoryCollectionTableName)}
IF OBJECT_ID(N'{collectionsTable}', N'U') IS NULL
CREATE TABLE {collectionsTable}
( [id] NVARCHAR(256) NOT NULL,
PRIMARY KEY ([id])
);

IF OBJECT_ID(N'{this.GetFullTableName(this._config.MemoryTableName)}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName(this._config.MemoryTableName)}
IF OBJECT_ID(N'{memoryTable}', N'U') IS NULL
CREATE TABLE {memoryTable}
( [id] UNIQUEIDENTIFIER NOT NULL,
[key] NVARCHAR(256) NOT NULL,
[collection] NVARCHAR(256) NOT NULL,
[payload] NVARCHAR(MAX),
[tags] NVARCHAR(MAX),
[embedding] NVARCHAR(MAX),
PRIMARY KEY ([id]),
FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id]) ON DELETE CASCADE,
CONSTRAINT UK_{this._config.MemoryTableName} UNIQUE([collection], [key])
FOREIGN KEY ([collection]) REFERENCES {collectionsTable}([id]) ON DELETE CASCADE,
CONSTRAINT UK_{memoryTableName} UNIQUE([collection], [key])
);
""";

Expand Down
Loading
Loading