From a41429abce979aa4306dfa5bda336a1e9319beb1 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 6 Feb 2026 09:33:41 +0100 Subject: [PATCH 1/5] Add websocket proxying support --- Cargo.lock | 16 ++++ Cargo.toml | 1 + src/attested_get.rs | 6 +- src/file_server.rs | 3 +- src/http_version.rs | 47 +-------- src/lib.rs | 228 +++++++++++++++++++++++++++++++++++++++----- src/main.rs | 20 +++- 7 files changed, 248 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a8115de..9c61991 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -538,6 +538,7 @@ dependencies = [ "http", "http-body-util", "hyper", + "hyper-tungstenite", "hyper-util", "jsonrpsee", "num-bigint", @@ -1942,6 +1943,21 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc778da281a749ed28d2be73a9f2cd13030680a1574bc729debd1195e44f00e9" +dependencies = [ + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + [[package]] name = "hyper-util" version = "0.1.17" diff --git a/Cargo.toml b/Cargo.toml index 115fa8c..546468b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ alloy-transport-http = { version = "1.4.3", features = [ "hyper", ], optional = true } url = { version = "2.5.7", optional = true } +hyper-tungstenite = "0.19.0" [dev-dependencies] tempfile = "3.23.0" diff --git a/src/attested_get.rs b/src/attested_get.rs index 9dc22f9..948e730 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -1,5 +1,7 @@ //! A one-shot attested TLS proxy client which sends a single GET request and returns the response -use crate::{AttestationGenerator, AttestationVerifier, ProxyClient, ProxyError}; +use crate::{ + AttestationGenerator, AttestationVerifier, ProxyClient, ProxyClientHttpMode, ProxyError, +}; use tokio_rustls::rustls::pki_types::CertificateDer; /// Start a proxy-client, send a single HTTP GET request to the given path and return the @@ -17,6 +19,7 @@ pub async fn attested_get( AttestationGenerator::with_no_attestation(), attestation_verifier, remote_certificate, + ProxyClientHttpMode::Http2, ) .await?; @@ -103,6 +106,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, + ProxyClientHttpMode::Http2, ) .await .unwrap(); diff --git a/src/file_server.rs b/src/file_server.rs index 2170127..b8ff77f 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -52,7 +52,7 @@ pub(crate) async fn static_file_server(path: PathBuf) -> Result; type Http2Sender = hyper::client::conn::http2::SendRequest; -type Http1Connection = hyper::client::conn::http1::Connection< - TokioIo>, - hyper::body::Incoming, ->; - -type Http2Connection = hyper::client::conn::http2::Connection< - TokioIo>, - hyper::body::Incoming, - crate::TokioExecutor, ->; - /// A protocol version agnostic HTTP sender pub enum HttpSender { Http1(Http1Sender), @@ -88,34 +76,5 @@ impl HttpSender { } } -pin_project_lite::pin_project! { - /// A protocol version agnostic HTTP connection - #[project = HttpConnectionProj] - pub enum HttpConnection { - Http1 { #[pin] inner: Http1Connection }, - Http2 { #[pin] inner: Http2Connection }, - } -} - -impl From for HttpConnection { - fn from(inner: Http1Connection) -> Self { - Self::Http1 { inner } - } -} - -impl From for HttpConnection { - fn from(inner: Http2Connection) -> Self { - Self::Http2 { inner } - } -} - -impl Future for HttpConnection { - type Output = Result<(), hyper::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - HttpConnectionProj::Http1 { inner } => inner.poll(cx), - HttpConnectionProj::Http2 { inner } => inner.poll(cx), - } - } -} +/// A protocol version agnostic HTTP connection future +pub type HttpConnection = Pin> + Send>>; diff --git a/src/lib.rs b/src/lib.rs index f8f586d..5f82406 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,9 @@ use tokio::io; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::VerifierBuilderError; -use tokio_rustls::rustls::{pki_types::CertificateDer, ClientConfig, ServerConfig}; +use tokio_rustls::rustls::{ + self, pki_types::CertificateDer, ClientConfig, RootCertStore, ServerConfig, +}; use tracing::{debug, error, warn}; use crate::{ @@ -266,6 +268,7 @@ impl ProxyServer { .timer(hyper_util::rt::tokio::TokioTimer::new()) .keep_alive(true) .serve_connection(io, service) + .with_upgrades() .await?; } } @@ -273,11 +276,13 @@ impl ProxyServer { Ok(()) } - // Handle a request from the proxy client to the target server + /// Handle a request from the proxy client to the target server async fn handle_http_request( - req: hyper::Request, + mut req: hyper::Request, target: String, ) -> Result>, ProxyError> { + let inbound_upgrade = hyper::upgrade::on(&mut req); + // Connect to the target server let outbound = TcpStream::connect(target).await?; let outbound_io = TokioIo::new(outbound); @@ -286,6 +291,7 @@ impl ProxyServer { .await?; // Drive the connection + let conn = conn.with_upgrades(); tokio::spawn(async move { if let Err(e) = conn.await { warn!("Client connection error: {e}"); @@ -294,7 +300,32 @@ impl ProxyServer { // Forward the request from the proxy-client to the target server match sender.send_request(req).await { - Ok(resp) => Ok(resp.map(|b| b.boxed())), + Ok(mut resp) => { + if resp.status() == hyper::StatusCode::SWITCHING_PROTOCOLS { + let outbound_upgrade = hyper::upgrade::on(&mut resp); + tokio::spawn(async move { + let inbound = match inbound_upgrade.await { + Ok(io) => io, + Err(e) => { + warn!("Inbound upgrade failed: {e}"); + return; + } + }; + let outbound = match outbound_upgrade.await { + Ok(io) => io, + Err(e) => { + warn!("Outbound upgrade failed: {e}"); + return; + } + }; + + if let Err(e) = tunnel_upgraded_streams(inbound, outbound).await { + warn!("Upgrade tunnel failed: {e}"); + } + }); + } + Ok(resp.map(|b| b.boxed())) + } Err(e) => { warn!("send_request error: {e}"); let mut resp = Response::new(full(format!("Request failed: {e}"))); @@ -321,6 +352,15 @@ pub struct ProxyClient { requests_tx: mpsc::Sender, } +/// Controls which HTTP version the proxy client uses to talk to the proxy server. +#[derive(Clone, Copy, Debug)] +pub enum ProxyClientHttpMode { + /// Use HTTP/1.1 (supports WS upgrades). + Http1, + /// Use HTTP/2 (no WS upgrades). + Http2, +} + impl ProxyClient { /// Start with optional TLS client auth pub async fn new( @@ -330,16 +370,40 @@ impl ProxyClient { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, remote_certificate: Option>, + http_mode: ProxyClientHttpMode, ) -> Result { - let attested_tls_client = AttestedTlsClient::new( - cert_and_key, + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + let client_config = if let Some(ref cert_and_key) = cert_and_key { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_client_auth_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + Self::new_with_tls_config( + client_config, + address, + server_name, attestation_generator, attestation_verifier, - remote_certificate, + cert_and_key.map(|c| c.cert_chain), + http_mode, ) - .await?; - - Self::new_with_inner(address, attested_tls_client, &server_name).await + .await } /// Create a new proxy client with given TLS configuration @@ -350,17 +414,12 @@ impl ProxyClient { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, + http_mode: ProxyClientHttpMode, ) -> Result { - for protocol in [ALPN_H2, ALPN_HTTP11] { - let already_present = client_config - .alpn_protocols - .iter() - .any(|p| p.as_slice() == protocol); - - if !already_present { - client_config.alpn_protocols.push(protocol.to_vec()); - } - } + client_config.alpn_protocols = match http_mode { + ProxyClientHttpMode::Http1 => vec![ALPN_HTTP11.to_vec()], + ProxyClientHttpMode::Http2 => vec![ALPN_H2.to_vec()], + }; let attested_tls_client = AttestedTlsClient::new_with_tls_config( client_config, @@ -562,7 +621,7 @@ impl ProxyClient { }); let io = TokioIo::new(inbound); - http.serve_connection(io, service).await?; + http.serve_connection(io, service).with_upgrades().await?; Ok(()) } @@ -625,7 +684,7 @@ impl ProxyClient { let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); let outbound_io = TokioIo::new(tls_stream); - let (sender, conn) = match http_version { + let (sender, conn): (HttpSender, HttpConnection) = match http_version { HttpVersion::Http2 => { let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) .timer(hyper_util::rt::tokio::TokioTimer::new()) @@ -634,13 +693,16 @@ impl ProxyClient { .keep_alive_while_idle(true) .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; - (sender.into(), conn.into()) + (sender.into(), Box::pin(conn) as HttpConnection) } HttpVersion::Http1 => { let (sender, conn) = hyper::client::conn::http1::Builder::new() .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; - (sender.into(), conn.into()) + ( + sender.into(), + Box::pin(conn.with_upgrades()) as HttpConnection, + ) } }; @@ -650,15 +712,52 @@ impl ProxyClient { // Handle a request from the source client to the proxy server async fn handle_http_request( - req: hyper::Request, + mut req: hyper::Request, requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { + let inbound_upgrade = hyper::upgrade::on(&mut req); let (response_tx, response_rx) = oneshot::channel(); requests_tx.send((req, response_tx)).await?; - Ok(response_rx.await??) + let mut resp = response_rx.await??; + + if resp.status() == hyper::StatusCode::SWITCHING_PROTOCOLS { + let outbound_upgrade = hyper::upgrade::on(&mut resp); + tokio::spawn(async move { + let inbound = match inbound_upgrade.await { + Ok(io) => io, + Err(e) => { + warn!("Inbound upgrade failed: {e}"); + return; + } + }; + let outbound = match outbound_upgrade.await { + Ok(io) => io, + Err(e) => { + warn!("Outbound upgrade failed: {e}"); + return; + } + }; + + if let Err(e) = tunnel_upgraded_streams(inbound, outbound).await { + warn!("Upgrade tunnel failed: {e}"); + } + }); + } + + Ok(resp) } } +async fn tunnel_upgraded_streams( + inbound: hyper::upgrade::Upgraded, + outbound: hyper::upgrade::Upgraded, +) -> Result<(), std::io::Error> { + let mut inbound = TokioIo::new(inbound); + let mut outbound = TokioIo::new(outbound); + let _ = tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await?; + Ok(()) +} + /// Update a request/response header if we are able to encode the header value /// /// This avoids bailing on bad header values - the headers are simply not updated @@ -755,6 +854,11 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, }; + #[cfg(feature = "ws")] + use futures_util::{SinkExt, StreamExt}; + #[cfg(feature = "ws")] + use tokio_tungstenite::{accept_async, connect_async, tungstenite::Message}; + // Server has mock DCAP, client has no attestation and no client auth #[tokio::test] async fn http_proxy_with_server_attestation() { @@ -787,6 +891,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, + ProxyClientHttpMode::Http2, ) .await .unwrap(); @@ -865,6 +970,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), Some(client_cert_chain), + ProxyClientHttpMode::Http2, ) .await .unwrap(); @@ -936,6 +1042,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), None, + ProxyClientHttpMode::Http2, ) .await .unwrap(); @@ -1017,6 +1124,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), Some(client_cert_chain), + ProxyClientHttpMode::Http2, ) .await .unwrap(); @@ -1152,6 +1260,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, + ProxyClientHttpMode::Http2, ) .await; @@ -1213,6 +1322,7 @@ mod tests { AttestationGenerator::with_no_attestation(), attestation_verifier, None, + ProxyClientHttpMode::Http2, ) .await; @@ -1268,6 +1378,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, + ProxyClientHttpMode::Http2, ) .await .unwrap(); @@ -1351,6 +1462,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, + ProxyClientHttpMode::Http1, ) .await .unwrap(); @@ -1383,4 +1495,68 @@ mod tests { let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } + + #[cfg(feature = "ws")] + #[tokio::test] + async fn http_proxy_websocket_upgrade() { + init_tracing(); + + let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = target_listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _addr) = target_listener.accept().await.unwrap(); + let mut ws = accept_async(stream).await.unwrap(); + + if let Some(Ok(Message::Text(text))) = ws.next().await { + ws.send(Message::Text(text)).await.unwrap(); + } + }); + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + ProxyClientHttpMode::Http1, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let (mut ws, _response) = connect_async(format!("ws://{}", proxy_client_addr)) + .await + .unwrap(); + + ws.send(Message::Text("ping".into())).await.unwrap(); + let msg = ws.next().await.unwrap().unwrap(); + assert_eq!(msg.to_text().unwrap(), "ping"); + } } diff --git a/src/main.rs b/src/main.rs index 67160e9..1d81e6f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, ensure}; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use std::{ fs::File, net::{IpAddr, SocketAddr}, @@ -17,6 +17,7 @@ use attested_tls_proxy::{ health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, AttestationGenerator, ProxyClient, ProxyServer, + ProxyClientHttpMode, }; #[derive(Parser, Debug, Clone)] @@ -76,6 +77,9 @@ enum CliCommand { /// Enables verification of self-signed TLS certificates #[arg(long)] allow_self_signed: bool, + /// HTTP mode for the proxy client: http1 supports WS upgrades, http2 does not + #[arg(long, value_enum, default_value = "http2")] + http_mode: ClientHttpMode, }, /// Run a proxy server Server { @@ -150,6 +154,12 @@ enum CliCommand { }, } +#[derive(Debug, Clone, Copy, ValueEnum)] +enum ClientHttpMode { + Http1, + Http2, +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); @@ -211,6 +221,7 @@ async fn main() -> anyhow::Result<()> { dev_dummy_dcap, listen_addr_healthcheck, allow_self_signed, + http_mode, } => { let target_addr = target_addr .strip_prefix("https://") @@ -249,6 +260,11 @@ async fn main() -> anyhow::Result<()> { AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) .await?; + let http_mode = match http_mode { + ClientHttpMode::Http1 => ProxyClientHttpMode::Http1, + ClientHttpMode::Http2 => ProxyClientHttpMode::Http2, + }; + let client = if allow_self_signed { let client_tls_config = attested_tls_proxy::self_signed::client_tls_config_allow_self_signed()?; @@ -259,6 +275,7 @@ async fn main() -> anyhow::Result<()> { client_attestation_generator, attestation_verifier, None, + http_mode, ) .await? } else { @@ -269,6 +286,7 @@ async fn main() -> anyhow::Result<()> { client_attestation_generator, attestation_verifier, remote_tls_cert, + http_mode, ) .await? }; From 063e0fe2cf403ece0dc8bb5e8887d4622287b698 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 6 Feb 2026 10:01:18 +0100 Subject: [PATCH 2/5] Dont add an extra enum for http version - use existing one --- src/attested_get.rs | 6 +++--- src/file_server.rs | 4 ++-- src/lib.rs | 37 ++++++++++++++----------------------- src/main.rs | 6 +++--- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/attested_get.rs b/src/attested_get.rs index 948e730..8e3b605 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -1,6 +1,6 @@ //! A one-shot attested TLS proxy client which sends a single GET request and returns the response use crate::{ - AttestationGenerator, AttestationVerifier, ProxyClient, ProxyClientHttpMode, ProxyError, + http_version::HttpVersion, AttestationGenerator, AttestationVerifier, ProxyClient, ProxyError, }; use tokio_rustls::rustls::pki_types::CertificateDer; @@ -19,7 +19,7 @@ pub async fn attested_get( AttestationGenerator::with_no_attestation(), attestation_verifier, remote_certificate, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await?; @@ -106,7 +106,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); diff --git a/src/file_server.rs b/src/file_server.rs index b8ff77f..3efc6c2 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -52,7 +52,7 @@ pub(crate) async fn static_file_server(path: PathBuf) -> Result, } -/// Controls which HTTP version the proxy client uses to talk to the proxy server. -#[derive(Clone, Copy, Debug)] -pub enum ProxyClientHttpMode { - /// Use HTTP/1.1 (supports WS upgrades). - Http1, - /// Use HTTP/2 (no WS upgrades). - Http2, -} - impl ProxyClient { /// Start with optional TLS client auth pub async fn new( @@ -370,7 +361,7 @@ impl ProxyClient { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, remote_certificate: Option>, - http_mode: ProxyClientHttpMode, + http_mode: HttpVersion, ) -> Result { let root_store = match remote_certificate { Some(remote_certificate) => { @@ -414,11 +405,11 @@ impl ProxyClient { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, - http_mode: ProxyClientHttpMode, + http_mode: HttpVersion, ) -> Result { client_config.alpn_protocols = match http_mode { - ProxyClientHttpMode::Http1 => vec![ALPN_HTTP11.to_vec()], - ProxyClientHttpMode::Http2 => vec![ALPN_H2.to_vec()], + HttpVersion::Http1 => vec![ALPN_HTTP11.to_vec()], + HttpVersion::Http2 => vec![ALPN_H2.to_vec()], }; let attested_tls_client = AttestedTlsClient::new_with_tls_config( @@ -891,7 +882,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); @@ -970,7 +961,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), Some(client_cert_chain), - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); @@ -1042,7 +1033,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); @@ -1124,7 +1115,7 @@ mod tests { AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), Some(client_cert_chain), - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); @@ -1260,7 +1251,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await; @@ -1322,7 +1313,7 @@ mod tests { AttestationGenerator::with_no_attestation(), attestation_verifier, None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await; @@ -1378,7 +1369,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, - ProxyClientHttpMode::Http2, + HttpVersion::Http2, ) .await .unwrap(); @@ -1462,7 +1453,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), None, - ProxyClientHttpMode::Http1, + HttpVersion::Http1, ) .await .unwrap(); @@ -1540,7 +1531,7 @@ mod tests { AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), None, - ProxyClientHttpMode::Http1, + HttpVersion::Http1, ) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index 1d81e6f..c03d8ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,9 +15,9 @@ use attested_tls_proxy::{ attested_tls::{get_tls_cert, TlsCertAndKey}, file_server::attested_file_server, health_check, + http_version::HttpVersion, normalize_pem::normalize_private_key_pem_to_pkcs8, AttestationGenerator, ProxyClient, ProxyServer, - ProxyClientHttpMode, }; #[derive(Parser, Debug, Clone)] @@ -261,8 +261,8 @@ async fn main() -> anyhow::Result<()> { .await?; let http_mode = match http_mode { - ClientHttpMode::Http1 => ProxyClientHttpMode::Http1, - ClientHttpMode::Http2 => ProxyClientHttpMode::Http2, + ClientHttpMode::Http1 => HttpVersion::Http1, + ClientHttpMode::Http2 => HttpVersion::Http2, }; let client = if allow_self_signed { From ae95b94e976f2a45c3ba4aa6c1fb6893b22e8035 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 16 Feb 2026 12:34:49 +0100 Subject: [PATCH 3/5] Rm unneeded deps --- Cargo.lock | 17 ----------------- Cargo.toml | 2 -- 2 files changed, 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9c61991..297de43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -538,7 +538,6 @@ dependencies = [ "http", "http-body-util", "hyper", - "hyper-tungstenite", "hyper-util", "jsonrpsee", "num-bigint", @@ -547,7 +546,6 @@ dependencies = [ "p256", "parity-scale-codec", "pem-rfc7468", - "pin-project-lite", "pkcs1", "pkcs8", "rand_core 0.6.4", @@ -1943,21 +1941,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-tungstenite" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc778da281a749ed28d2be73a9f2cd13030680a1574bc729debd1195e44f00e9" -dependencies = [ - "http-body-util", - "hyper", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - [[package]] name = "hyper-util" version = "0.1.17" diff --git a/Cargo.toml b/Cargo.toml index 12ce062..477f8df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,6 @@ p256 = { version = "0.13.2", features = ["pkcs8"] } pkcs1 = "0.7.5" pkcs8 = "0.10.2" rcgen = "0.14.5" -pin-project-lite = "0.2.16" # For Azure vTPM attestation az-tdx-vtpm = { version = "0.7.4", optional = true } @@ -68,7 +67,6 @@ alloy-transport-http = { version = "1.4.3", features = [ "hyper", ], optional = true } url = { version = "2.5.7", optional = true } -hyper-tungstenite = "0.19.0" [dev-dependencies] tempfile = "3.23.0" From 460d7d4a1c92cea98a3f850553862cc94f236c81 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 16 Feb 2026 12:47:29 +0100 Subject: [PATCH 4/5] HTTP2 client prefers http2 but still offers http1.1 fallback --- src/lib.rs | 3 ++- src/main.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5f4bb33..379307d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -409,7 +409,8 @@ impl ProxyClient { ) -> Result { client_config.alpn_protocols = match http_mode { HttpVersion::Http1 => vec![ALPN_HTTP11.to_vec()], - HttpVersion::Http2 => vec![ALPN_H2.to_vec()], + // Prefer HTTP/2 but allow HTTP/1.1 fallback if the server does not support h2. + HttpVersion::Http2 => vec![ALPN_H2.to_vec(), ALPN_HTTP11.to_vec()], }; let attested_tls_client = AttestedTlsClient::new_with_tls_config( diff --git a/src/main.rs b/src/main.rs index c03d8ba..97492c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -77,7 +77,7 @@ enum CliCommand { /// Enables verification of self-signed TLS certificates #[arg(long)] allow_self_signed: bool, - /// HTTP mode for the proxy client: http1 supports WS upgrades, http2 does not + /// HTTP mode for the proxy client: http1 is HTTP/1.1 only, http2 prefers HTTP/2 with HTTP/1.1 fallback #[arg(long, value_enum, default_value = "http2")] http_mode: ClientHttpMode, }, From 13d6693154935e49d15462de224fba3ea02f9215 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 16 Feb 2026 13:20:42 +0100 Subject: [PATCH 5/5] WS upgrades use a fresh connection to avoid meddling with other client connections --- src/lib.rs | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 379307d..49ba101 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -517,7 +517,9 @@ impl ProxyClient { ATTESTATION_TYPE_HEADER, remote_attestation_type.as_str(), ); - (Ok(resp.map(|b| b.boxed())), false) + let should_reconnect = + resp.status() == hyper::StatusCode::SWITCHING_PROTOCOLS; + (Ok(resp.map(|b| b.boxed())), should_reconnect) } Err(e) => { warn!("Failed to send request to proxy-server: {e}"); @@ -535,7 +537,7 @@ impl ProxyClient { if should_reconnect { // Leave the inner loop and continue on the reconnect loop - warn!("Reconnecting to proxy-server due to failed request"); + warn!("Reconnecting to proxy-server to rotate upstream connection"); break; } } else { @@ -849,6 +851,10 @@ mod tests { #[cfg(feature = "ws")] use futures_util::{SinkExt, StreamExt}; #[cfg(feature = "ws")] + use std::sync::atomic::{AtomicUsize, Ordering}; + #[cfg(feature = "ws")] + use std::sync::Arc; + #[cfg(feature = "ws")] use tokio_tungstenite::{accept_async, connect_async, tungstenite::Message}; // Server has mock DCAP, client has no attestation and no client auth @@ -871,7 +877,6 @@ mod tests { .unwrap(); let proxy_addr = proxy_server.local_addr().unwrap(); - tokio::spawn(async move { proxy_server.accept().await.unwrap(); }); @@ -1240,7 +1245,6 @@ mod tests { .unwrap(); let proxy_addr = proxy_server.local_addr().unwrap(); - tokio::spawn(async move { proxy_server.accept().await.unwrap(); }); @@ -1551,4 +1555,101 @@ mod tests { let msg = ws.next().await.unwrap().unwrap(); assert_eq!(msg.to_text().unwrap(), "ping"); } + + #[cfg(feature = "ws")] + #[tokio::test] + async fn http_proxy_websocket_upgrade_two_clients_concurrent() { + init_tracing(); + + let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = target_listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _addr) = target_listener.accept().await.unwrap(); + tokio::spawn(async move { + let mut ws = accept_async(stream).await.unwrap(); + while let Some(msg) = ws.next().await { + let msg = msg.unwrap(); + if msg.is_text() { + ws.send(msg).await.unwrap(); + } + } + }); + } + }); + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + let upstream_accept_count = Arc::new(AtomicUsize::new(0)); + let upstream_accept_count_clone = upstream_accept_count.clone(); + + tokio::spawn(async move { + loop { + match proxy_server.accept().await { + Ok(_handle) => { + upstream_accept_count_clone.fetch_add(1, Ordering::SeqCst); + } + Err(_) => break, + } + } + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + HttpVersion::Http1, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + tokio::spawn(async move { + loop { + if proxy_client.accept().await.is_err() { + break; + } + } + }); + + let (mut ws1, _response1) = connect_async(format!("ws://{}", proxy_client_addr)) + .await + .unwrap(); + ws1.send(Message::Text("ping-1".into())).await.unwrap(); + let msg1 = ws1.next().await.unwrap().unwrap(); + assert_eq!(msg1.to_text().unwrap(), "ping-1"); + + let (mut ws2, _response2) = connect_async(format!("ws://{}", proxy_client_addr)) + .await + .unwrap(); + ws2.send(Message::Text("ping-2".into())).await.unwrap(); + let msg2 = ws2.next().await.unwrap().unwrap(); + assert_eq!(msg2.to_text().unwrap(), "ping-2"); + + ws1.send(Message::Text("ping-after-second-ws".into())) + .await + .unwrap(); + let msg_after_second_ws = ws1.next().await.unwrap().unwrap(); + assert_eq!(msg_after_second_ws.to_text().unwrap(), "ping-after-second-ws"); + + tokio::time::sleep(Duration::from_millis(200)).await; + assert!(upstream_accept_count.load(Ordering::SeqCst) >= 2); + } }