Skip to content

Commit 47c46e2

Browse files
committed
mcp: implement URLElicitationRequired error with automatic retry
This change implements support for the MCP URL elicitation required error (-32042), enabling servers to request out-of-band user authorization and clients to automatically handle these requests with retry logic. Server-side: - Add CodeURLElicitationRequired constant - Add URLElicitationRequiredError() constructor. Client-side: - Add urlElicitationMiddleware to intercept URLElicitationRequired errors and retry calls. - Update callElicitationCompleteHandler to signal waiting operations. - Fix capability advertisement to support both form and URL modes. Fixes #623
1 parent 0ae7081 commit 47c46e2

File tree

5 files changed

+483
-4
lines changed

5 files changed

+483
-4
lines changed

mcp/client.go

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package mcp
77
import (
88
"context"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"iter"
1213
"slices"
14+
"strings"
1315
"sync"
1416
"sync/atomic"
1517
"time"
@@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
4547
c := &Client{
4648
impl: impl,
4749
roots: newFeatureSet(func(r *Root) string { return r.URI }),
48-
sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession],
50+
sendingMethodHandler_: defaultSendingMethodHandler,
4951
receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession],
5052
}
5153
if opts != nil {
@@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
134136
// {"form":{}} for backward compatibility, but we explicitly set the form
135137
// capability.
136138
caps.Elicitation.Form = &FormElicitationCapabilities{}
137-
} else if slices.Contains(modes, "url") {
139+
}
140+
if slices.Contains(modes, "url") {
138141
caps.Elicitation.URL = &URLElicitationCapabilities{}
139142
}
140143
}
@@ -206,6 +209,10 @@ type ClientSession struct {
206209
// No mutex is (currently) required to guard the session state, because it is
207210
// only set synchronously during Client.Connect.
208211
state clientSessionState
212+
213+
// Pending URL elicitations waiting for completion notifications.
214+
pendingElicitationsMu sync.Mutex
215+
pendingElicitations map[string]chan struct{}
209216
}
210217

