From f7dcf2d89bbb1d1849d628ad79cfdccc875a6b04 Mon Sep 17 00:00:00 2001 From: Joe Parks Date: Thu, 12 Mar 2026 12:18:29 -0700 Subject: [PATCH] feat: add custom socket transport support for Postgres and MySQL Add `connect_socket()` methods to `PgConnection` and `MySqlConnection` that accept any pre-connected socket implementing the `Socket` trait. This enables using custom transport layers (e.g., vsock for AWS Nitro Enclaves, QUIC, or other non-TCP/UDS transports) without forking sqlx. Re-export `Socket` and `ReadBuf` traits from `sqlx::net` so users can implement custom socket types. --- sqlx-core/src/net/mod.rs | 3 ++- sqlx-core/src/net/socket/mod.rs | 14 ++++++++++ sqlx-mysql/src/connection/establish.rs | 18 +++++++++++-- sqlx-mysql/src/connection/mod.rs | 32 ++++++++++++++++++++++ sqlx-mysql/src/options/connect.rs | 33 ++++++++++++++--------- sqlx-postgres/src/connection/establish.rs | 18 +++++++++++-- sqlx-postgres/src/connection/mod.rs | 30 +++++++++++++++++++++ sqlx-postgres/src/connection/stream.rs | 16 +++++++++++ src/lib.rs | 6 +++++ 9 files changed, 152 insertions(+), 18 deletions(-) diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..cdf3687c44 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -2,5 +2,6 @@ mod socket; pub mod tls; pub use socket::{ - connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, + connect_socket, connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, + WriteBuffer, }; diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0f9aae61b4..72e4fb304c 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -202,6 +202,20 @@ pub async fn connect_tcp( } } +/// Connect using a pre-connected socket that implements [`Socket`]. +/// +/// This allows using custom transport layers (e.g., vsock, QUIC, or any +/// `AsyncRead + AsyncWrite` type) with SQLx database connections. +/// +/// The socket will be passed through the `with_socket` handler, which +/// typically performs TLS upgrade negotiation. +pub async fn connect_socket( + socket: S, + with_socket: Ws, +) -> crate::Result { + Ok(with_socket.with_socket(socket).await) +} + /// Open a TCP socket to `host` and `port`. /// /// If `host` is a hostname, attempt to connect to each address it resolves to. diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..c67b6690f2 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -22,7 +22,21 @@ impl MySqlConnection { let stream = handshake?; - Ok(Self { + Ok(Self::establish_with_stream(stream, options)) + } + + pub(crate) async fn establish_with_socket( + socket: S, + options: &MySqlConnectOptions, + ) -> Result { + let do_handshake = DoHandshake::new(options)?; + let stream = do_handshake.with_socket(socket).await?; + + Ok(Self::establish_with_stream(stream, options)) + } + + fn establish_with_stream(stream: MySqlStream, options: &MySqlConnectOptions) -> Self { + Self { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, @@ -30,7 +44,7 @@ impl MySqlConnection { cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), - }) + } } } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 569ad32722..cb21c4ee77 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -52,6 +52,38 @@ pub(crate) struct MySqlConnectionInner { } impl MySqlConnection { + /// Connect to a MySQL database using a pre-connected socket. + /// + /// This allows using custom transport layers such as vsock, QUIC, + /// or any type that implements [`sqlx_core::net::Socket`]. + /// + /// The provided socket will go through TLS upgrade negotiation based on the + /// SSL mode configured in `options`. + /// + /// # Example + /// + /// ```rust,ignore + /// use sqlx::mysql::{MySqlConnectOptions, MySqlConnection}; + /// + /// # async fn example() -> sqlx::Result<()> { + /// let socket: tokio::net::TcpStream = todo!(); + /// let options = MySqlConnectOptions::new() + /// .username("root") + /// .database("mydb"); + /// + /// let _conn = MySqlConnection::connect_socket(socket, &options).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect_socket( + socket: S, + options: &MySqlConnectOptions, + ) -> Result { + let mut conn = Self::establish_with_socket(socket, options).await?; + options.configure_session(&mut conn).await?; + Ok(conn) + } + pub(crate) fn in_transaction(&self) -> bool { self.inner .status_flags diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index f3b0492781..86ce390ad5 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -24,9 +24,26 @@ impl ConnectOptions for MySqlConnectOptions { { let mut conn = MySqlConnection::establish(self).await?; - // After the connection is established, we initialize by configuring a few - // connection parameters + self.configure_session(&mut conn).await?; + Ok(conn) + } + + fn log_statements(mut self, level: LevelFilter) -> Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} + +impl MySqlConnectOptions { + /// After the connection is established, initialize by configuring + /// connection parameters (sql_mode, time_zone, charset). + pub(crate) async fn configure_session(&self, conn: &mut MySqlConnection) -> Result<(), Error> { // https://mariadb.com/kb/en/sql-mode/ // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. @@ -88,16 +105,6 @@ impl ConnectOptions for MySqlConnectOptions { .await?; } - Ok(conn) - } - - fn log_statements(mut self, level: LevelFilter) -> Self { - self.log_settings.log_statements(level); - self - } - - fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { - self.log_settings.log_slow_statements(level, duration); - self + Ok(()) } } diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 634b71de4b..3cdbdcad53 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -7,6 +7,7 @@ use crate::io::StatementId; use crate::message::{ Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, }; +use crate::net::Socket; use crate::{PgConnectOptions, PgConnection}; use super::PgConnectionInner; @@ -16,9 +17,22 @@ use super::PgConnectionInner; impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { - // Upgrade to TLS if we were asked to and the server supports it - let mut stream = PgStream::connect(options).await?; + let stream = PgStream::connect(options).await?; + Self::establish_with_stream(stream, options).await + } + + pub(crate) async fn establish_with_socket( + socket: S, + options: &PgConnectOptions, + ) -> Result { + let stream = PgStream::connect_socket(socket, options).await?; + Self::establish_with_stream(stream, options).await + } + async fn establish_with_stream( + mut stream: PgStream, + options: &PgConnectOptions, + ) -> Result { // To begin a session, a frontend opens a connection to the server // and sends a startup message. diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index d5db20ad05..fcefde7a8c 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -85,6 +85,36 @@ pub(crate) struct TableColumns { } impl PgConnection { + /// Connect to a PostgreSQL database using a pre-connected socket. + /// + /// This allows using custom transport layers such as vsock, QUIC, + /// or any type that implements [`sqlx_core::net::Socket`]. + /// + /// The provided socket will go through TLS upgrade negotiation based on the + /// SSL mode configured in `options`. + /// + /// # Example + /// + /// ```rust,ignore + /// use sqlx::postgres::{PgConnectOptions, PgConnection}; + /// + /// # async fn example() -> sqlx::Result<()> { + /// let socket: tokio::net::TcpStream = todo!(); + /// let options = PgConnectOptions::new() + /// .username("postgres") + /// .database("mydb"); + /// + /// let _conn = PgConnection::connect_socket(socket, &options).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect_socket( + socket: S, + options: &PgConnectOptions, + ) -> Result { + Self::establish_with_socket(socket, options).await + } + /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { self.inner.stream.server_version_num diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index e8a1aedc47..8921fe79d3 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -57,6 +57,22 @@ impl PgStream { }) } + pub(super) async fn connect_socket( + socket: S, + options: &PgConnectOptions, + ) -> Result { + let socket = net::connect_socket(socket, MaybeUpgradeTls(options)).await?; + + let socket = socket?; + + Ok(Self { + inner: BufferedSocket::new(socket), + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } + #[inline(always)] pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { self.write(EncodeMessage(message)) diff --git a/src/lib.rs b/src/lib.rs index 438463210d..56d23dd73f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -153,6 +153,12 @@ pub mod decode { pub use self::decode::Decode; +/// Networking traits for custom transport implementations. +pub mod net { + pub use sqlx_core::io::ReadBuf; + pub use sqlx_core::net::Socket; +} + /// Types and traits for the `query` family of functions and macros. pub mod query { pub use sqlx_core::query::{Map, Query};