@@ -7,9 +7,11 @@ package mcp
77import (
88 "context"
99 "encoding/json"
10+ "errors"
1011 "fmt"
1112 "iter"
1213 "slices"
14+ "strings"
1315 "sync"
1416 "sync/atomic"
1517 "time"
@@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
4547 c := & Client {
4648 impl : impl ,
4749 roots : newFeatureSet (func (r * Root ) string { return r .URI }),
48- sendingMethodHandler_ : defaultSendingMethodHandler [ * ClientSession ] ,
50+ sendingMethodHandler_ : defaultSendingMethodHandler ,
4951 receivingMethodHandler_ : defaultReceivingMethodHandler [* ClientSession ],
5052 }
5153 if opts != nil {
@@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
134136 // {"form":{}} for backward compatibility, but we explicitly set the form
135137 // capability.
136138 caps .Elicitation .Form = & FormElicitationCapabilities {}
137- } else if slices .Contains (modes , "url" ) {
139+ }
140+ if slices .Contains (modes , "url" ) {
138141 caps .Elicitation .URL = & URLElicitationCapabilities {}
139142 }
140143 }
@@ -206,6 +209,10 @@ type ClientSession struct {
206209 // No mutex is (currently) required to guard the session state, because it is
207210 // only set synchronously during Client.Connect.
208211 state clientSessionState
212+
213+ // Pending URL elicitations waiting for completion notifications.
214+ pendingElicitationsMu sync.Mutex
215+ pendingElicitations map [string ]chan struct {}
209216}
210217
211218type clientSessionState struct {
@@ -250,6 +257,46 @@ func (cs *ClientSession) Wait() error {
250257 return cs .conn .Wait ()
251258}
252259
260+ // registerElicitationWaiter registers a waiter for an elicitation complete
261+ // notification with the given elicitation ID. It returns two functions: an await
262+ // function that waits for the notification or context cancellation, and a cleanup
263+ // function that must be called to unregister the waiter. This must be called before
264+ // triggering the elicitation to avoid a race condition where the notification
265+ // arrives before the waiter is registered.
266+ //
267+ // The cleanup function must be called even if the await function is never called,
268+ // to prevent leaking the registration.
269+ func (cs * ClientSession ) registerElicitationWaiter (elicitationID string ) (await func (context.Context ) error , cleanup func ()) {
270+ // Create a channel for this elicitation.
271+ ch := make (chan struct {}, 1 )
272+
273+ // Register the channel.
274+ cs .pendingElicitationsMu .Lock ()
275+ if cs .pendingElicitations == nil {
276+ cs .pendingElicitations = make (map [string ]chan struct {})
277+ }
278+ cs .pendingElicitations [elicitationID ] = ch
279+ cs .pendingElicitationsMu .Unlock ()
280+
281+ // Return await and cleanup functions.
282+ await = func (ctx context.Context ) error {
283+ select {
284+ case <- ctx .Done ():
285+ return fmt .Errorf ("context cancelled while waiting for elicitation completion: %w" , ctx .Err ())
286+ case <- ch :
287+ return nil
288+ }
289+ }
290+
291+ cleanup = func () {
292+ cs .pendingElicitationsMu .Lock ()
293+ delete (cs .pendingElicitations , elicitationID )
294+ cs .pendingElicitationsMu .Unlock ()
295+ }
296+
297+ return await , cleanup
298+ }
299+
253300// startKeepalive starts the keepalive mechanism for this client session.
254301func (cs * ClientSession ) startKeepalive (interval time.Duration ) {
255302 startKeepalive (cs , interval , & cs .keepaliveCancel )
@@ -309,6 +356,110 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (
309356 return c .opts .CreateMessageHandler (ctx , req )
310357}
311358
359+ // urlElicitationMiddleware returns middleware that automatically handles URL elicitation
360+ // required errors by executing the elicitation handler, waiting for completion notifications,
361+ // and retrying the operation.
362+ //
363+ // This middleware should be added to clients that want automatic URL elicitation handling:
364+ //
365+ // client := mcp.NewClient(impl, opts)
366+ // client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
367+ //
368+ // TODO(rfindley): this isn't strictly necessary for the SEP, but may be
369+ // useful. Propose exporting it it.
370+ func urlElicitationMiddleware () Middleware {
371+ return func (next MethodHandler ) MethodHandler {
372+ return func (ctx context.Context , method string , req Request ) (Result , error ) {
373+ // Call the underlying handler.
374+ res , err := next (ctx , method , req )
375+ if err == nil {
376+ return res , nil
377+ }
378+
379+ // Check if this is a URL elicitation required error.
380+ var rpcErr * jsonrpc.Error
381+ if ! errors .As (err , & rpcErr ) || rpcErr .Code != CodeURLElicitationRequired {
382+ return res , err
383+ }
384+
385+ // Notifications don't support retries.
386+ if strings .HasPrefix (method , "notifications/" ) {
387+ return res , err
388+ }
389+
390+ // Extract the client session.
391+ cs , ok := req .GetSession ().(* ClientSession )
392+ if ! ok {
393+ return res , err
394+ }
395+
396+ // Check if the client has an elicitation handler.
397+ if cs .client .opts .ElicitationHandler == nil {
398+ return res , err
399+ }
400+
401+ // Parse the elicitations from the error data.
402+ var errorData struct {
403+ Elicitations []* ElicitParams `json:"elicitations"`
404+ }
405+ if rpcErr .Data != nil {
406+ if err := json .Unmarshal (rpcErr .Data , & errorData ); err != nil {
407+ return nil , fmt .Errorf ("failed to parse URL elicitation error data: %w" , err )
408+ }
409+ }
410+
411+ // Validate that all elicitations are URL mode.
412+ for _ , elicit := range errorData .Elicitations {
413+ mode := elicit .Mode
414+ if mode == "" {
415+ mode = "form" // Default mode.
416+ }
417+ if mode != "url" {
418+ return nil , fmt .Errorf ("URLElicitationRequired error must only contain URL mode elicitations, got %q" , mode )
419+ }
420+ }
421+
422+ // Register waiters for all elicitations before executing handlers
423+ // to avoid race condition where notification arrives before waiter is registered.
424+ type waiter struct {
425+ await func (context.Context ) error
426+ cleanup func ()
427+ }
428+ waiters := make ([]waiter , 0 , len (errorData .Elicitations ))
429+ for _ , elicitParams := range errorData .Elicitations {
430+ await , cleanup := cs .registerElicitationWaiter (elicitParams .ElicitationID )
431+ waiters = append (waiters , waiter {await : await , cleanup : cleanup })
432+ }
433+
434+ // Ensure cleanup happens even if we return early.
435+ defer func () {
436+ for _ , w := range waiters {
437+ w .cleanup ()
438+ }
439+ }()
440+
441+ // Execute the elicitation handler for each elicitation.
442+ for _ , elicitParams := range errorData .Elicitations {
443+ elicitReq := newClientRequest (cs , elicitParams )
444+ _ , elicitErr := cs .client .elicit (ctx , elicitReq )
445+ if elicitErr != nil {
446+ return nil , fmt .Errorf ("URL elicitation failed: %w" , elicitErr )
447+ }
448+ }
449+
450+ // Wait for all elicitations to complete.
451+ for _ , w := range waiters {
452+ if err := w .await (ctx ); err != nil {
453+ return nil , err
454+ }
455+ }
456+
457+ // All elicitations complete, retry the original operation.
458+ return next (ctx , method , req )
459+ }
460+ }
461+ }
462+
312463func (c * Client ) elicit (ctx context.Context , req * ElicitRequest ) (* ElicitResult , error ) {
313464 if c .opts .ElicitationHandler == nil {
314465 return nil , & jsonrpc.Error {Code : jsonrpc .CodeInvalidParams , Message : "client does not support elicitation" }
@@ -723,6 +874,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
723874}
724875
725876func (c * Client ) callElicitationCompleteHandler (ctx context.Context , req * ElicitationCompleteNotificationRequest ) (Result , error ) {
877+ // Check if there's a pending elicitation waiting for this notification.
878+ if cs , ok := req .GetSession ().(* ClientSession ); ok {
879+ cs .pendingElicitationsMu .Lock ()
880+ if ch , exists := cs .pendingElicitations [req .Params .ElicitationID ]; exists {
881+ select {
882+ case ch <- struct {}{}:
883+ default :
884+ // Channel already signaled.
885+ }
886+ }
887+ cs .pendingElicitationsMu .Unlock ()
888+ }
889+
890+ // Call the user's handler if provided.
726891 if h := c .opts .ElicitationCompleteHandler ; h != nil {
727892 h (ctx , req )
728893 }
0 commit comments