Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new_server_side_from_acceptor to allow init from a pre-existing Acceptor #33

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add new_server_side_from_acceptor to allow init from a pre-existing A…
…cceptor
mmastrac committed Jan 27, 2025
commit aa7d29ee22b91c49ee3ecb3b0a0232f8c8a82ba0
83 changes: 79 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
@@ -175,10 +175,10 @@ impl TlsStream {
}

async fn accept(
mut acceptor: Acceptor,
tcp_handshake: &TcpStream,
server_config_provider: ServerConfigProvider,
) -> Result<ServerConnection, io::Error> {
let mut acceptor = Acceptor::default();
loop {
tcp_handshake.readable().await?;
// Stop if connection was closed by client
@@ -234,6 +234,7 @@ impl TlsStream {
}

fn new_server_acceptor(
acceptor: Acceptor,
tcp: TcpStream,
server_config_provider: ServerConfigProvider,
buffer_size: Option<NonZeroUsize>,
@@ -250,7 +251,8 @@ impl TlsStream {
let handshake_send = handshake.clone();

let handle = spawn(async move {
let tls = Self::accept(&tcp_handshake, server_config_provider).await;
let tls =
Self::accept(acceptor, &tcp_handshake, server_config_provider).await;
let res = send_handshake(
tcp_handshake,
tls.map(rustls::Connection::Server),
@@ -352,6 +354,28 @@ impl TlsStream {
buffer_size: Option<NonZeroUsize>,
) -> Self {
Self::new_server_acceptor(
Acceptor::default(),
tcp,
server_config_provider,
buffer_size,
TestOptions::default(),
)
}

/// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically
/// based on the [`ClientHello`] message. This may be used to provide a different server
/// certificate or ALPN configuration depending on the requested hostname.
///
/// This allows the caller to provide an [`Acceptor`] which may be non-default in some
/// way, perhaps stuffed with prefix bytes or a full handshake to emulate.
pub fn new_server_side_from_acceptor(
acceptor: Acceptor,
tcp: TcpStream,
server_config_provider: ServerConfigProvider,
buffer_size: Option<NonZeroUsize>,
) -> Self {
Self::new_server_acceptor(
acceptor,
tcp,
server_config_provider,
buffer_size,
@@ -1353,6 +1377,53 @@ pub(super) mod tests {
(server, client)
}

async fn tls_pair_alpn_from_acceptor(
server_alpn: fn(
ClientHello,
) -> Result<&'static [&'static str], &'static str>,
server_buffer_size: Option<NonZeroUsize>,
client_alpn: &[&str],
client_buffer_size: Option<NonZeroUsize>,
) -> (TlsStream, TlsStream) {
let (mut server, client) = tcp_pair().await;

// Create the client first because we need the ClientHello. This will
// boot the client's handshake task and write to the socket.
let client = TlsStream::new_client_side_test_options(
client,
client_config(client_alpn).into(),
"example.com".try_into().unwrap(),
client_buffer_size,
TestOptions::default(),
);

// Read 8 bytes from the start of the server connection and then
// feed them to an Acceptor. Pass that acceptor when we create the
// TlsStream which will populate the rest of the ClientHello and
// properly handshake.
let mut prefix = [0; 8];
server
.read_exact(&mut prefix)
.await
.expect("Failed to read prefix");
let mut acceptor = Acceptor::default();
assert_eq!(
acceptor.read_tls(&mut prefix.as_slice()).unwrap(),
prefix.len()
);

let server = TlsStream::new_server_side_from_acceptor(
acceptor,
server,
Arc::new(move |client_hello| {
Box::pin(make_config(server_alpn(client_hello)))
}),
server_buffer_size,
);

(server, client)
}

async fn tls_pair_handshake_buffer_size(
server_buffer_size: Option<NonZeroUsize>,
client_buffer_size: Option<NonZeroUsize>,
@@ -1485,9 +1556,13 @@ pub(super) mod tests {
#[ntest::timeout(60000)]
async fn test_client_server_alpn_acceptor(
#[case] alpn: &'static str,
#[values(true, false)] use_from: bool,
) -> TestResult {
let (mut server, mut client) =
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await;
let (mut server, mut client) = if use_from {
tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await
} else {
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await
};
let a = spawn(async move {
if alpn == "c" {
server.handshake().await.expect_err("expected failure");