Skip to content

Commit 99e000c

Browse files
committed
Allow access to underlying connection after handshake
1 parent 6269131 commit 99e000c

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/connection_stream.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ impl ConnectionStream {
104104
(Arc::try_unwrap(self.tcp).unwrap(), self.tls)
105105
}
106106

107+
pub fn connection(&self) -> &Connection {
108+
&self.tls
109+
}
110+
107111
pub(crate) fn tcp_stream(&self) -> &Arc<TcpStream> {
108112
&self.tcp
109113
}

src/stream.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,14 @@ impl TlsStream {
433433
(read, write)
434434
}
435435

436+
/// If the stream is open, returns the underlying rustls connection.
437+
pub fn connection(&self) -> Option<&rustls::Connection> {
438+
match &self.state {
439+
TlsStreamState::Open(stm) => Some(stm.connection()),
440+
_ => None,
441+
}
442+
}
443+
436444
pub async fn into_inner(mut self) -> io::Result<(TcpStream, Connection)> {
437445
poll_fn(|cx| self.poll_pending_handshake(cx)).await?;
438446
match std::mem::replace(&mut self.state, TlsStreamState::Closed) {
@@ -489,12 +497,14 @@ impl TlsStream {
489497
&mut self,
490498
cx: &mut Context,
491499
) -> Poll<io::Result<TlsHandshake>> {
500+
// Transition to the open state if necessary
501+
ready!(self.poll_pending_handshake(cx)?);
502+
492503
// TODO(mmastrac): Handshake shouldn't need to be cloned
493504
match &*self.handshake.handshake.lock().unwrap() {
494505
None => {
495506
// Register both wakers just in case we get split
496507
self.handshake.rx_waker.register(cx.waker());
497-
self.handshake.tx_waker.register(cx.waker());
498508
Poll::Pending
499509
}
500510
Some(handshake) => Poll::Ready(clone_result(handshake)),
@@ -1619,6 +1629,25 @@ pub(super) mod tests {
16191629
Ok(())
16201630
}
16211631

1632+
/// Test that the handshake fails, and we get the correct errors on both ends.
1633+
#[tokio::test]
1634+
#[ntest::timeout(60000)]
1635+
async fn test_client_server_raw_connection() -> TestResult {
1636+
let (mut server, mut client) =
1637+
tls_pair_alpn(&["a"], None, &["a"], None).await;
1638+
1639+
assert!(server.connection().is_none());
1640+
assert!(client.connection().is_none());
1641+
1642+
server.handshake().await?;
1643+
client.handshake().await?;
1644+
1645+
assert!(server.connection().is_some());
1646+
assert!(client.connection().is_some());
1647+
1648+
Ok(())
1649+
}
1650+
16221651
#[tokio::test]
16231652
async fn test_peer_and_local_addresses() {
16241653
let (server, client) = tls_pair_slow_handshake(true, true, true).await;

0 commit comments

Comments
 (0)