Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ uuid = { version = "1", features = ["v4"], optional = true }
http-body = { version = "1", optional = true }
http-body-util = { version = "0.1", optional = true }
bytes = { version = "1", optional = true }

# for unix socket transport
hyper = { version = "1", features = ["client", "http1"], optional = true }
hyper-util = { version = "0.1", features = ["tokio"], optional = true }

# macro
rmcp-macros = { workspace = true, optional = true }
[target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies]
Expand Down Expand Up @@ -111,6 +116,15 @@ client-side-sse = ["dep:sse-stream", "dep:http"]
# Streamable HTTP client
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"]
transport-streamable-http-client-unix-socket = [
"transport-streamable-http-client",
"dep:hyper",
"dep:hyper-util",
"dep:http-body-util",
"dep:http",
"dep:bytes",
"tokio/net",
]

transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
transport-io = ["transport-async-rw", "tokio/io-std"]
Expand Down Expand Up @@ -259,3 +273,12 @@ path = "tests/test_sse_concurrent_streams.rs"
name = "test_client_credentials"
required-features = ["auth"]
path = "tests/test_client_credentials.rs"

[[test]]
name = "test_unix_socket_transport"
required-features = [
"client",
"server",
"transport-streamable-http-client-unix-socket",
]
path = "tests/test_unix_socket_transport.rs"
2 changes: 2 additions & 0 deletions crates/rmcp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHt

#[cfg(feature = "transport-streamable-http-client")]
pub mod streamable_http_client;
#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))]
pub use common::unix_socket::UnixSocketHttpClient;
#[cfg(feature = "transport-streamable-http-client")]
pub use streamable_http_client::StreamableHttpClientTransport;

Expand Down
3 changes: 3 additions & 0 deletions crates/rmcp/src/transport/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pub mod client_side_sse;

#[cfg(feature = "auth")]
pub mod auth;

#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))]
pub mod unix_socket;
119 changes: 119 additions & 0 deletions crates/rmcp/src/transport/common/http_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,122 @@ pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version";
pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
pub const JSON_MIME_TYPE: &str = "application/json";

/// Reserved headers that must not be overridden by user-supplied custom headers.
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
/// injects it after initialization.
pub(crate) const RESERVED_HEADERS: &[&str] = &[
"accept",
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION, // allowed through by validate_custom_header; worker injects it post-init
HEADER_LAST_EVENT_ID,
];

/// Checks whether a custom header name is allowed.
/// Returns `Ok(())` if allowed, `Err(name)` if rejected as reserved.
/// `MCP-Protocol-Version` is reserved but allowed through (the worker injects it post-init).
#[cfg(feature = "client-side-sse")]
pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), String> {
if RESERVED_HEADERS
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
if name
.as_str()
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
{
return Ok(());
}
return Err(name.to_string());
}
Ok(())
}

/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value.
/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {
let header_lowercase = header.to_ascii_lowercase();
let scope_key = "scope=";

if let Some(pos) = header_lowercase.find(scope_key) {
let start = pos + scope_key.len();
let value_slice = &header[start..];

if let Some(stripped) = value_slice.strip_prefix('"') {
if let Some(end_quote) = stripped.find('"') {
return Some(stripped[..end_quote].to_string());
}
} else {
let end = value_slice
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
.unwrap_or(value_slice.len());
if end > 0 {
return Some(value_slice[..end].to_string());
}
}
}

None
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn extract_scope_quoted() {
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
assert_eq!(
extract_scope_from_header(header),
Some("files:read files:write".to_string())
);
}

#[test]
fn extract_scope_unquoted() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
assert_eq!(
extract_scope_from_header(header),
Some("read:data".to_string())
);
}

#[test]
fn extract_scope_missing() {
let header = r#"Bearer error="invalid_token""#;
assert_eq!(extract_scope_from_header(header), None);
}

#[test]
fn extract_scope_empty_header() {
assert_eq!(extract_scope_from_header("Bearer"), None);
}

#[cfg(feature = "client-side-sse")]
#[test]
fn validate_rejects_reserved_accept() {
let name = http::HeaderName::from_static("accept");
assert!(validate_custom_header(&name).is_err());
}

#[cfg(feature = "client-side-sse")]
#[test]
fn validate_rejects_reserved_session_id() {
let name = http::HeaderName::from_static("mcp-session-id");
assert!(validate_custom_header(&name).is_err());
}

#[cfg(feature = "client-side-sse")]
#[test]
fn validate_allows_mcp_protocol_version() {
let name = http::HeaderName::from_static("mcp-protocol-version");
assert!(validate_custom_header(&name).is_ok());
}

#[cfg(feature = "client-side-sse")]
#[test]
fn validate_allows_custom_header() {
let name = http::HeaderName::from_static("x-custom");
assert!(validate_custom_header(&name).is_ok());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::{
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
transport::{
common::http_header::{
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
HEADER_SESSION_ID, JSON_MIME_TYPE,
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
extract_scope_from_header, validate_custom_header,
},
streamable_http_client::*,
},
Expand All @@ -22,38 +22,13 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
}
}

/// Reserved headers that must not be overridden by user-supplied custom headers.
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
/// injects it after initialization.
const RESERVED_HEADERS: &[&str] = &[
"accept",
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION,
HEADER_LAST_EVENT_ID,
];

/// Applies custom headers to a request builder, rejecting reserved headers
/// except `MCP-Protocol-Version` (which the worker injects after init).
/// Applies custom headers to a request builder, rejecting reserved headers.
fn apply_custom_headers(
mut builder: reqwest::RequestBuilder,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<reqwest::RequestBuilder, StreamableHttpError<reqwest::Error>> {
for (name, value) in custom_headers {
if RESERVED_HEADERS
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
if name
.as_str()
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
{
builder = builder.header(name, value);
continue;
}
return Err(StreamableHttpError::ReservedHeaderConflict(
name.to_string(),
));
}
validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?;
builder = builder.header(name, value);
}
Ok(builder)
Expand Down Expand Up @@ -280,66 +255,10 @@ impl StreamableHttpClientTransport<reqwest::Client> {
}
}

/// extract scope parameter from WWW-Authenticate header
fn extract_scope_from_header(header: &str) -> Option<String> {
let header_lowercase = header.to_ascii_lowercase();
let scope_key = "scope=";

if let Some(pos) = header_lowercase.find(scope_key) {
let start = pos + scope_key.len();
let value_slice = &header[start..];

if let Some(stripped) = value_slice.strip_prefix('"') {
if let Some(end_quote) = stripped.find('"') {
return Some(stripped[..end_quote].to_string());
}
} else {
let end = value_slice
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
.unwrap_or(value_slice.len());
if end > 0 {
return Some(value_slice[..end].to_string());
}
}
}

None
}

#[cfg(test)]
mod tests {
use super::extract_scope_from_header;
use crate::transport::streamable_http_client::InsufficientScopeError;

#[test]
fn extract_scope_quoted() {
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
assert_eq!(
extract_scope_from_header(header),
Some("files:read files:write".to_string())
);
}

#[test]
fn extract_scope_unquoted() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
assert_eq!(
extract_scope_from_header(header),
Some("read:data".to_string())
);
}

#[test]
fn extract_scope_missing() {
let header = r#"Bearer error="invalid_token""#;
assert_eq!(extract_scope_from_header(header), None);
}

#[test]
fn extract_scope_empty_header() {
assert_eq!(extract_scope_from_header("Bearer"), None);
}

#[test]
fn insufficient_scope_error_can_upgrade() {
let with_scope = InsufficientScopeError {
Expand Down
Loading