Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions jsonrpc/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type (
Request = jsonrpc2.Request
// Response is a JSON-RPC response.
Response = jsonrpc2.Response
// Error is a structured error in a JSON-RPC response.
Error = jsonrpc2.WireError
)

// MakeID coerces the given Go value to an ID. The value should be the
Expand All @@ -37,3 +39,18 @@ func EncodeMessage(msg Message) ([]byte, error) {
func DecodeMessage(data []byte) (Message, error) {
return jsonrpc2.DecodeMessage(data)
}

// Standard JSON-RPC 2.0 error codes.
// See https://www.jsonrpc.org/specification#error_object
const (
// CodeParseError indicates invalid JSON was received by the server.
CodeParseError = -32700
// CodeInvalidRequest indicates the JSON sent is not a valid Request object.
CodeInvalidRequest = -32600
// CodeMethodNotFound indicates the method does not exist or is not available.
CodeMethodNotFound = -32601
// CodeInvalidParams indicates invalid method parameter(s).
CodeInvalidParams = -32602
// CodeInternalError indicates an internal JSON-RPC error.
CodeInternalError = -32603
)
189 changes: 177 additions & 12 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"iter"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
c := &Client{
impl: impl,
roots: newFeatureSet(func(r *Root) string { return r.URI }),
sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession],
sendingMethodHandler_: defaultSendingMethodHandler,
receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession],
}
if opts != nil {
Expand Down Expand Up @@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
// {"form":{}} for backward compatibility, but we explicitly set the form
// capability.
caps.Elicitation.Form = &FormElicitationCapabilities{}
} else if slices.Contains(modes, "url") {
}
if slices.Contains(modes, "url") {
caps.Elicitation.URL = &URLElicitationCapabilities{}
}
}
Expand Down Expand Up @@ -206,6 +209,10 @@ type ClientSession struct {
// No mutex is (currently) required to guard the session state, because it is
// only set synchronously during Client.Connect.
state clientSessionState

// Pending URL elicitations waiting for completion notifications.
pendingElicitationsMu sync.Mutex
pendingElicitations map[string]chan struct{}
}

type clientSessionState struct {
Expand Down Expand Up @@ -250,6 +257,46 @@ func (cs *ClientSession) Wait() error {
return cs.conn.Wait()
}

// registerElicitationWaiter registers a waiter for an elicitation complete
// notification with the given elicitation ID. It returns two functions: an await
// function that waits for the notification or context cancellation, and a cleanup
// function that must be called to unregister the waiter. This must be called before
// triggering the elicitation to avoid a race condition where the notification
// arrives before the waiter is registered.
//
// The cleanup function must be called even if the await function is never called,
// to prevent leaking the registration.
func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) {
// Create a channel for this elicitation.
ch := make(chan struct{}, 1)

// Register the channel.
cs.pendingElicitationsMu.Lock()
if cs.pendingElicitations == nil {
cs.pendingElicitations = make(map[string]chan struct{})
}
cs.pendingElicitations[elicitationID] = ch
cs.pendingElicitationsMu.Unlock()

// Return await and cleanup functions.
await = func(ctx context.Context) error {
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err())
case <-ch:
return nil
}
}

cleanup = func() {
cs.pendingElicitationsMu.Lock()
delete(cs.pendingElicitations, elicitationID)
cs.pendingElicitationsMu.Unlock()
}

return await, cleanup
}

// startKeepalive starts the keepalive mechanism for this client session.
func (cs *ClientSession) startKeepalive(interval time.Duration) {
startKeepalive(cs, interval, &cs.keepaliveCancel)
Expand Down Expand Up @@ -304,14 +351,118 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots
func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
if c.opts.CreateMessageHandler == nil {
// TODO: wrap or annotate this error? Pick a standard code?
return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support CreateMessage")
return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"}
}
return c.opts.CreateMessageHandler(ctx, req)
}

// urlElicitationMiddleware returns middleware that automatically handles URL elicitation
// required errors by executing the elicitation handler, waiting for completion notifications,
// and retrying the operation.
//
// This middleware should be added to clients that want automatic URL elicitation handling:
//
// client := mcp.NewClient(impl, opts)
// client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
//
// TODO(rfindley): this isn't strictly necessary for the SEP, but may be
// useful. Propose exporting it.
func urlElicitationMiddleware() Middleware {
return func(next MethodHandler) MethodHandler {
return func(ctx context.Context, method string, req Request) (Result, error) {
// Call the underlying handler.
res, err := next(ctx, method, req)
if err == nil {
return res, nil
}

// Check if this is a URL elicitation required error.
var rpcErr *jsonrpc.Error
if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired {
return res, err
}

// Notifications don't support retries.
if strings.HasPrefix(method, "notifications/") {
return res, err
}

// Extract the client session.
cs, ok := req.GetSession().(*ClientSession)
if !ok {
return res, err
}

// Check if the client has an elicitation handler.
if cs.client.opts.ElicitationHandler == nil {
return res, err
}

// Parse the elicitations from the error data.
var errorData struct {
Elicitations []*ElicitParams `json:"elicitations"`
}
if rpcErr.Data != nil {
if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil {
return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err)
}
}

// Validate that all elicitations are URL mode.
for _, elicit := range errorData.Elicitations {
mode := elicit.Mode
if mode == "" {
mode = "form" // Default mode.
}
if mode != "url" {
return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode)
}
}

// Register waiters for all elicitations before executing handlers
// to avoid race condition where notification arrives before waiter is registered.
type waiter struct {
await func(context.Context) error
cleanup func()
}
waiters := make([]waiter, 0, len(errorData.Elicitations))
for _, elicitParams := range errorData.Elicitations {
await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID)
waiters = append(waiters, waiter{await: await, cleanup: cleanup})
}

