@@ -433,6 +433,14 @@ impl TlsStream {
433
433
( read, write)
434
434
}
435
435
436
+ /// If the stream is open, returns the underlying rustls connection.
437
+ pub fn connection ( & self ) -> Option < & rustls:: Connection > {
438
+ match & self . state {
439
+ TlsStreamState :: Open ( stm) => Some ( stm. connection ( ) ) ,
440
+ _ => None ,
441
+ }
442
+ }
443
+
436
444
pub async fn into_inner ( mut self ) -> io:: Result < ( TcpStream , Connection ) > {
437
445
poll_fn ( |cx| self . poll_pending_handshake ( cx) ) . await ?;
438
446
match std:: mem:: replace ( & mut self . state , TlsStreamState :: Closed ) {
@@ -489,12 +497,14 @@ impl TlsStream {
489
497
& mut self ,
490
498
cx : & mut Context ,
491
499
) -> Poll < io:: Result < TlsHandshake > > {
500
+ // Transition to the open state if necessary
501
+ ready ! ( self . poll_pending_handshake( cx) ?) ;
502
+
492
503
// TODO(mmastrac): Handshake shouldn't need to be cloned
493
504
match & * self . handshake . handshake . lock ( ) . unwrap ( ) {
494
505
None => {
495
506
// Register both wakers just in case we get split
496
507
self . handshake . rx_waker . register ( cx. waker ( ) ) ;
497
- self . handshake . tx_waker . register ( cx. waker ( ) ) ;
498
508
Poll :: Pending
499
509
}
500
510
Some ( handshake) => Poll :: Ready ( clone_result ( handshake) ) ,
@@ -1619,6 +1629,25 @@ pub(super) mod tests {
1619
1629
Ok ( ( ) )
1620
1630
}
1621
1631
1632
+ /// Test that the handshake fails, and we get the correct errors on both ends.
1633
+ #[ tokio:: test]
1634
+ #[ ntest:: timeout( 60000 ) ]
1635
+ async fn test_client_server_raw_connection ( ) -> TestResult {
1636
+ let ( mut server, mut client) =
1637
+ tls_pair_alpn ( & [ "a" ] , None , & [ "a" ] , None ) . await ;
1638
+
1639
+ assert ! ( server. connection( ) . is_none( ) ) ;
1640
+ assert ! ( client. connection( ) . is_none( ) ) ;
1641
+
1642
+ server. handshake ( ) . await ?;
1643
+ client. handshake ( ) . await ?;
1644
+
1645
+ assert ! ( server. connection( ) . is_some( ) ) ;
1646
+ assert ! ( client. connection( ) . is_some( ) ) ;
1647
+
1648
+ Ok ( ( ) )
1649
+ }
1650
+
1622
1651
#[ tokio:: test]
1623
1652
async fn test_peer_and_local_addresses ( ) {
1624
1653
let ( server, client) = tls_pair_slow_handshake ( true , true , true ) . await ;
0 commit comments