@@ -7,9 +7,11 @@ package mcp
77import (
88 "context"
99 "encoding/json"
10+ "errors"
1011 "fmt"
1112 "iter"
1213 "slices"
14+ "strings"
1315 "sync"
1416 "sync/atomic"
1517 "time"
@@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
4547 c := & Client {
4648 impl : impl ,
4749 roots : newFeatureSet (func (r * Root ) string { return r .URI }),
48- sendingMethodHandler_ : defaultSendingMethodHandler [ * ClientSession ] ,
50+ sendingMethodHandler_ : defaultSendingMethodHandler ,
4951 receivingMethodHandler_ : defaultReceivingMethodHandler [* ClientSession ],
5052 }
5153 if opts != nil {
@@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
134136 // {"form":{}} for backward compatibility, but we explicitly set the form
135137 // capability.
136138 caps .Elicitation .Form = & FormElicitationCapabilities {}
137- } else if slices .Contains (modes , "url" ) {
139+ }
140+ if slices .Contains (modes , "url" ) {
138141 caps .Elicitation .URL = & URLElicitationCapabilities {}
139142 }
140143 }
@@ -206,6 +209,10 @@ type ClientSession struct {
206209 // No mutex is (currently) required to guard the session state, because it is
207210 // only set synchronously during Client.Connect.
208211 state clientSessionState
212+
213+ // Pending URL elicitations waiting for completion notifications.
214+ pendingElicitationsMu sync.Mutex
215+ pendingElicitations map [string ]chan struct {}
209216}
210217
211218type clientSessionState struct {
@@ -250,6 +257,39 @@ func (cs *ClientSession) Wait() error {
250257 return cs .conn .Wait ()
251258}
252259
260+ // registerElicitationWaiter registers a waiter for an elicitation complete notification
261+ // with the given elicitation ID. It returns a function that waits for the notification
262+ // or context cancellation. This must be called before triggering the elicitation to avoid
263+ // a race condition where the notification arrives before the waiter is registered.
264+ func (cs * ClientSession ) registerElicitationWaiter (elicitationID string ) func (context.Context ) error {
265+ // Create a channel for this elicitation.
266+ ch := make (chan struct {}, 1 )
267+
268+ // Register the channel.
269+ cs .pendingElicitationsMu .Lock ()
270+ if cs .pendingElicitations == nil {
271+ cs .pendingElicitations = make (map [string ]chan struct {})
272+ }
273+ cs .pendingElicitations [elicitationID ] = ch
274+ cs .pendingElicitationsMu .Unlock ()
275+
276+ // Return a function that waits for completion and cleans up.
277+ return func (ctx context.Context ) error {
278+ defer func () {
279+ cs .pendingElicitationsMu .Lock ()
280+ delete (cs .pendingElicitations , elicitationID )
281+ cs .pendingElicitationsMu .Unlock ()
282+ }()
283+
284+ select {
285+ case <- ctx .Done ():
286+ return fmt .Errorf ("context cancelled while waiting for elicitation completion: %w" , ctx .Err ())
287+ case <- ch :
288+ return nil
289+ }
290+ }
291+ }
292+
253293// startKeepalive starts the keepalive mechanism for this client session.
254294func (cs * ClientSession ) startKeepalive (interval time.Duration ) {
255295 startKeepalive (cs , interval , & cs .keepaliveCancel )
@@ -309,6 +349,98 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (
309349 return c .opts .CreateMessageHandler (ctx , req )
310350}
311351
352+ // urlElicitationMiddleware returns middleware that automatically handles URL elicitation
353+ // required errors by executing the elicitation handler, waiting for completion notifications,
354+ // and retrying the operation.
355+ //
356+ // This middleware should be added to clients that want automatic URL elicitation handling:
357+ //
358+ // client := mcp.NewClient(impl, opts)
359+ // client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
360+ //
361+ // TODO(rfindley): this isn't strictly necessary for the SEP, but may be
362+ // useful. Propose exporting it it.
363+ func urlElicitationMiddleware () Middleware {
364+ return func (next MethodHandler ) MethodHandler {
365+ return func (ctx context.Context , method string , req Request ) (Result , error ) {
366+ // Call the underlying handler.
367+ res , err := next (ctx , method , req )
368+ if err == nil {
369+ return res , nil
370+ }
371+
372+ // Check if this is a URL elicitation required error.
373+ var rpcErr * jsonrpc.Error
374+ if ! errors .As (err , & rpcErr ) || rpcErr .Code != CodeURLElicitationRequired {
375+ return res , err
376+ }
377+
378+ // Notifications don't support retries.
379+ if strings .HasPrefix (method , "notifications/" ) {
380+ return res , err
381+ }
382+
383+ // Extract the client session.
384+ cs , ok := req .GetSession ().(* ClientSession )
385+ if ! ok {
386+ return res , err
387+ }
388+
389+ // Check if the client has an elicitation handler.
390+ if cs .client .opts .ElicitationHandler == nil {
391+ return res , err
392+ }
393+
394+ // Parse the elicitations from the error data.
395+ var errorData struct {
396+ Elicitations []* ElicitParams `json:"elicitations"`
397+ }
398+ if rpcErr .Data != nil {
399+ if err := json .Unmarshal (rpcErr .Data , & errorData ); err != nil {
400+ return nil , fmt .Errorf ("failed to parse URL elicitation error data: %w" , err )
401+ }
402+ }
403+
404+ // Validate that all elicitations are URL mode.
405+ for _ , elicit := range errorData .Elicitations {
406+ mode := elicit .Mode
407+ if mode == "" {
408+ mode = "form" // Default mode.
409+ }
410+ if mode != "url" {
411+ return nil , fmt .Errorf ("URLElicitationRequired error must only contain URL mode elicitations, got %q" , mode )
412+ }
413+ }
414+
415+ // Register waiters for all elicitations before executing handlers
416+ // to avoid race condition where notification arrives before waiter is registered.
417+ awaitFuncs := make ([]func (context.Context ) error , 0 , len (errorData .Elicitations ))
418+ for _ , elicitParams := range errorData .Elicitations {
419+ awaitFuncs = append (awaitFuncs , cs .registerElicitationWaiter (elicitParams .ElicitationID ))
420+ }
421+
422+ // Execute the elicitation handler for each elicitation.
423+ for _ , elicitParams := range errorData .Elicitations {
424+ elicitReq := newClientRequest (cs , elicitParams )
425+ _ , elicitErr := cs .client .elicit (ctx , elicitReq )
426+ if elicitErr != nil {
427+ return nil , fmt .Errorf ("URL elicitation failed: %w" , elicitErr )
428+ }
429+ }
430+
431+ // Wait for all elicitations to complete.
432+ for _ , await := range awaitFuncs {
433+ if err := await (ctx ); err != nil {
434+ return nil , err
435+ }
436+ }
437+
438+ // All elicitations complete, retry the original operation.
439+ return next (ctx , method , req )
440+ }
441+ }
442+ }
443+
312444func (c * Client ) elicit (ctx context.Context , req * ElicitRequest ) (* ElicitResult , error ) {
313445 if c .opts .ElicitationHandler == nil {
314446 return nil , & jsonrpc.Error {Code : jsonrpc .CodeInvalidParams , Message : "client does not support elicitation" }
@@ -723,6 +855,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
723855}
724856
725857func (c * Client ) callElicitationCompleteHandler (ctx context.Context , req * ElicitationCompleteNotificationRequest ) (Result , error ) {
858+ // Check if there's a pending elicitation waiting for this notification.
859+ if cs , ok := req .GetSession ().(* ClientSession ); ok {
860+ cs .pendingElicitationsMu .Lock ()
861+ if ch , exists := cs .pendingElicitations [req .Params .ElicitationID ]; exists {
862+ select {
863+ case ch <- struct {}{}:
864+ default :
865+ // Channel already signaled.
866+ }
867+ }
868+ cs .pendingElicitationsMu .Unlock ()
869+ }
870+
871+ // Call the user's handler if provided.
726872 if h := c .opts .ElicitationCompleteHandler ; h != nil {
727873 h (ctx , req )
728874 }
0 commit comments