211218
type clientSessionState struct {
@@ -250,6 +257,39 @@ func (cs *ClientSession) Wait() error {
250257
return cs.conn.Wait()
251258
}
252259

260+
// registerElicitationWaiter registers a waiter for an elicitation complete notification
261+
// with the given elicitation ID. It returns a function that waits for the notification
262+
// or context cancellation. This must be called before triggering the elicitation to avoid
263+
// a race condition where the notification arrives before the waiter is registered.
264+
func (cs *ClientSession) registerElicitationWaiter(elicitationID string) func(context.Context) error {
265+
// Create a channel for this elicitation.
266+
ch := make(chan struct{}, 1)
267+
268+
// Register the channel.
269+
cs.pendingElicitationsMu.Lock()
270+
if cs.pendingElicitations == nil {
271+
cs.pendingElicitations = make(map[string]chan struct{})
272+
}
273+
cs.pendingElicitations[elicitationID] = ch
274+
cs.pendingElicitationsMu.Unlock()
275+
276+
// Return a function that waits for completion and cleans up.
277+
return func(ctx context.Context) error {
278+
defer func() {
279+
cs.pendingElicitationsMu.Lock()
280+
delete(cs.pendingElicitations, elicitationID)
281+
cs.pendingElicitationsMu.Unlock()
282+
}()
283+
284+
select {
285+
case <-ctx.Done():
286+
return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err())
287+
case <-ch:
288+
return nil
289+
}
290+
}
291+
}
292+
253293
// startKeepalive starts the keepalive mechanism for this client session.
254294
func (cs *ClientSession) startKeepalive(interval time.Duration) {
255295
startKeepalive(cs, interval, &cs.keepaliveCancel)
@@ -309,6 +349,98 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (
309349
return c.opts.CreateMessageHandler(ctx, req)
310350
}
311351

352+
// urlElicitationMiddleware returns middleware that automatically handles URL elicitation
353+
// required errors by executing the elicitation handler, waiting for completion notifications,
354+
// and retrying the operation.
355+
//
356+
// This middleware should be added to clients that want automatic URL elicitation handling:
357+
//
358+
// client := mcp.NewClient(impl, opts)
359+
// client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
360+
//
361+
// TODO(rfindley): this isn't strictly necessary for the SEP, but may be
362+
// useful. Propose exporting it it.
363+
func urlElicitationMiddleware() Middleware {
364+
return func(next MethodHandler) MethodHandler {
365+
return func(ctx context.Context, method string, req Request) (Result, error) {
366+
// Call the underlying handler.
367+
res, err := next(ctx, method, req)
368+
if err == nil {
369+
return res, nil
370+
}
371+
372+
// Check if this is a URL elicitation required error.
373+
var rpcErr *jsonrpc.Error
374+
if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired {
375+
return res, err
376+
}
377+
378+
// Notifications don't support retries.
379+
if strings.HasPrefix(method, "notifications/") {
380+
return res, err
381+
}
382+
383+
// Extract the client session.
384+
cs, ok := req.GetSession().(*ClientSession)
385+
if !ok {
386+
return res, err
387+
}
388+
389+
// Check if the client has an elicitation handler.
390+
if cs.client.opts.ElicitationHandler == nil {
391+
return res, err
392+
}
393+
394+
// Parse the elicitations from the error data.
395+
var errorData struct {
396+
Elicitations []*ElicitParams `json:"elicitations"`
397+
}
398+
if rpcErr.Data != nil {
399+
if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil {
400+
return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err)
401+
}
402+
}
403+
404+
// Validate that all elicitations are URL mode.
405+
for _, elicit := range errorData.Elicitations {
406+
mode := elicit.Mode
407+
if mode == "" {
408+
mode = "form" // Default mode.
409+
}
410+
if mode != "url" {
411+
return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode)
412+
}
413+
}
414+
415+
// Register waiters for all elicitations before executing handlers
416+
// to avoid race condition where notification arrives before waiter is registered.
417+
awaitFuncs := make([]func(context.Context) error, 0, len(errorData.Elicitations))
418+
for _, elicitParams := range errorData.Elicitations {
419+
awaitFuncs = append(awaitFuncs, cs.registerElicitationWaiter(elicitParams.ElicitationID))
420+
}
421+
422+
// Execute the elicitation handler for each elicitation.
423+
for _, elicitParams := range errorData.Elicitations {
424+
elicitReq := newClientRequest(cs, elicitParams)
425+
_, elicitErr := cs.client.elicit(ctx, elicitReq)
426+
if elicitErr != nil {
427+
return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr)
428+
}
429+
}
430+
431+
// Wait for all elicitations to complete.
432+
for _, await := range awaitFuncs {
433+
if err := await(ctx); err != nil {
434+
return nil, err
435+
}
436+
}
437+
438+
// All elicitations complete, retry the original operation.
439+
return next(ctx, method, req)
440+
}
441+
}
442+
}
443+
312444
func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) {
313445
if c.opts.ElicitationHandler == nil {
314446
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"}
@@ -723,6 +855,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
723855
}
724856

725857
func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) {
858+
// Check if there's a pending elicitation waiting for this notification.
859+
if cs, ok := req.GetSession().(*ClientSession); ok {
860+
cs.pendingElicitationsMu.Lock()
861+
if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists {
862+
select {
863+
case ch <- struct{}{}:
864+
default:
865+
// Channel already signaled.
866+
}
867+
}
868+
cs.pendingElicitationsMu.Unlock()
869+
}
870+
871+
// Call the user's handler if provided.
726872
if h := c.opts.ElicitationCompleteHandler; h != nil {
727873
h(ctx, req)
728874
}

mcp/client_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,61 @@ func TestClientCapabilities(t *testing.T) {
222222
Sampling: &SamplingCapabilities{},
223223
},
224224
},
225+
{
226+
name: "With form elicitation",
227+
configureClient: func(s *Client) {},
228+
clientOpts: ClientOptions{
229+
ElicitationModes: []string{"form"},
230+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
231+
return nil, nil
232+
},
233+
},
234+
wantCapabilities: &ClientCapabilities{
235+
Roots: struct {
236+
ListChanged bool "json:\"listChanged,omitempty\""
237+
}{ListChanged: true},
238+
Elicitation: &ElicitationCapabilities{
239+
Form: &FormElicitationCapabilities{},
240+
},
241+
},
242+
},
243+
{
244+
name: "With URL elicitation",
245+
configureClient: func(s *Client) {},
246+
clientOpts: ClientOptions{
247+
ElicitationModes: []string{"url"},
248+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
249+
return nil, nil
250+
},
251+
},
252+
wantCapabilities: &ClientCapabilities{
253+
Roots: struct {
254+
ListChanged bool "json:\"listChanged,omitempty\""
255+
}{ListChanged: true},
256+
Elicitation: &ElicitationCapabilities{
257+
URL: &URLElicitationCapabilities{},
258+
},
259+
},
260+
},
261+
{
262+
name: "With both form and URL elicitation",
263+
configureClient: func(s *Client) {},
264+
clientOpts: ClientOptions{
265+
ElicitationModes: []string{"form", "url"},
266+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
267+
return nil, nil
268+
},
269+
},
270+
wantCapabilities: &ClientCapabilities{
271+
Roots: struct {
272+
ListChanged bool "json:\"listChanged,omitempty\""
273+
}{ListChanged: true},
274+
Elicitation: &ElicitationCapabilities{
275+
Form: &FormElicitationCapabilities{},
276+
URL: &URLElicitationCapabilities{},
277+
},
278+
},
279+
},
225280
}
226281

227282
for _, tc := range testCases {

0 commit comments

Comments
 (0)