From aa7d29ee22b91c49ee3ecb3b0a0232f8c8a82ba0 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 27 Jan 2025 15:05:49 -0500 Subject: [PATCH] Add new_server_side_from_acceptor to allow init from a pre-existing Acceptor --- src/stream.rs | 83 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 4 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index df8af43..2848201 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -175,10 +175,10 @@ impl TlsStream { } async fn accept( + mut acceptor: Acceptor, tcp_handshake: &TcpStream, server_config_provider: ServerConfigProvider, ) -> Result { - let mut acceptor = Acceptor::default(); loop { tcp_handshake.readable().await?; // Stop if connection was closed by client @@ -234,6 +234,7 @@ impl TlsStream { } fn new_server_acceptor( + acceptor: Acceptor, tcp: TcpStream, server_config_provider: ServerConfigProvider, buffer_size: Option, @@ -250,7 +251,8 @@ impl TlsStream { let handshake_send = handshake.clone(); let handle = spawn(async move { - let tls = Self::accept(&tcp_handshake, server_config_provider).await; + let tls = + Self::accept(acceptor, &tcp_handshake, server_config_provider).await; let res = send_handshake( tcp_handshake, tls.map(rustls::Connection::Server), @@ -352,6 +354,28 @@ impl TlsStream { buffer_size: Option, ) -> Self { Self::new_server_acceptor( + Acceptor::default(), + tcp, + server_config_provider, + buffer_size, + TestOptions::default(), + ) + } + + /// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically + /// based on the [`ClientHello`] message. This may be used to provide a different server + /// certificate or ALPN configuration depending on the requested hostname. + /// + /// This allows the caller to provide an [`Acceptor`] which may be non-default in some + /// way, perhaps stuffed with prefix bytes or a full handshake to emulate. + pub fn new_server_side_from_acceptor( + acceptor: Acceptor, + tcp: TcpStream, + server_config_provider: ServerConfigProvider, + buffer_size: Option, + ) -> Self { + Self::new_server_acceptor( + acceptor, tcp, server_config_provider, buffer_size, @@ -1353,6 +1377,53 @@ pub(super) mod tests { (server, client) } + async fn tls_pair_alpn_from_acceptor( + server_alpn: fn( + ClientHello, + ) -> Result<&'static [&'static str], &'static str>, + server_buffer_size: Option, + client_alpn: &[&str], + client_buffer_size: Option, + ) -> (TlsStream, TlsStream) { + let (mut server, client) = tcp_pair().await; + + // Create the client first because we need the ClientHello. This will + // boot the client's handshake task and write to the socket. + let client = TlsStream::new_client_side_test_options( + client, + client_config(client_alpn).into(), + "example.com".try_into().unwrap(), + client_buffer_size, + TestOptions::default(), + ); + + // Read 8 bytes from the start of the server connection and then + // feed them to an Acceptor. Pass that acceptor when we create the + // TlsStream which will populate the rest of the ClientHello and + // properly handshake. + let mut prefix = [0; 8]; + server + .read_exact(&mut prefix) + .await + .expect("Failed to read prefix"); + let mut acceptor = Acceptor::default(); + assert_eq!( + acceptor.read_tls(&mut prefix.as_slice()).unwrap(), + prefix.len() + ); + + let server = TlsStream::new_server_side_from_acceptor( + acceptor, + server, + Arc::new(move |client_hello| { + Box::pin(make_config(server_alpn(client_hello))) + }), + server_buffer_size, + ); + + (server, client) + } + async fn tls_pair_handshake_buffer_size( server_buffer_size: Option, client_buffer_size: Option, @@ -1485,9 +1556,13 @@ pub(super) mod tests { #[ntest::timeout(60000)] async fn test_client_server_alpn_acceptor( #[case] alpn: &'static str, + #[values(true, false)] use_from: bool, ) -> TestResult { - let (mut server, mut client) = - tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await; + let (mut server, mut client) = if use_from { + tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await + } else { + tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await + }; let a = spawn(async move { if alpn == "c" { server.handshake().await.expect_err("expected failure");