Skip to content

Commit

Permalink
Add new_server_side_from_acceptor to allow init from a pre-existing A…
Browse files Browse the repository at this point in the history
…cceptor (#33)

In certain cases, a TLS server may need to pre-process the incoming bytes 
(eg: it may want to peek to see if this is raw HTTP, or HTTP over SSL) before
deciding to hand it off to a TlsStream. Because TlsStream uses raw
TcpSocket, there's no way to re-inject any sniffed bytes that were read()
rather than peek()'d from the socket.

This adds a new new_server_side_from_acceptor where an initial Acceptor
can be passed in. It may contain either the full handshake or just a few prefix
bytes that were sniffed from the socket earlier.

We test this by creating a new TLS pair where the socket sniffs the first 8
bytes of the client handshake before firing up its TlsStream.
  • Loading branch information
mmastrac authored Jan 27, 2025
1 parent 0fe165a commit ee5798c
Showing 1 changed file with 79 additions and 4 deletions.
83 changes: 79 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ impl TlsStream {
}

async fn accept(
mut acceptor: Acceptor,
tcp_handshake: &TcpStream,
server_config_provider: ServerConfigProvider,
) -> Result<ServerConnection, io::Error> {
let mut acceptor = Acceptor::default();
loop {
tcp_handshake.readable().await?;
// Stop if connection was closed by client
Expand Down Expand Up @@ -234,6 +234,7 @@ impl TlsStream {
}

fn new_server_acceptor(
acceptor: Acceptor,
tcp: TcpStream,
server_config_provider: ServerConfigProvider,
buffer_size: Option<NonZeroUsize>,
Expand All @@ -250,7 +251,8 @@ impl TlsStream {
let handshake_send = handshake.clone();

let handle = spawn(async move {
let tls = Self::accept(&tcp_handshake, server_config_provider).await;
let tls =
Self::accept(acceptor, &tcp_handshake, server_config_provider).await;
let res = send_handshake(
tcp_handshake,
tls.map(rustls::Connection::Server),
Expand Down Expand Up @@ -352,6 +354,28 @@ impl TlsStream {
buffer_size: Option<NonZeroUsize>,
) -> Self {
Self::new_server_acceptor(
Acceptor::default(),
tcp,
server_config_provider,
buffer_size,
TestOptions::default(),
)
}

/// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically
/// based on the [`ClientHello`] message. This may be used to provide a different server
/// certificate or ALPN configuration depending on the requested hostname.
///
/// This allows the caller to provide an [`Acceptor`] which may be non-default in some
/// way, perhaps stuffed with prefix bytes or a full handshake to emulate.
pub fn new_server_side_from_acceptor(
acceptor: Acceptor,
tcp: TcpStream,
server_config_provider: ServerConfigProvider,
buffer_size: Option<NonZeroUsize>,
) -> Self {
Self::new_server_acceptor(
acceptor,
tcp,
server_config_provider,
buffer_size,
Expand Down Expand Up @@ -1353,6 +1377,53 @@ pub(super) mod tests {
(server, client)
}

async fn tls_pair_alpn_from_acceptor(
server_alpn: fn(
ClientHello,
) -> Result<&'static [&'static str], &'static str>,
server_buffer_size: Option<NonZeroUsize>,
client_alpn: &[&str],
client_buffer_size: Option<NonZeroUsize>,
) -> (TlsStream, TlsStream) {
let (mut server, client) = tcp_pair().await;

// Create the client first because we need the ClientHello. This will
// boot the client's handshake task and write to the socket.
let client = TlsStream::new_client_side_test_options(
client,
client_config(client_alpn).into(),
"example.com".try_into().unwrap(),
client_buffer_size,
TestOptions::default(),
);

// Read 8 bytes from the start of the server connection and then
// feed them to an Acceptor. Pass that acceptor when we create the
// TlsStream which will populate the rest of the ClientHello and
// properly handshake.
let mut prefix = [0; 8];
server
.read_exact(&mut prefix)
.await
.expect("Failed to read prefix");
let mut acceptor = Acceptor::default();
assert_eq!(
acceptor.read_tls(&mut prefix.as_slice()).unwrap(),
prefix.len()
);

let server = TlsStream::new_server_side_from_acceptor(
acceptor,
server,
Arc::new(move |client_hello| {
Box::pin(make_config(server_alpn(client_hello)))
}),
server_buffer_size,
);

(server, client)
}

async fn tls_pair_handshake_buffer_size(
server_buffer_size: Option<NonZeroUsize>,
client_buffer_size: Option<NonZeroUsize>,
Expand Down Expand Up @@ -1485,9 +1556,13 @@ pub(super) mod tests {
#[ntest::timeout(60000)]
async fn test_client_server_alpn_acceptor(
#[case] alpn: &'static str,
#[values(true, false)] use_from: bool,
) -> TestResult {
let (mut server, mut client) =
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await;
let (mut server, mut client) = if use_from {
tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await
} else {
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await
};
let a = spawn(async move {
if alpn == "c" {
server.handshake().await.expect_err("expected failure");
Expand Down

0 comments on commit ee5798c

Please sign in to comment.