1- //! reference: https://html.spec.whatwg.org/multipage/server-sent-events.html
2- use std:: { pin:: Pin , sync:: Arc } ;
1+ //! Reference: <https://html.spec.whatwg.org/multipage/server-sent-events.html>
2+ use std:: {
3+ pin:: Pin ,
4+ sync:: { Arc , RwLock } ,
5+ } ;
36
47use futures:: { StreamExt , future:: BoxFuture } ;
58use http:: Uri ;
6- use sse_stream:: Error as SseError ;
9+ use sse_stream:: { Error as SseError , Sse } ;
710use thiserror:: Error ;
811
912use super :: {
@@ -54,9 +57,13 @@ pub trait SseClient: Clone + Send + Sync + 'static {
5457 ) -> impl Future < Output = Result < BoxedSseResponse , SseTransportError < Self :: Error > > > + Send + ' _ ;
5558}
5659
60+ /// Helper that refreshes the POST endpoint whenever the server emits
61+ /// control frames during SSE reconnect; used together with
62+ /// [`SseAutoReconnectStream`].
5763struct SseClientReconnect < C > {
5864 pub client : C ,
5965 pub uri : Uri ,
66+ pub message_endpoint : Arc < RwLock < Uri > > ,
6067}
6168
6269impl < C : SseClient > SseStreamReconnect for SseClientReconnect < C > {
@@ -68,6 +75,37 @@ impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
6875 let last_event_id = last_event_id. map ( |s| s. to_owned ( ) ) ;
6976 Box :: pin ( async move { client. get_stream ( uri, last_event_id, None ) . await } )
7077 }
78+
79+ fn handle_control_event ( & mut self , event : & Sse ) -> Result < ( ) , Self :: Error > {
80+ if event. event . as_deref ( ) != Some ( "endpoint" ) {
81+ return Ok ( ( ) ) ;
82+ }
83+ let Some ( data) = event. data . as_ref ( ) else {
84+ return Ok ( ( ) ) ;
85+ } ;
86+ // Servers typically resend the message POST endpoint (often with a new
87+ // sessionId) when a stream reconnects. Reuse `message_endpoint` helper
88+ // to resolve it and update the shared URI.
89+ let new_endpoint = message_endpoint ( self . uri . clone ( ) , data. clone ( ) )
90+ . map_err ( SseTransportError :: InvalidUri ) ?;
91+ * self
92+ . message_endpoint
93+ . write ( )
94+ . expect ( "message endpoint lock poisoned" ) = new_endpoint;
95+ Ok ( ( ) )
96+ }
97+
98+ fn handle_stream_error (
99+ & mut self ,
100+ error : & ( dyn std:: error:: Error + ' static ) ,
101+ last_event_id : Option < & str > ,
102+ ) {
103+ tracing:: warn!(
104+ uri = %self . uri,
105+ last_event_id = last_event_id. unwrap_or( "" ) ,
106+ "sse stream error: {error}"
107+ ) ;
108+ }
71109}
72110type ServerMessageStream < C > = Pin < Box < SseAutoReconnectStream < SseClientReconnect < C > > > > ;
73111
@@ -81,7 +119,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
81119///
82120/// ## Using reqwest
83121///
84- /// ```rust
122+ /// ```rust,ignore
85123/// use rmcp::transport::SseClientTransport;
86124///
87125/// // Enable the reqwest feature in Cargo.toml:
@@ -95,7 +133,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
95133///
96134/// ## Using a custom HTTP client
97135///
98- /// ```rust
136+ /// ```rust,ignore
99137/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig};
100138/// use std::sync::Arc;
101139/// use futures::stream::BoxStream;
@@ -154,7 +192,9 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
154192pub struct SseClientTransport < C : SseClient > {
155193 client : C ,
156194 config : SseClientConfig ,
157- message_endpoint : Uri ,
195+ /// Current POST endpoint; refreshed when the server sends new endpoint
196+ /// control frames.
197+ message_endpoint : Arc < RwLock < Uri > > ,
158198 stream : Option < ServerMessageStream < C > > ,
159199}
160200
@@ -168,8 +208,16 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
168208 item : crate :: service:: TxJsonRpcMessage < RoleClient > ,
169209 ) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send + ' static {
170210 let client = self . client . clone ( ) ;
171- let uri = self . message_endpoint . clone ( ) ;
172- async move { client. post_message ( uri, item, None ) . await }
211+ let message_endpoint = self . message_endpoint . clone ( ) ;
212+ async move {
213+ let uri = {
214+ let guard = message_endpoint
215+ . read ( )
216+ . expect ( "message endpoint lock poisoned" ) ;
217+ guard. clone ( )
218+ } ;
219+ client. post_message ( uri, item, None ) . await
220+ }
173221 }
174222 async fn close ( & mut self ) -> Result < ( ) , Self :: Error > {
175223 self . stream . take ( ) ;
@@ -194,7 +242,7 @@ impl<C: SseClient> SseClientTransport<C> {
194242 let sse_endpoint = config. sse_endpoint . as_ref ( ) . parse :: < http:: Uri > ( ) ?;
195243
196244 let mut sse_stream = client. get_stream ( sse_endpoint. clone ( ) , None , None ) . await ?;
197- let message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
245+ let initial_message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
198246 let ep = endpoint. parse :: < http:: Uri > ( ) ?;
199247 let mut sse_endpoint_parts = sse_endpoint. clone ( ) . into_parts ( ) ;
200248 sse_endpoint_parts. path_and_query = ep. into_parts ( ) . path_and_query ;
@@ -214,12 +262,14 @@ impl<C: SseClient> SseClientTransport<C> {
214262 break message_endpoint ( sse_endpoint. clone ( ) , ep) ?;
215263 }
216264 } ;
265+ let message_endpoint = Arc :: new ( RwLock :: new ( initial_message_endpoint) ) ;
217266
218267 let stream = Box :: pin ( SseAutoReconnectStream :: new (
219268 sse_stream,
220269 SseClientReconnect {
221270 client : client. clone ( ) ,
222271 uri : sse_endpoint. clone ( ) ,
272+ message_endpoint : message_endpoint. clone ( ) ,
223273 } ,
224274 config. retry_policy . clone ( ) ,
225275 ) ) ;
@@ -274,7 +324,7 @@ pub struct SseClientConfig {
274324 /// and the server send the message endpoint event as `message?session_id=123`,
275325 /// then the message endpoint will be `http://example.com/message`.
276326 ///
277- /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN /docs/Web/API/URL/URL)
327+ /// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US /docs/Web/API/URL/URL)
278328 pub sse_endpoint : Arc < str > ,
279329 pub retry_policy : Arc < dyn SseRetryPolicy > ,
280330 /// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
@@ -293,8 +343,40 @@ impl Default for SseClientConfig {
293343
294344#[ cfg( test) ]
295345mod tests {
346+ use futures:: StreamExt ;
347+ use serde_json:: { Value , json} ;
348+
296349 use super :: * ;
297350
351+ #[ derive( Clone ) ]
352+ struct DummyClient ;
353+
354+ #[ derive( Debug , thiserror:: Error ) ]
355+ #[ error( "dummy error" ) ]
356+ struct DummyError ;
357+
358+ impl SseClient for DummyClient {
359+ type Error = DummyError ;
360+
361+ async fn post_message (
362+ & self ,
363+ _uri : Uri ,
364+ _message : ClientJsonRpcMessage ,
365+ _auth_token : Option < String > ,
366+ ) -> Result < ( ) , SseTransportError < Self :: Error > > {
367+ Ok ( ( ) )
368+ }
369+
370+ async fn get_stream (
371+ & self ,
372+ _uri : Uri ,
373+ _last_event_id : Option < String > ,
374+ _auth_token : Option < String > ,
375+ ) -> Result < BoxedSseResponse , SseTransportError < Self :: Error > > {
376+ unreachable ! ( "get_stream should not be called in this test" )
377+ }
378+ }
379+
298380 #[ test]
299381 fn test_message_endpoint ( ) {
300382 let base_url = "https://localhost/sse" . parse :: < http:: Uri > ( ) . unwrap ( ) ;
@@ -319,4 +401,58 @@ mod tests {
319401 . unwrap ( ) ;
320402 assert_eq ! ( result. to_string( ) , "http://example.com/xxx?sessionId=x" ) ;
321403 }
404+
405+ #[ test]
406+ fn handle_endpoint_control_event_updates_uri ( ) {
407+ let initial_endpoint = "https://example.com/message?sessionId=old"
408+ . parse :: < Uri > ( )
409+ . unwrap ( ) ;
410+ let shared_endpoint = Arc :: new ( RwLock :: new ( initial_endpoint) ) ;
411+ let mut reconnect = SseClientReconnect {
412+ client : DummyClient ,
413+ uri : "https://example.com/sse" . parse :: < Uri > ( ) . unwrap ( ) ,
414+ message_endpoint : shared_endpoint. clone ( ) ,
415+ } ;
416+
417+ let control_event = Sse :: default ( )
418+ . event ( "endpoint" )
419+ . data ( "/message?sessionId=new" ) ;
420+
421+ reconnect. handle_control_event ( & control_event) . unwrap ( ) ;
422+
423+ let guard = shared_endpoint. read ( ) . expect ( "lock poisoned" ) ;
424+ assert_eq ! (
425+ guard. to_string( ) ,
426+ "https://example.com/message?sessionId=new"
427+ ) ;
428+ }
429+
430+ #[ tokio:: test]
431+ async fn control_event_frames_are_skipped ( ) {
432+ let payload = json ! ( {
433+ "jsonrpc" : "2.0" ,
434+ "id" : 1 ,
435+ "result" : { "ok" : true }
436+ } )
437+ . to_string ( ) ;
438+
439+ let events = vec ! [
440+ Ok ( Sse :: default ( )
441+ . event( "endpoint" )
442+ . data( "/message?sessionId=reconnect" ) ) ,
443+ Ok ( Sse :: default ( ) . event( "message" ) . data( payload. clone( ) ) ) ,
444+ ] ;
445+
446+ let sse_src: BoxedSseResponse = futures:: stream:: iter ( events) . boxed ( ) ;
447+ let reconn_stream = SseAutoReconnectStream :: never_reconnect ( sse_src, DummyError ) ;
448+ futures:: pin_mut!( reconn_stream) ;
449+
450+ let message = reconn_stream. next ( ) . await . expect ( "stream item" ) . unwrap ( ) ;
451+ let actual: Value = serde_json:: to_value ( message) . expect ( "serialize actual message" ) ;
452+ // We only need to assert that a valid JSON-RPC response came through after
453+ // skipping control frames. The exact `result` shape depends on the SDK's
454+ // typed result enums and is not asserted here.
455+ assert_eq ! ( actual. get( "jsonrpc" ) , Some ( & Value :: String ( "2.0" . into( ) ) ) ) ;
456+ assert_eq ! ( actual. get( "id" ) , Some ( & Value :: Number ( 1u64 . into( ) ) ) ) ;
457+ }
322458}
0 commit comments