From 0ae70814dd4a3949689c7e9d91d9954fb79a38cd Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Sun, 23 Nov 2025 13:21:28 -0500 Subject: [PATCH 1/3] mcp: expose jsonrpc.Error type and standard error codes Expose the jsonrpc.Error type to allow access to underlying JSON-RPC error codes. Also, expose common JSON-RPC error codes in the jsonrpc package, and MCP error codes in the mcp package. + tests Fixes #452 --- jsonrpc/jsonrpc.go | 17 ++++++ mcp/client.go | 20 +++---- mcp/elicitation_test.go | 5 +- mcp/error_test.go | 117 ++++++++++++++++++++++++++++++++++++++++ mcp/mcp_test.go | 10 ++-- mcp/resource.go | 6 +-- mcp/server.go | 12 ++--- mcp/shared.go | 11 ++-- mcp/streamable_test.go | 2 +- mcp/tool_test.go | 14 ++--- mcp/transport.go | 3 +- 11 files changed, 177 insertions(+), 40 deletions(-) create mode 100644 mcp/error_test.go diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go index 1633d4e3..a9ea78fa 100644 --- a/jsonrpc/jsonrpc.go +++ b/jsonrpc/jsonrpc.go @@ -17,6 +17,8 @@ type ( Request = jsonrpc2.Request // Response is a JSON-RPC response. Response = jsonrpc2.Response + // Error is a structured error in a JSON-RPC response. + Error = jsonrpc2.WireError ) // MakeID coerces the given Go value to an ID. The value should be the @@ -37,3 +39,18 @@ func EncodeMessage(msg Message) ([]byte, error) { func DecodeMessage(data []byte) (Message, error) { return jsonrpc2.DecodeMessage(data) } + +// Standard JSON-RPC 2.0 error codes. +// See https://www.jsonrpc.org/specification#error_object +const ( + // CodeParseError indicates invalid JSON was received by the server. + CodeParseError = -32700 + // CodeInvalidRequest indicates the JSON sent is not a valid Request object. + CodeInvalidRequest = -32600 + // CodeMethodNotFound indicates the method does not exist or is not available. + CodeMethodNotFound = -32601 + // CodeInvalidParams indicates invalid method parameter(s). + CodeInvalidParams = -32602 + // CodeInternalError indicates an internal JSON-RPC error. + CodeInternalError = -32603 +) diff --git a/mcp/client.go b/mcp/client.go index cd784015..f454e320 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -304,14 +304,14 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? - return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support CreateMessage") + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} } return c.opts.CreateMessageHandler(ctx, req) } func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { if c.opts.ElicitationHandler == nil { - return nil, jsonrpc2.NewError(codeInvalidParams, "client does not support elicitation") + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"} } // Validate the elicitation parameters based on the mode. @@ -323,11 +323,11 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, switch mode { case "form": if req.Params.URL != "" { - return nil, jsonrpc2.NewError(codeInvalidParams, "URL must not be set for form elicitation") + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"} } schema, err := validateElicitSchema(req.Params.RequestedSchema) if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, err.Error()) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()} } res, err := c.opts.ElicitationHandler(ctx, req) if err != nil { @@ -337,28 +337,28 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, if schema != nil && res.Content != nil { resolved, err := schema.Resolve(nil) if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)} } if err := resolved.Validate(res.Content); err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)} } err = resolved.ApplyDefaults(&res.Content) if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)} } } return res, nil case "url": if req.Params.RequestedSchema != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, "requestedSchema must not be set for URL elicitation") + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"} } if req.Params.URL == "" { - return nil, jsonrpc2.NewError(codeInvalidParams, "URL must be set for URL elicitation") + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"} } // No schema validation for URL mode, just pass through to handler. return c.opts.ElicitationHandler(ctx, req) default: - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("unsupported elicitation mode: %q", mode)) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)} } } diff --git a/mcp/elicitation_test.go b/mcp/elicitation_test.go index de5f65d3..8da033ed 100644 --- a/mcp/elicitation_test.go +++ b/mcp/elicitation_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // TODO: migrate other elicitation tests here. @@ -74,7 +75,7 @@ func TestElicitationURLMode(t *testing.T) { Message: "URL is missing", }, wantErrMsg: "URL must be set for URL elicitation", - wantErrCode: codeInvalidParams, + wantErrCode: jsonrpc.CodeInvalidParams, }, { name: "schema not allowed", @@ -90,7 +91,7 @@ func TestElicitationURLMode(t *testing.T) { }, }, wantErrMsg: "requestedSchema must not be set for URL elicitation", - wantErrCode: codeInvalidParams, + wantErrCode: jsonrpc.CodeInvalidParams, }, } for _, tc := range testCases { diff --git a/mcp/error_test.go b/mcp/error_test.go new file mode 100644 index 00000000..64ca5a6f --- /dev/null +++ b/mcp/error_test.go @@ -0,0 +1,117 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "errors" + "testing" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// TestServerErrors validates that the server returns appropriate error codes +// for various invalid requests. +func TestServerErrors(t *testing.T) { + ctx := context.Background() + + // Set up a server with tools, prompts, and resources for testing + cs, _, cleanup := basicConnection(t, func(s *Server) { + // Add a tool with required parameters + type RequiredParams struct { + Name string `json:"name" jsonschema:"the name is required"` + } + handler := func(ctx context.Context, req *CallToolRequest, args RequiredParams) (*CallToolResult, any, error) { + return &CallToolResult{ + Content: []Content{&TextContent{Text: "success"}}, + }, nil, nil + } + AddTool(s, &Tool{Name: "validate", Description: "validates params"}, handler) + + // Add a prompt + s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) + + // Add a resource that returns ResourceNotFoundError + s.AddResource( + &Resource{URI: "file:///test.txt", Name: "test", MIMEType: "text/plain"}, + func(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + return nil, ResourceNotFoundError(req.Params.URI) + }, + ) + }) + defer cleanup() + + testCases := []struct { + name string + executeCall func() error + expectedCode int64 + }{ + { + name: "missing required param", + executeCall: func() error { + _, err := cs.CallTool(ctx, &CallToolParams{ + Name: "validate", + Arguments: map[string]any{}, // Missing required "name" field + }) + return err + }, + expectedCode: jsonrpc.CodeInvalidParams, + }, + { + name: "unknown tool", + executeCall: func() error { + _, err := cs.CallTool(ctx, &CallToolParams{ + Name: "nonexistent_tool", + Arguments: map[string]any{}, + }) + return err + }, + expectedCode: jsonrpc.CodeInvalidParams, + }, + { + name: "unknown prompt", + executeCall: func() error { + _, err := cs.GetPrompt(ctx, &GetPromptParams{ + Name: "nonexistent_prompt", + Arguments: map[string]string{}, + }) + return err + }, + expectedCode: jsonrpc.CodeInvalidParams, + }, + { + name: "resource not found", + executeCall: func() error { + _, err := cs.ReadResource(ctx, &ReadResourceParams{ + URI: "file:///test.txt", + }) + return err + }, + expectedCode: CodeResourceNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.executeCall() + if err == nil { + t.Fatal("expected error, got nil") + } + + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) { + t.Fatalf("expected jsonrpc.Error, got %T: %v", err, err) + } + + if rpcErr.Code != tc.expectedCode { + t.Errorf("expected error code %d, got %d", tc.expectedCode, rpcErr.Code) + } + + if rpcErr.Message == "" { + t.Error("expected non-empty error message") + } + }) + } +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index c2c949e8..3edfefd7 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -24,7 +24,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type hiParams struct { @@ -316,7 +316,7 @@ func TestEndToEnd(t *testing.T) { } { rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) if err != nil { - if code := errorCode(err); code == codeResourceNotFound { + if code := errorCode(err); code == CodeResourceNotFound { if tt.mimeType != "" { t.Errorf("%s: not found but expected it to be", tt.uri) } @@ -576,7 +576,7 @@ func errorCode(err error) int64 { if err == nil { return 0 } - var werr *jsonrpc2.WireError + var werr *jsonrpc.Error if errors.As(err, &werr) { return werr.Code } @@ -1367,8 +1367,8 @@ func TestElicitationSchemaValidation(t *testing.T) { t.Errorf("expected error for invalid schema %q, got nil", tc.name) return } - if code := errorCode(err); code != codeInvalidParams { - t.Errorf("got error code %d, want %d (CodeInvalidParams)", code, codeInvalidParams) + if code := errorCode(err); code != jsonrpc.CodeInvalidParams { + t.Errorf("got error code %d, want %d (CodeInvalidParams)", code, jsonrpc.CodeInvalidParams) } if !strings.Contains(err.Error(), tc.expectedError) { t.Errorf("error message %q does not contain expected text %q", err.Error(), tc.expectedError) diff --git a/mcp/resource.go b/mcp/resource.go index 8746edae..dc657f5d 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -15,8 +15,8 @@ import ( "path/filepath" "strings" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/yosida95/uritemplate/v3" ) @@ -40,8 +40,8 @@ type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceR // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. func ResourceNotFoundError(uri string) error { - return &jsonrpc2.WireError{ - Code: codeResourceNotFound, + return &jsonrpc.Error{ + Code: CodeResourceNotFound, Message: "Resource not found", Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), } diff --git a/mcp/server.go b/mcp/server.go index 254c2d5e..7e85f3c6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -293,12 +293,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // Call typed handler. res, out, err := h(ctx, req, in) // Handle server errors appropriately: - // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a structured error (like jsonrpc.Error), return it directly // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true // - This allows tools to distinguish between protocol errors and tool execution errors if err != nil { // Check if this is already a structured JSON-RPC error - if wireErr, ok := err.(*jsonrpc2.WireError); ok { + if wireErr, ok := err.(*jsonrpc.Error); ok { return nil, wireErr } // For regular errors, embed them in the tool result as per MCP spec @@ -542,8 +542,8 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm s.mu.Unlock() if !ok { // Return a proper JSON-RPC error with the correct error code - return nil, &jsonrpc2.WireError{ - Code: codeInvalidParams, + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } @@ -569,8 +569,8 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, &jsonrpc2.WireError{ - Code: codeInvalidParams, + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unknown tool %q", req.Params.Name), } } diff --git a/mcp/shared.go b/mcp/shared.go index 3fac40b2..0f2f64dd 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -331,16 +331,19 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont } } -// Error codes +// MCP-specific error codes. +const ( + // CodeResourceNotFound indicates that a requested resource could not be found. + CodeResourceNotFound = -32002 +) + +// Internal error codes const ( - codeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. // // TODO(rfindley): this code is wrong, and we should fix it to be // consistent with other SDKs. codeUnsupportedMethod = -31001 - // The error code for invalid parameters - codeInvalidParams = -32602 ) // notifySessions calls Notify on all the sessions. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e9c0cbda..8c66bf6e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -962,7 +962,7 @@ func TestStreamableServerTransport(t *testing.T) { method: "POST", messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ + wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc.Error{ Message: `method "tools/call" is invalid during session initialization`, })}, }, diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 80661d31..dfd859be 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -14,7 +14,7 @@ import ( "testing" "github.com/google/jsonschema-go/jsonschema" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestApplySchema(t *testing.T) { @@ -65,8 +65,8 @@ func TestToolErrorHandling(t *testing.T) { // Create a tool that returns a structured error structuredErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { - return nil, nil, &jsonrpc2.WireError{ - Code: codeInvalidParams, + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, Message: "internal server error", } } @@ -106,13 +106,13 @@ func TestToolErrorHandling(t *testing.T) { t.Fatal("expected error, got nil") } - var wireErr *jsonrpc2.WireError + var wireErr *jsonrpc.Error if !errors.As(err, &wireErr) { - t.Fatalf("expected WireError, got %[1]T: %[1]v", err) + t.Fatalf("expected jsonrpc.Error, got %[1]T: %[1]v", err) } - if wireErr.Code != codeInvalidParams { - t.Errorf("expected error code %d, got %d", codeInvalidParams, wireErr.Code) + if wireErr.Code != jsonrpc.CodeInvalidParams { + t.Errorf("expected error code %d, got %d", jsonrpc.CodeInvalidParams, wireErr.Code) } }) diff --git a/mcp/transport.go b/mcp/transport.go index cacd65fd..d6359f5b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -204,8 +204,7 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result a // call executes and awaits a jsonrpc2 call on the given connection, // translating errors into the mcp domain. func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { - // TODO: the "%w"s in this function effectively make jsonrpc2.WireError part of the API. - // Consider alternatives. + // The "%w"s in this function expose jsonrpc.Error as part of the API. call := conn.Call(ctx, method, params) err := call.Await(ctx, result) switch { From d7d18929517e0752eb0b2fd097086512a59e8610 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 24 Nov 2025 21:51:11 -0500 Subject: [PATCH 2/3] mcp: implement URLElicitationRequired error with automatic retry This change implements support for the MCP URL elicitation required error (-32042), enabling servers to request out-of-band user authorization and clients to automatically handle these requests with retry logic. Server-side: - Add CodeURLElicitationRequired constant - Add URLElicitationRequiredError() constructor. Client-side: - Add urlElicitationMiddleware to intercept URLElicitationRequired errors and retry calls. - Update callElicitationCompleteHandler to signal waiting operations. - Fix capability advertisement to support both form and URL modes. Fixes #623 --- mcp/client.go | 169 +++++++++++++++++++++++++++++++- mcp/client_test.go | 55 +++++++++++ mcp/error_test.go | 236 +++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 2 +- mcp/shared.go | 35 ++++++- 5 files changed, 493 insertions(+), 4 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index f454e320..2f2174c2 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -7,9 +7,11 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "iter" "slices" + "strings" "sync" "sync/atomic" "time" @@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { c := &Client{ impl: impl, roots: newFeatureSet(func(r *Root) string { return r.URI }), - sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], + sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } if opts != nil { @@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities { // {"form":{}} for backward compatibility, but we explicitly set the form // capability. caps.Elicitation.Form = &FormElicitationCapabilities{} - } else if slices.Contains(modes, "url") { + } + if slices.Contains(modes, "url") { caps.Elicitation.URL = &URLElicitationCapabilities{} } } @@ -206,6 +209,10 @@ type ClientSession struct { // No mutex is (currently) required to guard the session state, because it is // only set synchronously during Client.Connect. state clientSessionState + + // Pending URL elicitations waiting for completion notifications. + pendingElicitationsMu sync.Mutex + pendingElicitations map[string]chan struct{} } type clientSessionState struct { @@ -250,6 +257,46 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +// registerElicitationWaiter registers a waiter for an elicitation complete +// notification with the given elicitation ID. It returns two functions: an await +// function that waits for the notification or context cancellation, and a cleanup +// function that must be called to unregister the waiter. This must be called before +// triggering the elicitation to avoid a race condition where the notification +// arrives before the waiter is registered. +// +// The cleanup function must be called even if the await function is never called, +// to prevent leaking the registration. +func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) { + // Create a channel for this elicitation. + ch := make(chan struct{}, 1) + + // Register the channel. + cs.pendingElicitationsMu.Lock() + if cs.pendingElicitations == nil { + cs.pendingElicitations = make(map[string]chan struct{}) + } + cs.pendingElicitations[elicitationID] = ch + cs.pendingElicitationsMu.Unlock() + + // Return await and cleanup functions. + await = func(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err()) + case <-ch: + return nil + } + } + + cleanup = func() { + cs.pendingElicitationsMu.Lock() + delete(cs.pendingElicitations, elicitationID) + cs.pendingElicitationsMu.Unlock() + } + + return await, cleanup +} + // startKeepalive starts the keepalive mechanism for this client session. func (cs *ClientSession) startKeepalive(interval time.Duration) { startKeepalive(cs, interval, &cs.keepaliveCancel) @@ -309,6 +356,110 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) ( return c.opts.CreateMessageHandler(ctx, req) } +// urlElicitationMiddleware returns middleware that automatically handles URL elicitation +// required errors by executing the elicitation handler, waiting for completion notifications, +// and retrying the operation. +// +// This middleware should be added to clients that want automatic URL elicitation handling: +// +// client := mcp.NewClient(impl, opts) +// client.AddSendingMiddleware(mcp.urlElicitationMiddleware()) +// +// TODO(rfindley): this isn't strictly necessary for the SEP, but may be +// useful. Propose exporting it it. +func urlElicitationMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + // Call the underlying handler. + res, err := next(ctx, method, req) + if err == nil { + return res, nil + } + + // Check if this is a URL elicitation required error. + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired { + return res, err + } + + // Notifications don't support retries. + if strings.HasPrefix(method, "notifications/") { + return res, err + } + + // Extract the client session. + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, err + } + + // Check if the client has an elicitation handler. + if cs.client.opts.ElicitationHandler == nil { + return res, err + } + + // Parse the elicitations from the error data. + var errorData struct { + Elicitations []*ElicitParams `json:"elicitations"` + } + if rpcErr.Data != nil { + if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil { + return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err) + } + } + + // Validate that all elicitations are URL mode. + for _, elicit := range errorData.Elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // Default mode. + } + if mode != "url" { + return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode) + } + } + + // Register waiters for all elicitations before executing handlers + // to avoid race condition where notification arrives before waiter is registered. + type waiter struct { + await func(context.Context) error + cleanup func() + } + waiters := make([]waiter, 0, len(errorData.Elicitations)) + for _, elicitParams := range errorData.Elicitations { + await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID) + waiters = append(waiters, waiter{await: await, cleanup: cleanup}) + } + + // Ensure cleanup happens even if we return early. + defer func() { + for _, w := range waiters { + w.cleanup() + } + }() + + // Execute the elicitation handler for each elicitation. + for _, elicitParams := range errorData.Elicitations { + elicitReq := newClientRequest(cs, elicitParams) + _, elicitErr := cs.client.elicit(ctx, elicitReq) + if elicitErr != nil { + return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr) + } + } + + // Wait for all elicitations to complete. + for _, w := range waiters { + if err := w.await(ctx); err != nil { + return nil, err + } + } + + // All elicitations complete, retry the original operation. + return next(ctx, method, req) + } + } +} + func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { if c.opts.ElicitationHandler == nil { return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"} @@ -723,6 +874,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa } func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) { + // Check if there's a pending elicitation waiting for this notification. + if cs, ok := req.GetSession().(*ClientSession); ok { + cs.pendingElicitationsMu.Lock() + if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists { + select { + case ch <- struct{}{}: + default: + // Channel already signaled. + } + } + cs.pendingElicitationsMu.Unlock() + } + + // Call the user's handler if provided. if h := c.opts.ElicitationCompleteHandler; h != nil { h(ctx, req) } diff --git a/mcp/client_test.go b/mcp/client_test.go index eaeedc81..a7fb68dc 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -222,6 +222,61 @@ func TestClientCapabilities(t *testing.T) { Sampling: &SamplingCapabilities{}, }, }, + { + name: "With form elicitation", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + ElicitationModes: []string{"form"}, + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Elicitation: &ElicitationCapabilities{ + Form: &FormElicitationCapabilities{}, + }, + }, + }, + { + name: "With URL elicitation", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + ElicitationModes: []string{"url"}, + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Elicitation: &ElicitationCapabilities{ + URL: &URLElicitationCapabilities{}, + }, + }, + }, + { + name: "With both form and URL elicitation", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + ElicitationModes: []string{"form", "url"}, + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Elicitation: &ElicitationCapabilities{ + Form: &FormElicitationCapabilities{}, + URL: &URLElicitationCapabilities{}, + }, + }, + }, } for _, tc := range testCases { diff --git a/mcp/error_test.go b/mcp/error_test.go index 64ca5a6f..3b2a4a62 100644 --- a/mcp/error_test.go +++ b/mcp/error_test.go @@ -6,7 +6,9 @@ package mcp import ( "context" + "encoding/json" "errors" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -115,3 +117,237 @@ func TestServerErrors(t *testing.T) { }) } } + +// TestURLElicitationRequired validates that URL elicitation required errors +// are properly created and handled by the client. +func TestURLElicitationRequired(t *testing.T) { + ctx := context.Background() + + t.Run("error creation", func(t *testing.T) { + elicitations := []*ElicitParams{ + { + Mode: "url", + Message: "Please authorize", + URL: "https://example.com/auth", + ElicitationID: "auth-123", + }, + } + + err := URLElicitationRequiredError(elicitations) + + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) { + t.Fatalf("expected jsonrpc.Error, got %T", err) + } + + if rpcErr.Code != CodeURLElicitationRequired { + t.Errorf("expected error code %d, got %d", CodeURLElicitationRequired, rpcErr.Code) + } + + if rpcErr.Message != "URL elicitation required" { + t.Errorf("expected message 'URL elicitation required', got %q", rpcErr.Message) + } + + if rpcErr.Data == nil { + t.Fatal("expected error data, got nil") + } + + // Verify the elicitations can be unmarshaled from the error data + var errorData struct { + Elicitations []*ElicitParams `json:"elicitations"` + } + if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil { + t.Fatalf("failed to unmarshal error data: %v", err) + } + + if len(errorData.Elicitations) != 1 { + t.Fatalf("expected 1 elicitation, got %d", len(errorData.Elicitations)) + } + + if errorData.Elicitations[0].URL != "https://example.com/auth" { + t.Errorf("expected URL 'https://example.com/auth', got %q", errorData.Elicitations[0].URL) + } + }) + + t.Run("error creation with non-URL mode panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic when creating URLElicitationRequiredError with non-URL mode") + } + }() + + // This should panic because mode is "form" + URLElicitationRequiredError([]*ElicitParams{ + { + Mode: "form", + Message: "This should panic", + ElicitationID: "bad-123", + }, + }) + }) + + t.Run("error creation with empty mode panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic when creating URLElicitationRequiredError with empty mode (defaults to form)") + } + }() + + // This should panic because empty mode defaults to "form" + URLElicitationRequiredError([]*ElicitParams{ + { + Message: "This should panic", + ElicitationID: "bad-123", + }, + }) + }) + + t.Run("client middleware", func(t *testing.T) { + // Declare ss outside so it can be captured in handlers. + var ss *ServerSession + + elicitCalled := false + elicitURL := "" + elicitID := "form-123" + + // Create client with elicitation handler and middleware. + client := NewClient(testImpl, &ClientOptions{ + ElicitationModes: []string{"url"}, + ElicitationHandler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + elicitCalled = true + elicitURL = req.Params.URL + + // Simulate the server sending elicitation complete notification. + // In a real scenario, this would happen out-of-band after the user + // completes the form submission. + go func() { + err := handleNotify(ctx, notificationElicitationComplete, + newServerRequest(ss, &ElicitationCompleteParams{ + ElicitationID: elicitID, + })) + if err != nil { + t.Errorf("failed to send elicitation complete notification: %v", err) + } + }() + + return &ElicitResult{Action: "accept"}, nil + }, + }) + // Add URL elicitation middleware for automatic retry. + client.AddSendingMiddleware(urlElicitationMiddleware()) + + callCount := 0 + + cs, serverSession, cleanup := basicClientServerConnection(t, + client, + nil, + func(s *Server) { + // Tool that requires form submission on first call, succeeds on second. + handler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { + callCount++ + if callCount == 1 { + // First call: require elicitation. + return nil, nil, URLElicitationRequiredError([]*ElicitParams{ + { + Mode: "url", + Message: "Please complete the form", + URL: "https://example.com/form", + ElicitationID: elicitID, + }, + }) + } + // Second call (after retry): return success. + return &CallToolResult{ + Content: []Content{&TextContent{Text: "form submitted"}}, + }, nil, nil + } + AddTool(s, &Tool{Name: "submit_form", Description: "requires form submission"}, handler) + + // Tool that returns invalid elicitation mode (form instead of URL). + badHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { + // Manually construct an error with form mode (bypassing validation). + data, _ := json.Marshal(map[string]any{ + "elicitations": []*ElicitParams{ + { + Mode: "form", + Message: "Invalid mode", + ElicitationID: "bad-form", + }, + }, + }) + return nil, nil, &jsonrpc.Error{ + Code: CodeURLElicitationRequired, + Message: "URL elicitation required", + Data: json.RawMessage(data), + } + } + AddTool(s, &Tool{Name: "bad_tool", Description: "returns invalid elicitation"}, badHandler) + }, + ) + ss = serverSession + defer cleanup() + + t.Run("auto-retry after elicitation", func(t *testing.T) { + // Reset state for this subtest. + elicitCalled = false + elicitURL = "" + callCount = 0 + + // Call the tool that requires URL elicitation. + result, err := cs.CallTool(ctx, &CallToolParams{ + Name: "submit_form", + Arguments: map[string]any{}, + }) + + // After automatic retry, the operation should succeed. + if err != nil { + t.Fatalf("expected success after retry, got error: %v", err) + } + + // Verify the elicitation handler was called. + if !elicitCalled { + t.Error("expected elicitation handler to be called") + } + + if elicitURL != "https://example.com/form" { + t.Errorf("expected elicit URL 'https://example.com/form', got %q", elicitURL) + } + + // Verify the tool was called twice (first attempt + retry). + if callCount != 2 { + t.Errorf("expected tool to be called 2 times, got %d", callCount) + } + + // Verify we got the successful result. + if len(result.Content) != 1 { + t.Fatalf("expected 1 content item, got %d", len(result.Content)) + } + + textContent, ok := result.Content[0].(*TextContent) + if !ok { + t.Fatalf("expected TextContent, got %T", result.Content[0]) + } + + if textContent.Text != "form submitted" { + t.Errorf("expected text 'form submitted', got %q", textContent.Text) + } + }) + + t.Run("reject non-URL mode", func(t *testing.T) { + // Call the tool that returns invalid elicitation mode. + _, err := cs.CallTool(ctx, &CallToolParams{ + Name: "bad_tool", + Arguments: map[string]any{}, + }) + + // Should get an error about invalid mode. + if err == nil { + t.Fatal("expected error for non-URL mode elicitation, got nil") + } + + if !strings.Contains(err.Error(), "URL mode") { + t.Errorf("expected error message to mention URL mode, got: %v", err) + } + }) + }) +} diff --git a/mcp/server.go b/mcp/server.go index 7e85f3c6..d4317222 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -146,7 +146,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), - sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], + sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), } diff --git a/mcp/shared.go b/mcp/shared.go index 0f2f64dd..e5901442 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -88,7 +88,7 @@ func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { } } -func defaultSendingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { +func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) { info, ok := req.GetSession().sendingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. @@ -335,8 +335,41 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont const ( // CodeResourceNotFound indicates that a requested resource could not be found. CodeResourceNotFound = -32002 + // CodeURLElicitationRequired indicates that the server requires URL elicitation + // before processing the request. The client should execute the elicitation handler + // with the elicitations provided in the error data. + CodeURLElicitationRequired = -32042 ) +// URLElicitationRequiredError returns an error indicating that URL elicitation is required +// before the request can be processed. The elicitations parameter should contain the +// elicitation requests that must be completed. +func URLElicitationRequiredError(elicitations []*ElicitParams) error { + // Validate that all elicitations are URL mode + for _, elicit := range elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // default mode + } + if mode != "url" { + panic(fmt.Sprintf("URLElicitationRequiredError requires all elicitations to be URL mode, got %q", mode)) + } + } + + data, err := json.Marshal(map[string]any{ + "elicitations": elicitations, + }) + if err != nil { + // This should never happen with valid ElicitParams + panic(fmt.Sprintf("failed to marshal elicitations: %v", err)) + } + return &jsonrpc.Error{ + Code: CodeURLElicitationRequired, + Message: "URL elicitation required", + Data: json.RawMessage(data), + } +} + // Internal error codes const ( // The error code if the method exists and was called properly, but the peer does not support it. From cf8bbd59f9e65549d4dcbdbf770aa7c409ed84a9 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 26 Nov 2025 14:28:44 -0500 Subject: [PATCH 3/3] address review comments; improve test error messages --- mcp/client.go | 2 +- mcp/error_test.go | 52 +++++++++++++++++++++++------------------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 2f2174c2..a3d36b2e 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -366,7 +366,7 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) ( // client.AddSendingMiddleware(mcp.urlElicitationMiddleware()) // // TODO(rfindley): this isn't strictly necessary for the SEP, but may be -// useful. Propose exporting it it. +// useful. Propose exporting it. func urlElicitationMiddleware() Middleware { return func(next MethodHandler) MethodHandler { return func(ctx context.Context, method string, req Request) (Result, error) { diff --git a/mcp/error_test.go b/mcp/error_test.go index 3b2a4a62..694ef104 100644 --- a/mcp/error_test.go +++ b/mcp/error_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "strings" + "sync/atomic" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -99,20 +100,20 @@ func TestServerErrors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { err := tc.executeCall() if err == nil { - t.Fatal("expected error, got nil") + t.Fatal("got nil error, want non-nil") } var rpcErr *jsonrpc.Error if !errors.As(err, &rpcErr) { - t.Fatalf("expected jsonrpc.Error, got %T: %v", err, err) + t.Fatalf("got error type %T, want jsonrpc.Error: %v", err, err) } if rpcErr.Code != tc.expectedCode { - t.Errorf("expected error code %d, got %d", tc.expectedCode, rpcErr.Code) + t.Errorf("got error code %d, want %d", rpcErr.Code, tc.expectedCode) } if rpcErr.Message == "" { - t.Error("expected non-empty error message") + t.Error("got empty error message, want non-empty") } }) } @@ -137,19 +138,19 @@ func TestURLElicitationRequired(t *testing.T) { var rpcErr *jsonrpc.Error if !errors.As(err, &rpcErr) { - t.Fatalf("expected jsonrpc.Error, got %T", err) + t.Fatalf("got error type %T, want jsonrpc.Error", err) } if rpcErr.Code != CodeURLElicitationRequired { - t.Errorf("expected error code %d, got %d", CodeURLElicitationRequired, rpcErr.Code) + t.Errorf("got error code %d, want %d", rpcErr.Code, CodeURLElicitationRequired) } if rpcErr.Message != "URL elicitation required" { - t.Errorf("expected message 'URL elicitation required', got %q", rpcErr.Message) + t.Errorf("got message %q, want 'URL elicitation required'", rpcErr.Message) } if rpcErr.Data == nil { - t.Fatal("expected error data, got nil") + t.Fatal("got nil error data, want non-nil") } // Verify the elicitations can be unmarshaled from the error data @@ -161,18 +162,18 @@ func TestURLElicitationRequired(t *testing.T) { } if len(errorData.Elicitations) != 1 { - t.Fatalf("expected 1 elicitation, got %d", len(errorData.Elicitations)) + t.Fatalf("got %d elicitations, want 1", len(errorData.Elicitations)) } if errorData.Elicitations[0].URL != "https://example.com/auth" { - t.Errorf("expected URL 'https://example.com/auth', got %q", errorData.Elicitations[0].URL) + t.Errorf("got URL %q, want 'https://example.com/auth'", errorData.Elicitations[0].URL) } }) t.Run("error creation with non-URL mode panics", func(t *testing.T) { defer func() { if r := recover(); r == nil { - t.Error("expected panic when creating URLElicitationRequiredError with non-URL mode") + t.Error("got no panic when creating URLElicitationRequiredError with non-URL mode, want panic") } }() @@ -189,7 +190,7 @@ func TestURLElicitationRequired(t *testing.T) { t.Run("error creation with empty mode panics", func(t *testing.T) { defer func() { if r := recover(); r == nil { - t.Error("expected panic when creating URLElicitationRequiredError with empty mode (defaults to form)") + t.Error("got no panic when creating URLElicitationRequiredError with empty mode (defaults to form), want panic") } }() @@ -236,7 +237,7 @@ func TestURLElicitationRequired(t *testing.T) { // Add URL elicitation middleware for automatic retry. client.AddSendingMiddleware(urlElicitationMiddleware()) - callCount := 0 + var callCount atomic.Int32 cs, serverSession, cleanup := basicClientServerConnection(t, client, @@ -244,8 +245,7 @@ func TestURLElicitationRequired(t *testing.T) { func(s *Server) { // Tool that requires form submission on first call, succeeds on second. handler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { - callCount++ - if callCount == 1 { + if callCount.Add(1) == 1 { // First call: require elicitation. return nil, nil, URLElicitationRequiredError([]*ElicitParams{ { @@ -291,7 +291,7 @@ func TestURLElicitationRequired(t *testing.T) { // Reset state for this subtest. elicitCalled = false elicitURL = "" - callCount = 0 + callCount.Store(0) // Call the tool that requires URL elicitation. result, err := cs.CallTool(ctx, &CallToolParams{ @@ -301,35 +301,35 @@ func TestURLElicitationRequired(t *testing.T) { // After automatic retry, the operation should succeed. if err != nil { - t.Fatalf("expected success after retry, got error: %v", err) + t.Fatalf("CallTool failed: %v", err) } // Verify the elicitation handler was called. if !elicitCalled { - t.Error("expected elicitation handler to be called") + t.Error("elicitation handler not called") } if elicitURL != "https://example.com/form" { - t.Errorf("expected elicit URL 'https://example.com/form', got %q", elicitURL) + t.Errorf("got elicit URL %q, want 'https://example.com/form'", elicitURL) } // Verify the tool was called twice (first attempt + retry). - if callCount != 2 { - t.Errorf("expected tool to be called 2 times, got %d", callCount) + if got, want := callCount.Load(), int32(2); got != want { + t.Errorf("CallTool(): with retry, got %d tool calls, want %d", got, want) } // Verify we got the successful result. if len(result.Content) != 1 { - t.Fatalf("expected 1 content item, got %d", len(result.Content)) + t.Fatalf("CallTool(): got %d content items, want 1", len(result.Content)) } textContent, ok := result.Content[0].(*TextContent) if !ok { - t.Fatalf("expected TextContent, got %T", result.Content[0]) + t.Fatalf("CallTool(): got content type %T, want TextContent", result.Content[0]) } if textContent.Text != "form submitted" { - t.Errorf("expected text 'form submitted', got %q", textContent.Text) + t.Errorf("CallTool(): got text %q, want 'form submitted'", textContent.Text) } }) @@ -342,11 +342,11 @@ func TestURLElicitationRequired(t *testing.T) { // Should get an error about invalid mode. if err == nil { - t.Fatal("expected error for non-URL mode elicitation, got nil") + t.Fatal("got nil error for non-URL mode elicitation, want error") } if !strings.Contains(err.Error(), "URL mode") { - t.Errorf("expected error message to mention URL mode, got: %v", err) + t.Errorf("got error %v, want mention of URL mode", err) } }) })