Skip to content

Commit 0a31b05

Browse files
committed
mcp: implement URLElicitationRequired error with automatic retry
This change implements support for the MCP URL elicitation required error (-32042), enabling servers to request out-of-band user authorization and clients to automatically handle these requests with retry logic. Server-side: - Add CodeURLElicitationRequired constant - Add URLElicitationRequiredError() constructor. Client-side: - Add urlElicitationMiddleware to intercept URLElicitationRequired errors and retry calls. - Update callElicitationCompleteHandler to signal waiting operations. - Fix capability advertisement to support both form and URL modes. Fixes #623
1 parent 0ae7081 commit 0a31b05

File tree

5 files changed

+492
-4
lines changed

5 files changed

+492
-4
lines changed

mcp/client.go

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package mcp
77
import (
88
"context"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"iter"
1213
"slices"
14+
"strings"
1315
"sync"
1416
"sync/atomic"
1517
"time"
@@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
4547
c := &Client{
4648
impl: impl,
4749
roots: newFeatureSet(func(r *Root) string { return r.URI }),
48-
sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession],
50+
sendingMethodHandler_: defaultSendingMethodHandler,
4951
receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession],
5052
}
5153
if opts != nil {
@@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
134136
// {"form":{}} for backward compatibility, but we explicitly set the form
135137
// capability.
136138
caps.Elicitation.Form = &FormElicitationCapabilities{}
137-
} else if slices.Contains(modes, "url") {
139+
}
140+
if slices.Contains(modes, "url") {
138141
caps.Elicitation.URL = &URLElicitationCapabilities{}
139142
}
140143
}
@@ -206,6 +209,10 @@ type ClientSession struct {
206209
// No mutex is (currently) required to guard the session state, because it is
207210
// only set synchronously during Client.Connect.
208211
state clientSessionState
212+
213+
// Pending URL elicitations waiting for completion notifications.
214+
pendingElicitationsMu sync.Mutex
215+
pendingElicitations map[string]chan struct{}
209216
}
210217

