diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index d011c225a..85f3f69a6 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -131,22 +131,6 @@ where .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string())) } -/// Helper function to expect a request from the stream -async fn expect_request( - transport: &mut T, - context: &str, -) -> Result<(ClientRequest, RequestId), ServerInitializeError> -where - T: Transport, -{ - let msg = expect_next_message(transport, context).await?; - let msg_clone = msg.clone(); - msg.into_request() - .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some( - msg_clone, - ))) -} - pub async fn serve_server_with_ct( service: S, transport: T, @@ -177,8 +161,35 @@ where let mut transport = transport.into_transport(); let id_provider = >::default(); - // Get initialize request - let (request, id) = expect_request(&mut transport, "initialized request").await?; + // Get initialize request; the MCP spec permits ping before initialize. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle#initialization + let (request, id) = loop { + let msg = expect_next_message(&mut transport, "initialize request").await?; + match msg { + ClientJsonRpcMessage::Request(req) + if matches!(req.request, ClientRequest::PingRequest(_)) => + { + transport + .send(ServerJsonRpcMessage::response( + ServerResult::EmptyResult(EmptyResult {}), + req.id, + )) + .await + .map_err(|error| { + ServerInitializeError::transport::( + error, + "sending pre-init ping response", + ) + })?; + } + ClientJsonRpcMessage::Request(req) => break (req.request, req.id), + other => { + return Err(ServerInitializeError::ExpectedInitializeRequest(Some( + other, + ))); + } + } + }; let ClientRequest::InitializeRequest(peer_info) = &request else { return Err(ServerInitializeError::ExpectedInitializeRequest(Some( diff --git a/crates/rmcp/tests/test_server_initialization.rs b/crates/rmcp/tests/test_server_initialization.rs index c07501f0b..88a8e45b2 100644 --- a/crates/rmcp/tests/test_server_initialization.rs +++ b/crates/rmcp/tests/test_server_initialization.rs @@ -96,6 +96,47 @@ async fn server_init_succeeds_after_set_level_before_initialized() { result.unwrap().cancel().await.unwrap(); } +// Server responds with EmptyResult to ping received before initialize request. +#[tokio::test] +async fn server_init_ping_response_is_empty_result_before_initialize() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + client.send(ping_request(1)).await.unwrap(); + + let response = client.receive().await.unwrap(); + assert!( + matches!( + response, + ServerJsonRpcMessage::Response(ref r) + if matches!(r.result, ServerResult::EmptyResult(_)) + ), + "expected EmptyResult for pre-initialize ping, got: {response:?}" + ); +} + +// Server initializes successfully when ping is sent before the initialize request. +#[tokio::test] +async fn server_init_succeeds_after_ping_before_initialize() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = + tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + client.send(ping_request(1)).await.unwrap(); + let _pong = client.receive().await.unwrap(); + do_initialize(&mut client).await; + client.send(initialized_notification()).await.unwrap(); + + let result = server_handle.await.unwrap(); + assert!( + result.is_ok(), + "server should initialize successfully after pre-initialize ping" + ); + result.unwrap().cancel().await.unwrap(); +} + // Server responds with EmptyResult to ping received before initialized. #[tokio::test] async fn server_init_ping_response_is_empty_result() {