diff --git a/packages/types/src/providers/anthropic.ts b/packages/types/src/providers/anthropic.ts index 766869c7e6..66abf95ba1 100644 --- a/packages/types/src/providers/anthropic.ts +++ b/packages/types/src/providers/anthropic.ts @@ -2,8 +2,9 @@ import type { ModelInfo } from "../model.js" // https://docs.anthropic.com/en/docs/about-claude/models -export type AnthropicModelId = keyof typeof anthropicModels -export const anthropicDefaultModelId: AnthropicModelId = "claude-sonnet-4-5" +// Allow custom model IDs in addition to known models +export type AnthropicModelId = keyof typeof anthropicModels | string +export const anthropicDefaultModelId = "claude-sonnet-4-5" as const export const anthropicModels = { "claude-sonnet-4-5": { diff --git a/src/api/providers/__tests__/anthropic.spec.ts b/src/api/providers/__tests__/anthropic.spec.ts index b05e50125b..503c38c031 100644 --- a/src/api/providers/__tests__/anthropic.spec.ts +++ b/src/api/providers/__tests__/anthropic.spec.ts @@ -288,5 +288,105 @@ describe("AnthropicHandler", () => { expect(model.info.inputPrice).toBe(6.0) expect(model.info.outputPrice).toBe(22.5) }) + + it("should handle custom model ID not in predefined list", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "custom-claude-model", + }) + const model = handler.getModel() + expect(model.id).toBe("custom-claude-model") + expect(model.info).toBeDefined() + // Should have sensible defaults + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(200000) + expect(model.info.supportsImages).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + }) + + it("should handle third-party Anthropic-compatible models", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "MinMax-M2", + }) + const model = handler.getModel() + expect(model.id).toBe("MinMax-M2") + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(200000) + }) + + it("should handle custom Claude model with hyphen format", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-haiku-4-5", + }) + const model = handler.getModel() + expect(model.id).toBe("claude-haiku-4-5") + expect(model.info).toBeDefined() + expect(model.info.supportsPromptCache).toBe(true) + }) + + it("should handle custom Claude model with date suffix", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-sonnet-4-5-20250929", + }) + const model = handler.getModel() + expect(model.id).toBe("claude-sonnet-4-5-20250929") + expect(model.info).toBeDefined() + expect(model.info.supportsImages).toBe(true) + }) + }) + + describe("createMessage with custom models", () => { + const systemPrompt = "You are a helpful assistant." + + it("should handle custom model with prompt caching", async () => { + const customHandler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "custom-anthropic-model", + }) + + const stream = customHandler.createMessage(systemPrompt, [ + { + role: "user", + content: [{ type: "text" as const, text: "Test message" }], + }, + ]) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should still process the stream correctly + expect(chunks.length).toBeGreaterThan(0) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "custom-anthropic-model", + stream: true, + }), + expect.objectContaining({ + headers: { "anthropic-beta": "prompt-caching-2024-07-31" }, + }), + ) + }) + + it("should use appropriate defaults for unknown model", async () => { + const customHandler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "third-party-claude-compatible", + }) + + await customHandler.completePrompt("Test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "third-party-claude-compatible", + max_tokens: 8192, + }), + ) + }) }) }) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 0e767ce237..d0d823e3cb 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -43,7 +43,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa ): ApiStream { let stream: AnthropicStream const cacheControl: CacheControlEphemeral = { type: "ephemeral" } - let { id: modelId, betas = [], maxTokens, temperature, reasoning: thinking } = this.getModel() + let { id: modelId, betas = [], maxTokens, temperature, reasoning: thinking, info } = this.getModel() // Add 1M context beta flag if enabled for Claude Sonnet 4 and 4.5 if ( @@ -53,98 +53,144 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa betas.push("context-1m-2025-08-07") } - switch (modelId) { - case "claude-sonnet-4-5": - case "claude-sonnet-4-20250514": - case "claude-opus-4-1-20250805": - case "claude-opus-4-20250514": - case "claude-3-7-sonnet-20250219": - case "claude-3-5-sonnet-20241022": - case "claude-3-5-haiku-20241022": - case "claude-3-opus-20240229": - case "claude-haiku-4-5-20251001": - case "claude-3-haiku-20240307": { - /** - * The latest message will be the new user message, one before - * will be the assistant message from a previous request, and - * the user message before that will be a previously cached user - * message. So we need to mark the latest user message as - * ephemeral to cache it for the next request, and mark the - * second to last user message as ephemeral to let the server - * know the last message to retrieve from the cache for the - * current request. - */ - const userMsgIndices = messages.reduce( - (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - stream = await this.client.messages.create( - { - model: modelId, - max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, - temperature, - thinking, - // Setting cache breakpoint for system prompt so new tasks can reuse it. - system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }], - messages: messages.map((message, index) => { - if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { - return { - ...message, - content: - typeof message.content === "string" - ? [{ type: "text", text: message.content, cache_control: cacheControl }] - : message.content.map((content, contentIndex) => - contentIndex === message.content.length - 1 - ? { ...content, cache_control: cacheControl } - : content, - ), + // Check if this is a known model that supports prompt caching + const supportsPromptCache = info.supportsPromptCache ?? false + + // For custom models, check if they support prompt caching + if (supportsPromptCache) { + switch (modelId) { + case "claude-sonnet-4-5": + case "claude-sonnet-4-20250514": + case "claude-opus-4-1-20250805": + case "claude-opus-4-20250514": + case "claude-3-7-sonnet-20250219": + case "claude-3-5-sonnet-20241022": + case "claude-3-5-haiku-20241022": + case "claude-3-opus-20240229": + case "claude-haiku-4-5-20251001": + case "claude-3-haiku-20240307": { + /** + * The latest message will be the new user message, one before + * will be the assistant message from a previous request, and + * the user message before that will be a previously cached user + * message. So we need to mark the latest user message as + * ephemeral to cache it for the next request, and mark the + * second to last user message as ephemeral to let the server + * know the last message to retrieve from the cache for the + * current request. + */ + const userMsgIndices = messages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + stream = await this.client.messages.create( + { + model: modelId, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, + temperature, + thinking, + // Setting cache breakpoint for system prompt so new tasks can reuse it. + system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }], + messages: messages.map((message, index) => { + if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { + return { + ...message, + content: + typeof message.content === "string" + ? [{ type: "text", text: message.content, cache_control: cacheControl }] + : message.content.map((content, contentIndex) => + contentIndex === message.content.length - 1 + ? { ...content, cache_control: cacheControl } + : content, + ), + } } + return message + }), + stream: true, + }, + (() => { + // prompt caching: https://x.com/alexalbert__/status/1823751995901272068 + // https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers + // https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393 + + // Then check for models that support prompt caching + switch (modelId) { + case "claude-sonnet-4-5": + case "claude-sonnet-4-20250514": + case "claude-opus-4-1-20250805": + case "claude-opus-4-20250514": + case "claude-3-7-sonnet-20250219": + case "claude-3-5-sonnet-20241022": + case "claude-3-5-haiku-20241022": + case "claude-3-opus-20240229": + case "claude-haiku-4-5-20251001": + case "claude-3-haiku-20240307": + betas.push("prompt-caching-2024-07-31") + return { headers: { "anthropic-beta": betas.join(",") } } + default: + return undefined } - return message - }), - stream: true, - }, - (() => { - // prompt caching: https://x.com/alexalbert__/status/1823751995901272068 - // https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers - // https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393 - - // Then check for models that support prompt caching - switch (modelId) { - case "claude-sonnet-4-5": - case "claude-sonnet-4-20250514": - case "claude-opus-4-1-20250805": - case "claude-opus-4-20250514": - case "claude-3-7-sonnet-20250219": - case "claude-3-5-sonnet-20241022": - case "claude-3-5-haiku-20241022": - case "claude-3-opus-20240229": - case "claude-haiku-4-5-20251001": - case "claude-3-haiku-20240307": - betas.push("prompt-caching-2024-07-31") - return { headers: { "anthropic-beta": betas.join(",") } } - default: - return undefined - } - })(), - ) - break - } - default: { - stream = (await this.client.messages.create({ - model: modelId, - max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, - temperature, - system: [{ text: systemPrompt, type: "text" }], - messages, - stream: true, - })) as any - break + })(), + ) + break + } + default: { + // Custom model with prompt caching support + const userMsgIndices = messages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + stream = await this.client.messages.create( + { + model: modelId, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, + temperature, + thinking, + // Setting cache breakpoint for system prompt so new tasks can reuse it. + system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }], + messages: messages.map((message, index) => { + if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { + return { + ...message, + content: + typeof message.content === "string" + ? [{ type: "text", text: message.content, cache_control: cacheControl }] + : message.content.map((content, contentIndex) => + contentIndex === message.content.length - 1 + ? { ...content, cache_control: cacheControl } + : content, + ), + } + } + return message + }), + stream: true, + }, + { headers: { "anthropic-beta": "prompt-caching-2024-07-31" } }, + ) + break + } } + } else { + // Models without prompt caching support (or unknown custom models without the flag) + stream = (await this.client.messages.create({ + model: modelId, + max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, + temperature, + thinking, + system: [{ text: systemPrompt, type: "text" }], + messages, + stream: true, + })) as any } let inputTokens = 0 @@ -249,23 +295,45 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa getModel() { const modelId = this.options.apiModelId - let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId - let info: ModelInfo = anthropicModels[id] - - // If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info - if ((id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") && this.options.anthropicBeta1MContext) { - // Use the tier pricing for 1M context - const tier = info.tiers?.[0] - if (tier) { - info = { - ...info, - contextWindow: tier.contextWindow, - inputPrice: tier.inputPrice, - outputPrice: tier.outputPrice, - cacheWritesPrice: tier.cacheWritesPrice, - cacheReadsPrice: tier.cacheReadsPrice, + let id: string + let info: ModelInfo + + // Check if modelId is a known model + if (modelId && modelId in anthropicModels) { + id = modelId as AnthropicModelId + info = anthropicModels[id as keyof typeof anthropicModels] + + // If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info + if ( + (id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") && + this.options.anthropicBeta1MContext + ) { + // Use the tier pricing for 1M context + const tier = info.tiers?.[0] + if (tier) { + info = { + ...info, + contextWindow: tier.contextWindow, + inputPrice: tier.inputPrice, + outputPrice: tier.outputPrice, + cacheWritesPrice: tier.cacheWritesPrice, + cacheReadsPrice: tier.cacheReadsPrice, + } } } + } else if (modelId) { + // Custom model - use sensible defaults + id = modelId + info = { + maxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS, + contextWindow: 200_000, // Default Anthropic context window + supportsImages: true, // Assume modern capabilities + supportsPromptCache: true, // Most Anthropic-compatible APIs support caching + } + } else { + // No model specified - use default + id = anthropicDefaultModelId + info = anthropicModels[anthropicDefaultModelId] } const params = getModelParams({