// Ensure cleanup happens even if we return early.
defer func() {
for _, w := range waiters {
w.cleanup()
}
}()

// Execute the elicitation handler for each elicitation.
for _, elicitParams := range errorData.Elicitations {
elicitReq := newClientRequest(cs, elicitParams)
_, elicitErr := cs.client.elicit(ctx, elicitReq)
if elicitErr != nil {
return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr)
}
}

// Wait for all elicitations to complete.
for _, w := range waiters {
if err := w.await(ctx); err != nil {
return nil, err
}
}

// All elicitations complete, retry the original operation.
return next(ctx, method, req)
}
}
}

func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) {
if c.opts.ElicitationHandler == nil {
return nil, jsonrpc2.NewError(codeInvalidParams, "client does not support elicitation")
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"}
}

// Validate the elicitation parameters based on the mode.
Expand All @@ -323,11 +474,11 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
switch mode {
case "form":
if req.Params.URL != "" {
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must not be set for form elicitation")
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"}
}
schema, err := validateElicitSchema(req.Params.RequestedSchema)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, err.Error())
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()}
}
res, err := c.opts.ElicitationHandler(ctx, req)
if err != nil {
Expand All @@ -337,28 +488,28 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
if schema != nil && res.Content != nil {
resolved, err := schema.Resolve(nil)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)}
}
if err := resolved.Validate(res.Content); err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)}
}
err = resolved.ApplyDefaults(&res.Content)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)}
}
}
return res, nil
case "url":
if req.Params.RequestedSchema != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, "requestedSchema must not be set for URL elicitation")
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"}
}
if req.Params.URL == "" {
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must be set for URL elicitation")
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"}
}
// No schema validation for URL mode, just pass through to handler.
return c.opts.ElicitationHandler(ctx, req)
default:
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("unsupported elicitation mode: %q", mode))
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)}
}
}

Expand Down Expand Up @@ -723,6 +874,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
}

func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) {
// Check if there's a pending elicitation waiting for this notification.
if cs, ok := req.GetSession().(*ClientSession); ok {
cs.pendingElicitationsMu.Lock()
if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists {
select {
case ch <- struct{}{}:
default:
// Channel already signaled.
}
}
cs.pendingElicitationsMu.Unlock()
}

// Call the user's handler if provided.
if h := c.opts.ElicitationCompleteHandler; h != nil {
h(ctx, req)
}
Expand Down
55 changes: 55 additions & 0 deletions mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,61 @@ func TestClientCapabilities(t *testing.T) {
Sampling: &SamplingCapabilities{},
},
},
{
name: "With form elicitation",
configureClient: func(s *Client) {},
clientOpts: ClientOptions{
ElicitationModes: []string{"form"},
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
return nil, nil
},
},
wantCapabilities: &ClientCapabilities{
Roots: struct {
ListChanged bool "json:\"listChanged,omitempty\""
}{ListChanged: true},
Elicitation: &ElicitationCapabilities{
Form: &FormElicitationCapabilities{},
},
},
},
{
name: "With URL elicitation",
configureClient: func(s *Client) {},
clientOpts: ClientOptions{
ElicitationModes: []string{"url"},
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
return nil, nil
},
},
wantCapabilities: &ClientCapabilities{
Roots: struct {
ListChanged bool "json:\"listChanged,omitempty\""
}{ListChanged: true},
Elicitation: &ElicitationCapabilities{
URL: &URLElicitationCapabilities{},
},
},
},
{
name: "With both form and URL elicitation",
configureClient: func(s *Client) {},
clientOpts: ClientOptions{
ElicitationModes: []string{"form", "url"},
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
return nil, nil
},
},
wantCapabilities: &ClientCapabilities{
Roots: struct {
ListChanged bool "json:\"listChanged,omitempty\""
}{ListChanged: true},
Elicitation: &ElicitationCapabilities{
Form: &FormElicitationCapabilities{},
URL: &URLElicitationCapabilities{},
},
},
},
}

for _, tc := range testCases {
Expand Down
5 changes: 3 additions & 2 deletions mcp/elicitation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

// TODO: migrate other elicitation tests here.
Expand Down Expand Up @@ -74,7 +75,7 @@ func TestElicitationURLMode(t *testing.T) {
Message: "URL is missing",
},
wantErrMsg: "URL must be set for URL elicitation",
wantErrCode: codeInvalidParams,
wantErrCode: jsonrpc.CodeInvalidParams,
},
{
name: "schema not allowed",
Expand All @@ -90,7 +91,7 @@ func TestElicitationURLMode(t *testing.T) {
},
},
wantErrMsg: "requestedSchema must not be set for URL elicitation",
wantErrCode: codeInvalidParams,
wantErrCode: jsonrpc.CodeInvalidParams,
},
}
for _, tc := range testCases {
Expand Down
Loading