Skip to content

Commit aa7d29e

Browse files
committed
Add new_server_side_from_acceptor to allow init from a pre-existing Acceptor
1 parent 0fe165a commit aa7d29e

File tree

1 file changed

+79
-4
lines changed

1 file changed

+79
-4
lines changed

src/stream.rs

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ impl TlsStream {
175175
}
176176

177177
async fn accept(
178+
mut acceptor: Acceptor,
178179
tcp_handshake: &TcpStream,
179180
server_config_provider: ServerConfigProvider,
180181
) -> Result<ServerConnection, io::Error> {
181-
let mut acceptor = Acceptor::default();
182182
loop {
183183
tcp_handshake.readable().await?;
184184
// Stop if connection was closed by client
@@ -234,6 +234,7 @@ impl TlsStream {
234234
}
235235

236236
fn new_server_acceptor(
237+
acceptor: Acceptor,
237238
tcp: TcpStream,
238239
server_config_provider: ServerConfigProvider,
239240
buffer_size: Option<NonZeroUsize>,
@@ -250,7 +251,8 @@ impl TlsStream {
250251
let handshake_send = handshake.clone();
251252

252253
let handle = spawn(async move {
253-
let tls = Self::accept(&tcp_handshake, server_config_provider).await;
254+
let tls =
255+
Self::accept(acceptor, &tcp_handshake, server_config_provider).await;
254256
let res = send_handshake(
255257
tcp_handshake,
256258
tls.map(rustls::Connection::Server),
@@ -352,6 +354,28 @@ impl TlsStream {
352354
buffer_size: Option<NonZeroUsize>,
353355
) -> Self {
354356
Self::new_server_acceptor(
357+
Acceptor::default(),
358+
tcp,
359+
server_config_provider,
360+
buffer_size,
361+
TestOptions::default(),
362+
)
363+
}
364+
365+
/// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically
366+
/// based on the [`ClientHello`] message. This may be used to provide a different server
367+
/// certificate or ALPN configuration depending on the requested hostname.
368+
///
369+
/// This allows the caller to provide an [`Acceptor`] which may be non-default in some
370+
/// way, perhaps stuffed with prefix bytes or a full handshake to emulate.
371+
pub fn new_server_side_from_acceptor(
372+
acceptor: Acceptor,
373+
tcp: TcpStream,
374+
server_config_provider: ServerConfigProvider,
375+
buffer_size: Option<NonZeroUsize>,
376+
) -> Self {
377+
Self::new_server_acceptor(
378+
acceptor,
355379
tcp,
356380
server_config_provider,
357381
buffer_size,
@@ -1353,6 +1377,53 @@ pub(super) mod tests {
13531377
(server, client)
13541378
}
13551379

1380+
async fn tls_pair_alpn_from_acceptor(
1381+
server_alpn: fn(
1382+
ClientHello,
1383+
) -> Result<&'static [&'static str], &'static str>,
1384+
server_buffer_size: Option<NonZeroUsize>,
1385+
client_alpn: &[&str],
1386+
client_buffer_size: Option<NonZeroUsize>,
1387+
) -> (TlsStream, TlsStream) {
1388+
let (mut server, client) = tcp_pair().await;
1389+
1390+
// Create the client first because we need the ClientHello. This will
1391+
// boot the client's handshake task and write to the socket.
1392+
let client = TlsStream::new_client_side_test_options(
1393+
client,
1394+
client_config(client_alpn).into(),
1395+
"example.com".try_into().unwrap(),
1396+
client_buffer_size,
1397+
TestOptions::default(),
1398+
);
1399+
1400+
// Read 8 bytes from the start of the server connection and then
1401+
// feed them to an Acceptor. Pass that acceptor when we create the
1402+
// TlsStream which will populate the rest of the ClientHello and
1403+
// properly handshake.
1404+
let mut prefix = [0; 8];
1405+
server
1406+
.read_exact(&mut prefix)
1407+
.await
1408+
.expect("Failed to read prefix");
1409+
let mut acceptor = Acceptor::default();
1410+
assert_eq!(
1411+
acceptor.read_tls(&mut prefix.as_slice()).unwrap(),
1412+
prefix.len()
1413+
);
1414+
1415+
let server = TlsStream::new_server_side_from_acceptor(
1416+
acceptor,
1417+
server,
1418+
Arc::new(move |client_hello| {
1419+
Box::pin(make_config(server_alpn(client_hello)))
1420+
}),
1421+
server_buffer_size,
1422+
);
1423+
1424+
(server, client)
1425+
}
1426+
13561427
async fn tls_pair_handshake_buffer_size(
13571428
server_buffer_size: Option<NonZeroUsize>,
13581429
client_buffer_size: Option<NonZeroUsize>,
@@ -1485,9 +1556,13 @@ pub(super) mod tests {
14851556
#[ntest::timeout(60000)]
14861557
async fn test_client_server_alpn_acceptor(
14871558
#[case] alpn: &'static str,
1559+
#[values(true, false)] use_from: bool,
14881560
) -> TestResult {
1489-
let (mut server, mut client) =
1490-
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await;
1561+
let (mut server, mut client) = if use_from {
1562+
tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await
1563+
} else {
1564+
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await
1565+
};
14911566
let a = spawn(async move {
14921567
if alpn == "c" {
14931568
server.handshake().await.expect_err("expected failure");

0 commit comments

Comments
 (0)