From 3c01ee70d781d83e7dca8990071919c28e364a7e Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 27 Jan 2025 16:43:54 -0500 Subject: [PATCH] Fix race condition between handshake and waker registration --- src/stream.rs | 201 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 134 insertions(+), 67 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index 2848201..2ba6440 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -24,14 +24,12 @@ use rustls::ClientConnection; use rustls::Connection; use rustls::ServerConfig; use rustls::ServerConnection; -use std::cell::Cell; use std::fmt::Debug; use std::io; use std::io::ErrorKind; use std::io::Write; use std::num::NonZeroUsize; use std::pin::Pin; -use std::rc::Rc; use std::sync::Arc; use std::sync::Mutex; use std::task::ready; @@ -46,19 +44,66 @@ use tokio::task::spawn_blocking; use tokio::task::JoinError; use tokio::task::JoinHandle; -#[derive(Clone, Default)] -struct SharedWaker(Rc>>); -unsafe impl Send for SharedWaker {} +/// The handshake may block read and write operations and requires us to track +/// which wakers are pending so that we can wake them to re-poll their +/// operations after the handshake completes. +#[derive(Clone)] +struct DeferredWakers { + wakers: Arc>, +} + +#[derive(Default)] +enum DeferredWakersInner { + /// If the deferred wakers have been woken already, we don't want + /// to re-register them and instead just wake them in place to + /// prevent races. + #[default] + Woke, + /// No deferred wakers have been woken. + Pending(Option, Option), +} -impl SharedWaker { +impl DeferredWakers { pub fn wake(&self) { - if let Some(waker) = self.0.take() { - waker.wake(); + match std::mem::take(&mut *self.wakers.lock().unwrap()) { + DeferredWakersInner::Pending(mut read, mut write) => { + if let Some(read) = read.take() { + read.wake(); + } + if let Some(write) = write.take() { + write.wake(); + } + } + DeferredWakersInner::Woke => {} + } + } + + /// Register the read waker if pending, or wake immediately if the deferred wakers have been woken. + pub fn set_read_waker(&self, waker: &Waker) { + let mut lock = self.wakers.lock().unwrap(); + match &mut *lock { + DeferredWakersInner::Pending(read, _write) => *read = Some(waker.clone()), + DeferredWakersInner::Woke => waker.wake_by_ref(), + } + } + + /// Register the write waker if pending, or wake immediately if the deferred wakers have been woken. + pub fn set_write_waker(&self, waker: &Waker) { + let mut lock = self.wakers.lock().unwrap(); + match &mut *lock { + DeferredWakersInner::Pending(_read, write) => { + *write = Some(waker.clone()) + } + DeferredWakersInner::Woke => waker.wake_by_ref(), } } +} - pub fn set_waker(&self, waker: &Waker) { - self.0.set(Some(waker.clone())) +impl Default for DeferredWakers { + fn default() -> Self { + Self { + wakers: Arc::new(Mutex::new(DeferredWakersInner::Pending(None, None))), + } } } @@ -74,8 +119,7 @@ enum TlsStreamState { // TODO(mmastrac): We should be buffered in the Connection, not the Vec, as this results in a double-copy. Handshaking { handle: JoinHandle>, - read_waker: SharedWaker, - write_waker: SharedWaker, + wakers: DeferredWakers, write_buf: Vec, tcp: Arc, }, @@ -141,10 +185,8 @@ impl TlsStream { ) -> Self { tls.set_buffer_limit(buffer_size.map(|s| s.get())); let handshake = Arc::new(HandshakeWatch::default()); - let read_waker = SharedWaker::default(); - let write_waker = SharedWaker::default(); - let read_waker_clone = read_waker.clone(); - let write_waker_clone = write_waker.clone(); + let wakers = DeferredWakers::default(); + let wakers_clone = wakers.clone(); let tcp = Arc::new(tcp); let tcp_handshake = tcp.clone(); @@ -155,8 +197,7 @@ impl TlsStream { .await; // We may have read/writes blocked on the handshake, so wake them all up - read_waker_clone.wake(); - write_waker_clone.wake(); + wakers_clone.wake(); res }); @@ -164,8 +205,7 @@ impl TlsStream { Self { state: TlsStreamState::Handshaking { handle, - read_waker, - write_waker, + wakers, write_buf: vec![], tcp, }, @@ -241,10 +281,8 @@ impl TlsStream { test_options: TestOptions, ) -> Self { let handshake = Arc::new(HandshakeWatch::default()); - let read_waker = SharedWaker::default(); - let write_waker = SharedWaker::default(); - let read_waker_clone = read_waker.clone(); - let write_waker_clone = write_waker.clone(); + let wakers = DeferredWakers::default(); + let wakers_clone = wakers.clone(); let tcp = Arc::new(tcp); let tcp_handshake = tcp.clone(); @@ -262,8 +300,7 @@ impl TlsStream { .await; // We may have read/writes blocked on the handshake, so wake them all up - read_waker_clone.wake(); - write_waker_clone.wake(); + wakers_clone.wake(); res }); @@ -271,8 +308,7 @@ impl TlsStream { Self { state: TlsStreamState::Handshaking { handle, - read_waker, - write_waker, + wakers, write_buf: vec![], tcp, }, @@ -520,8 +556,7 @@ impl TlsStream { trace!("finalize handshake"); match std::mem::replace(&mut self.state, TlsStreamState::Closed) { TlsStreamState::Handshaking { - read_waker, - write_waker, + wakers, write_buf: buf, .. } => { @@ -549,8 +584,7 @@ impl TlsStream { // We need to save all the data we wrote before the connection. The stream has an internal buffer // that matches our buffer, so it can accept it all. stm.write_buf_fully(&buf); - read_waker.wake(); - write_waker.wake(); + wakers.wake(); self.state = TlsStreamState::Open(stm); Ok(()) } @@ -621,13 +655,11 @@ impl TlsStream { match state { TlsStreamState::Handshaking { handle, - read_waker, - write_waker, + wakers, write_buf: buf, .. } => { - read_waker.wake(); - write_waker.wake(); + wakers.wake(); match handle.await { Ok(Ok(result)) => { // TODO(mmastrac): if we split ConnectionStream we can remove this Arc and use reclaim2 @@ -748,9 +780,7 @@ impl AsyncRead for TlsStream { ) -> Poll> { loop { break match &mut self.state { - TlsStreamState::Handshaking { - handle, read_waker, .. - } => { + TlsStreamState::Handshaking { handle, wakers, .. } => { // If the handshake completed, we want to finalize it and then continue if handle.is_finished() { // This may return Pending if we've exhausted the co-op budget @@ -760,7 +790,8 @@ impl AsyncRead for TlsStream { } // Handshake is still blocking us - read_waker.set_waker(cx.waker()); + wakers.set_read_waker(cx.waker()); + Poll::Pending } TlsStreamState::Open(ref mut stm) => { @@ -791,7 +822,7 @@ impl AsyncWrite for TlsStream { break match &mut self.state { TlsStreamState::Handshaking { handle, - write_waker, + wakers, write_buf, .. } => { @@ -802,11 +833,12 @@ impl AsyncWrite for TlsStream { self.finalize_handshake(res)?; continue; } + if let Some(buffer_size) = buffer_size { let remaining = buffer_size.get() - write_buf.len(); if remaining == 0 { // No room to write, so store the waker for whenever the handshake is done - write_waker.set_waker(cx.waker()); + wakers.set_write_waker(cx.waker()); trace!("write limit"); Poll::Pending } else { @@ -845,7 +877,7 @@ impl AsyncWrite for TlsStream { break match &mut self.state { TlsStreamState::Handshaking { handle, - write_waker, + wakers, write_buf, .. } => { @@ -860,7 +892,7 @@ impl AsyncWrite for TlsStream { let mut remaining = buffer_size.get() - write_buf.len(); if remaining == 0 { // No room to write, so store the waker for whenever the handshake is done - write_waker.set_waker(cx.waker()); + wakers.set_write_waker(cx.waker()); trace!("write limit"); Poll::Pending } else { @@ -900,11 +932,7 @@ impl AsyncWrite for TlsStream { ) -> Poll> { loop { break match &mut self.state { - TlsStreamState::Handshaking { - write_waker, - handle, - .. - } => { + TlsStreamState::Handshaking { wakers, handle, .. } => { // If the handshake completed, we want to finalize it and then continue if handle.is_finished() { // This may return Pending if we've exhausted the co-op budget @@ -913,7 +941,7 @@ impl AsyncWrite for TlsStream { continue; } - write_waker.set_waker(cx.waker()); + wakers.set_write_waker(cx.waker()); Poll::Pending } TlsStreamState::Open(stm) => stm.poll_flush(cx), @@ -1287,6 +1315,7 @@ pub(super) mod tests { delay_handshake: bool, slow_server: bool, slow_client: bool, + buffer: bool, ) -> (TlsStream, TlsStream) { let (server, client) = tcp_pair().await; let server_test_options = TestOptions { @@ -1299,17 +1328,23 @@ pub(super) mod tests { slow_handshake_read: slow_client, slow_handshake_write: slow_client, }; + let buffer_size = if buffer { + NonZeroUsize::new(1024) + } else { + None + }; + let server = TlsStream::new_server_side_test_options( server, server_config(&[]).into(), - None, + buffer_size, server_test_options, ); let client = TlsStream::new_client_side_test_options( client, client_config(&[]).into(), "example.com".try_into().unwrap(), - None, + buffer_size, client_test_options, ); @@ -1463,17 +1498,14 @@ pub(super) mod tests { /// Test that automatic state transition works: send and receive work as expected without waiting /// for the handshake #[rstest] - #[case(false, false)] - #[case(false, true)] - #[case(true, false)] - #[case(true, true)] #[tokio::test] async fn test_client_server( - #[case] server_slow: bool, - #[case] client_slow: bool, + #[values(true, false)] server_slow: bool, + #[values(true, false)] client_slow: bool, + #[values(true, false)] buffer: bool, ) -> TestResult { let (mut server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, buffer).await; let a = spawn(async move { server.write_all(b"hello?").await.unwrap(); let mut buf = [0; 6]; @@ -1503,6 +1535,40 @@ pub(super) mod tests { Ok(()) } + #[rstest] + #[tokio::test(flavor = "multi_thread")] + #[ntest::timeout(60000)] + async fn test_read_with_buffered_write( + #[values(true, false)] delay_handshake: bool, + #[values(true, false)] slow_server: bool, + #[values(true, false)] slow_client: bool, + #[values(true, false)] buffer: bool, + ) -> TestResult { + let (mut server, mut client) = tls_pair_slow_handshake( + delay_handshake, + slow_server, + slow_client, + buffer, + ) + .await; + + let a = tokio::task::spawn(async move { + server.read_u8().await.unwrap(); + server.write_u8(1).await.unwrap(); + }); + + let b = tokio::task::spawn(async move { + let buf = [0; 1024]; + client.write_all(&buf).await.unwrap(); + client.read_u8().await.unwrap(); + }); + + a.await.unwrap(); + b.await.unwrap(); + + Ok(()) + } + /// Test that the handshake works, and we get the correct ALPN negotiated values. #[tokio::test] #[ntest::timeout(60000)] @@ -1621,7 +1687,8 @@ pub(super) mod tests { #[tokio::test] async fn test_peer_and_local_addresses() { - let (server, client) = tls_pair_slow_handshake(true, true, true).await; + let (server, client) = + tls_pair_slow_handshake(true, true, true, false).await; // Use a barrier to keep the client and server sockets alive until the end let barrier = Arc::new(Barrier::new(2)); let barrier_clone = barrier.clone(); @@ -1666,7 +1733,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (mut server, client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; let a = spawn(async move { server.shutdown().await.unwrap(); // While this races the handshake, we are not going to expose a handshake EOF to the stream in a @@ -1723,7 +1790,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; let a = spawn(async move { drop(server); }); @@ -1751,7 +1818,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (mut server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; let (tx, rx) = tokio::sync::oneshot::channel(); let a = spawn(async move { server.write_all(b"hello?").await.unwrap(); @@ -1795,7 +1862,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (mut server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; let (tx, rx) = tokio::sync::oneshot::channel(); let a = spawn(async move { // Shut down after the handshake @@ -1831,7 +1898,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (mut server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; let a = spawn(async move { let mut futures = FuturesUnordered::new(); @@ -1862,7 +1929,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (server, mut client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; // The server will spawn a task to complete the handshake and then go away drop(server); client.handshake().await?; @@ -1904,7 +1971,7 @@ pub(super) mod tests { #[case] client_slow: bool, ) -> TestResult { let (mut server, client) = - tls_pair_slow_handshake(false, server_slow, client_slow).await; + tls_pair_slow_handshake(false, server_slow, client_slow, false).await; drop(client); // The client will spawn a task to complete the handshake and then go away server.handshake().await?;