211218
type clientSessionState struct {
@@ -250,6 +257,46 @@ func (cs *ClientSession) Wait() error {
250257
return cs.conn.Wait()
251258
}
252259

260+
// registerElicitationWaiter registers a waiter for an elicitation complete
261+
// notification with the given elicitation ID. It returns two functions: an await
262+
// function that waits for the notification or context cancellation, and a cleanup
263+
// function that must be called to unregister the waiter. This must be called before
264+
// triggering the elicitation to avoid a race condition where the notification
265+
// arrives before the waiter is registered.
266+
//
267+
// The cleanup function must be called even if the await function is never called,
268+
// to prevent leaking the registration.
269+
func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) {
270+
// Create a channel for this elicitation.
271+
ch := make(chan struct{}, 1)
272+
273+
// Register the channel.
274+
cs.pendingElicitationsMu.Lock()
275+
if cs.pendingElicitations == nil {
276+
cs.pendingElicitations = make(map[string]chan struct{})
277+
}
278+
cs.pendingElicitations[elicitationID] = ch
279+
cs.pendingElicitationsMu.Unlock()
280+
281+
// Return await and cleanup functions.
282+
await = func(ctx context.Context) error {
283+
select {
284+
case <-ctx.Done():
285+
return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err())
286+
case <-ch:
287+
return nil
288+
}
289+
}
290+
291+
cleanup = func() {
292+
cs.pendingElicitationsMu.Lock()
293+
delete(cs.pendingElicitations, elicitationID)
294+
cs.pendingElicitationsMu.Unlock()
295+
}
296+
297+
return await, cleanup
298+
}
299+
253300
// startKeepalive starts the keepalive mechanism for this client session.
254301
func (cs *ClientSession) startKeepalive(interval time.Duration) {
255302
startKeepalive(cs, interval, &cs.keepaliveCancel)
@@ -309,6 +356,110 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (
309356
return c.opts.CreateMessageHandler(ctx, req)
310357
}
311358

359+
// urlElicitationMiddleware returns middleware that automatically handles URL elicitation
360+
// required errors by executing the elicitation handler, waiting for completion notifications,
361+
// and retrying the operation.
362+
//
363+
// This middleware should be added to clients that want automatic URL elicitation handling:
364+
//
365+
// client := mcp.NewClient(impl, opts)
366+
// client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
367+
//
368+
// TODO(rfindley): this isn't strictly necessary for the SEP, but may be
369+
// useful. Propose exporting it it.
370+
func urlElicitationMiddleware() Middleware {
371+
return func(next MethodHandler) MethodHandler {
372+
return func(ctx context.Context, method string, req Request) (Result, error) {
373+
// Call the underlying handler.
374+
res, err := next(ctx, method, req)
375+
if err == nil {
376+
return res, nil
377+
}
378+
379+
// Check if this is a URL elicitation required error.
380+
var rpcErr *jsonrpc.Error
381+
if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired {
382+
return res, err
383+
}
384+
385+
// Notifications don't support retries.
386+
if strings.HasPrefix(method, "notifications/") {
387+
return res, err
388+
}
389+
390+
// Extract the client session.
391+
cs, ok := req.GetSession().(*ClientSession)
392+
if !ok {
393+
return res, err
394+
}
395+
396+
// Check if the client has an elicitation handler.
397+
if cs.client.opts.ElicitationHandler == nil {
398+
return res, err
399+
}
400+
401+
// Parse the elicitations from the error data.
402+
var errorData struct {
403+
Elicitations []*ElicitParams `json:"elicitations"`
404+
}
405+
if rpcErr.Data != nil {
406+
if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil {
407+
return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err)
408+
}
409+
}
410+
411+
// Validate that all elicitations are URL mode.
412+
for _, elicit := range errorData.Elicitations {
413+
mode := elicit.Mode
414+
if mode == "" {
415+
mode = "form" // Default mode.
416+
}
417+
if mode != "url" {
418+
return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode)
419+
}
420+
}
421+
422+
// Register waiters for all elicitations before executing handlers
423+
// to avoid race condition where notification arrives before waiter is registered.
424+
type waiter struct {
425+
await func(context.Context) error
426+
cleanup func()
427+
}
428+
waiters := make([]waiter, 0, len(errorData.Elicitations))
429+
for _, elicitParams := range errorData.Elicitations {
430+
await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID)
431+
waiters = append(waiters, waiter{await: await, cleanup: cleanup})
432+
}
433+
434+
// Ensure cleanup happens even if we return early.
435+
defer func() {
436+
for _, w := range waiters {
437+
w.cleanup()
438+
}
439+
}()
440+
441+
// Execute the elicitation handler for each elicitation.
442+
for _, elicitParams := range errorData.Elicitations {
443+
elicitReq := newClientRequest(cs, elicitParams)
444+
_, elicitErr := cs.client.elicit(ctx, elicitReq)
445+
if elicitErr != nil {
446+
return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr)
447+
}
448+
}
449+
450+
// Wait for all elicitations to complete.
451+
for _, w := range waiters {
452+
if err := w.await(ctx); err != nil {
453+
return nil, err
454+
}
455+
}
456+
457+
// All elicitations complete, retry the original operation.
458+
return next(ctx, method, req)
459+
}
460+
}
461+
}
462+
312463
func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) {
313464
if c.opts.ElicitationHandler == nil {
314465
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"}
@@ -723,6 +874,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
723874
}
724875

725876
func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) {
877+
// Check if there's a pending elicitation waiting for this notification.
878+
if cs, ok := req.GetSession().(*ClientSession); ok {
879+
cs.pendingElicitationsMu.Lock()
880+
if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists {
881+
select {
882+
case ch <- struct{}{}:
883+
default:
884+
// Channel already signaled.
885+
}
886+
}
887+
cs.pendingElicitationsMu.Unlock()
888+
}
889+
890+
// Call the user's handler if provided.
726891
if h := c.opts.ElicitationCompleteHandler; h != nil {
727892
h(ctx, req)
728893
}

mcp/client_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,61 @@ func TestClientCapabilities(t *testing.T) {
222222
Sampling: &SamplingCapabilities{},
223223
},
224224
},
225+
{
226+
name: "With form elicitation",
227+
configureClient: func(s *Client) {},
228+
clientOpts: ClientOptions{
229+
ElicitationModes: []string{"form"},
230+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
231+
return nil, nil
232+
},
233+
},
234+
wantCapabilities: &ClientCapabilities{
235+
Roots: struct {
236+
ListChanged bool "json:\"listChanged,omitempty\""
237+
}{ListChanged: true},
238+
Elicitation: &ElicitationCapabilities{
239+
Form: &FormElicitationCapabilities{},
240+
},
241+
},
242+
},
243+
{
244+
name: "With URL elicitation",
245+
configureClient: func(s *Client) {},
246+
clientOpts: ClientOptions{
247+
ElicitationModes: []string{"url"},
248+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
249+
return nil, nil
250+
},
251+
},
252+
wantCapabilities: &ClientCapabilities{
253+
Roots: struct {
254+
ListChanged bool "json:\"listChanged,omitempty\""
255+
}{ListChanged: true},
256+
Elicitation: &ElicitationCapabilities{
257+
URL: &URLElicitationCapabilities{},
258+
},
259+
},
260+
},
261+
{
262+
name: "With both form and URL elicitation",
263+
configureClient: func(s *Client) {},
264+
clientOpts: ClientOptions{
265+
ElicitationModes: []string{"form", "url"},
266+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
267+
return nil, nil
268+
},
269+
},
270+
wantCapabilities: &ClientCapabilities{
271+
Roots: struct {
272+
ListChanged bool "json:\"listChanged,omitempty\""
273+
}{ListChanged: true},
274+
Elicitation: &ElicitationCapabilities{
275+
Form: &FormElicitationCapabilities{},
276+
URL: &URLElicitationCapabilities{},
277+
},
278+
},
279+
},
225280
}
226281

227282
for _, tc := range testCases {

0 commit comments

Comments
 (0)