Skip to content
Open
24 changes: 24 additions & 0 deletions src/helpers/searchErrorHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { type DbOperationArgs, MongoDBToolBase } from "../tools/mongodb/mongodbTool.js";
import type { ToolArgs } from "../tools/tool.js";

export abstract class MongoDBToolWithSearchErrorHandler extends MongoDBToolBase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we want this. In #626 we are using the MongoDBToolBase, and it would be best to have all errors centralised in one single place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just now stumbled upon it while I was going through your PR. Certainly that's more preferable.

protected handleError(
error: unknown,
args: ToolArgs<typeof DbOperationArgs>
): Promise<CallToolResult> | CallToolResult {
const CTA = this.server?.areLocalAtlasToolsAvailable() ? "`atlas-local` tools" : "Atlas CLI";
if (error instanceof Error && "codeName" in error && error.codeName === "SearchNotEnabled") {
return {
content: [
{
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
type: "text",
},
],
isError: true,
};
}
return super.handleError(error, args);
}
}
8 changes: 7 additions & 1 deletion src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
UnsubscribeRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";
import assert from "assert";
import type { ToolBase, ToolConstructorParams } from "./tools/tool.js";
import type { ToolBase, ToolCategory, ToolConstructorParams } from "./tools/tool.js";
import { validateConnectionString } from "./helpers/connectionOptions.js";
import { packageInfo } from "./common/packageInfo.js";
import { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js";
Expand Down Expand Up @@ -174,6 +174,12 @@ export class Server {
this.mcpServer.sendResourceListChanged();
}

public areLocalAtlasToolsAvailable(): boolean {
// TODO: remove hacky casts once we merge the local dev tools
const atlasLocalCategory = "atlas-local" as unknown as ToolCategory;
return !!this.tools.filter((tool) => tool.category === atlasLocalCategory).length;
}

public sendResourceUpdated(uri: string): void {
this.session.logger.info({
id: LogId.resourceUpdateFailure,
Expand Down
25 changes: 3 additions & 22 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolCategory } from "../../tool.js";
import { DbOperationArgs } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";
import { MongoDBToolWithSearchErrorHandler } from "../../../helpers/searchErrorHandler.js";

export class CreateIndexTool extends MongoDBToolBase {
export class CreateIndexTool extends MongoDBToolWithSearchErrorHandler {
private vectorSearchIndexDefinition = z.object({
type: z.literal("vectorSearch"),
fields: z
Expand Down Expand Up @@ -113,25 +113,6 @@ export class CreateIndexTool extends MongoDBToolBase {
break;
case "vectorSearch":
{
const isVectorSearchSupported = await this.session.isSearchSupported();
if (!isVectorSearchSupported) {
// TODO: remove hacky casts once we merge the local dev tools
const isLocalAtlasAvailable =
(this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory))
.length ?? 0) > 0;

const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI";
return {
content: [
{
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
type: "text",
},
],
isError: true,
};
}

indexes = await provider.createSearchIndexes(database, collection, [
{
name,
Expand Down
65 changes: 57 additions & 8 deletions src/tools/mongodb/delete/dropIndex.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
import z from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { DbOperationArgs } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, formatUntrustedData, FeatureFlags } from "../../tool.js";
import { ListSearchIndexesTool } from "../search/listSearchIndexes.js";
import { MongoDBToolWithSearchErrorHandler } from "../../../helpers/searchErrorHandler.js";

export class DropIndexTool extends MongoDBToolBase {
export class DropIndexTool extends MongoDBToolWithSearchErrorHandler {
public name = "drop-index";
protected description = "Drop an index for the provided database and collection.";
protected argsShape = {
...DbOperationArgs,
indexName: z.string().nonempty().describe("The name of the index to be dropped."),
type: this.isFeatureFlagEnabled(FeatureFlags.VectorSearch)
? z
.enum(["classic", "search"])
.describe(
"The type of index to be deleted. Use 'classic' for standard indexes and 'search' for atlas search and vector search indexes."
)
: z
.literal("classic")
.default("classic")
.describe("The type of index to be deleted. Is always set to 'classic'."),
};
public operationType: OperationType = "delete";

protected async execute({
database,
collection,
indexName,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
protected async execute(toolArgs: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
switch (toolArgs.type) {
case "classic":
return this.dropClassicIndex(provider, toolArgs);
case "search":
return this.dropSearchIndex(provider, toolArgs);
}
}

private async dropClassicIndex(
provider: NodeDriverServiceProvider,
{ database, collection, indexName }: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> {
const result = await provider.runCommand(database, {
dropIndexes: collection,
index: indexName,
Expand All @@ -35,6 +56,34 @@ export class DropIndexTool extends MongoDBToolBase {
};
}

private async dropSearchIndex(
provider: NodeDriverServiceProvider,
{ database, collection, indexName }: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> {
const searchIndexes = await ListSearchIndexesTool.getSearchIndexes(provider, database, collection);
const indexDoesNotExist = !searchIndexes.find((index) => index.name === indexName);
if (indexDoesNotExist) {
return {
content: formatUntrustedData(
"Index does not exist in the provided namespace.",
JSON.stringify({ indexName, namespace: `${database}.${collection}` })
),
isError: true,
};
}

await provider.dropSearchIndex(database, collection, indexName);
return {
content: formatUntrustedData(
"Successfully dropped the index from the provided namespace.",
JSON.stringify({
indexName,
namespace: `${database}.${collection}`,
})
),
};
}

protected getConfirmationMessage({ database, collection, indexName }: ToolArgs<typeof this.argsShape>): string {
return (
`You are about to drop the \`${indexName}\` index from the \`${database}.${collection}\` namespace:\n\n` +
Expand Down
51 changes: 28 additions & 23 deletions src/tools/mongodb/search/listSearchIndexes.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import type { ToolArgs, OperationType } from "../../tool.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import { formatUntrustedData } from "../../tool.js";
import { EJSON } from "bson";

export type SearchIndexStatus = {
export type SearchIndexWithStatus = {
name: string;
type: string;
status: string;
Expand All @@ -20,14 +21,13 @@ export class ListSearchIndexesTool extends MongoDBToolBase {

protected async execute({ database, collection }: ToolArgs<typeof DbOperationArgs>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const indexes = await provider.getSearchIndexes(database, collection);
const trimmedIndexDefinitions = this.pickRelevantInformation(indexes);
const searchIndexes = await ListSearchIndexesTool.getSearchIndexes(provider, database, collection);

if (trimmedIndexDefinitions.length > 0) {
if (searchIndexes.length > 0) {
return {
content: formatUntrustedData(
`Found ${trimmedIndexDefinitions.length} search and vector search indexes in ${database}.${collection}`,
trimmedIndexDefinitions.map((index) => EJSON.stringify(index)).join("\n")
`Found ${searchIndexes.length} search and vector search indexes in ${database}.${collection}`,
searchIndexes.map((index) => EJSON.stringify(index)).join("\n")
),
};
} else {
Expand All @@ -45,22 +45,6 @@ export class ListSearchIndexesTool extends MongoDBToolBase {
return process.env.VITEST === "true";
}

/**
* Atlas Search index status contains a lot of information that is not relevant for the agent at this stage.
* Like for example, the status on each of the dedicated nodes. We only care about the main status, if it's
* queryable and the index name. We are also picking the index definition as it can be used by the agent to
* understand which fields are available for searching.
**/
protected pickRelevantInformation(indexes: Record<string, unknown>[]): SearchIndexStatus[] {
return indexes.map((index) => ({
name: (index["name"] ?? "default") as string,
type: (index["type"] ?? "UNKNOWN") as string,
status: (index["status"] ?? "UNKNOWN") as string,
queryable: (index["queryable"] ?? false) as boolean,
latestDefinition: index["latestDefinition"] as Document,
}));
}

protected handleError(
error: unknown,
args: ToolArgs<typeof DbOperationArgs>
Expand All @@ -71,11 +55,32 @@ export class ListSearchIndexesTool extends MongoDBToolBase {
{
text: "This MongoDB cluster does not support Search Indexes. Make sure you are using an Atlas Cluster, either remotely in Atlas or using the Atlas Local image, or your cluster supports MongoDB Search.",
type: "text",
isError: true,
},
],
isError: true,
};
}
return super.handleError(error, args);
}

static async getSearchIndexes(
provider: NodeDriverServiceProvider,
database: string,
collection: string
): Promise<SearchIndexWithStatus[]> {
const searchIndexes = await provider.getSearchIndexes(database, collection);
/**
* Atlas Search index status contains a lot of information that is not relevant for the agent at this stage.
* Like for example, the status on each of the dedicated nodes. We only care about the main status, if it's
* queryable and the index name. We are also picking the index definition as it can be used by the agent to
* understand which fields are available for searching.
**/
return searchIndexes.map<SearchIndexWithStatus>((index) => ({
name: (index["name"] ?? "default") as string,
type: (index["type"] ?? "UNKNOWN") as string,
status: (index["status"] ?? "UNKNOWN") as string,
queryable: (index["queryable"] ?? false) as boolean,
latestDefinition: index["latestDefinition"] as Document,
}));
}
}
61 changes: 61 additions & 0 deletions tests/integration/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { Collection } from "mongodb";
import { CompositeLogger } from "../../src/common/logger.js";
import { ExportsManager } from "../../src/common/exportsManager.js";
import { Session } from "../../src/common/session.js";
Expand All @@ -22,6 +23,9 @@ import { Keychain } from "../../src/common/keychain.js";
import { Elicitation } from "../../src/elicitation.js";
import type { MockClientCapabilities, createMockElicitInput } from "../utils/elicitationMocks.js";

export const DEFAULT_WAIT_TIMEOUT = 1000;
export const DEFAULT_RETRY_INTERVAL = 100;

export const driverOptions = setupDriverConfig({
config,
defaults: defaultDriverOptionsFromConfig,
Expand Down Expand Up @@ -417,3 +421,60 @@ export function getDataFromUntrustedContent(content: string): string {
export function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}

export async function waitUntilSearchManagementServiceIsReady(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we using waitUntilSearchIsReady

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you should take a look at mongodbHelpers.ts and see what can be reused and refactor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the same methods. I just renamed / refactored them to re-use parts of the logic.

collection: Collection,
timeout: number = DEFAULT_WAIT_TIMEOUT,
interval: number = DEFAULT_RETRY_INTERVAL
): Promise<void> {
await vi.waitFor(async () => await collection.listSearchIndexes({}).toArray(), { timeout, interval });
}

async function waitUntilSearchIndexIs(
collection: Collection,
searchIndex: string,
indexValidator: (index: { name: string; queryable: boolean }) => boolean,
timeout: number,
interval: number
): Promise<void> {
await vi.waitFor(
async () => {
const searchIndexes = (await collection.listSearchIndexes(searchIndex).toArray()) as {
name: string;
queryable: boolean;
}[];

if (!searchIndexes.some((index) => indexValidator(index))) {
throw new Error("Search index did not pass validation");
}
},
{
timeout,
interval,
}
);
}

export async function waitUntilSearchIndexIsListed(
collection: Collection,
searchIndex: string,
timeout: number = DEFAULT_WAIT_TIMEOUT,
interval: number = DEFAULT_RETRY_INTERVAL
): Promise<void> {
return waitUntilSearchIndexIs(collection, searchIndex, (index) => index.name === searchIndex, timeout, interval);
}

export async function waitUntilSearchIndexIsQueryable(
collection: Collection,
searchIndex: string,
timeout: number = DEFAULT_WAIT_TIMEOUT,
interval: number = DEFAULT_RETRY_INTERVAL
): Promise<void> {
return waitUntilSearchIndexIs(
collection,
searchIndex,
(index) => index.name === searchIndex && index.queryable,
timeout,
interval
);
}
Loading
Loading