Skip to content
Open
4 changes: 4 additions & 0 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ stability_scope!(BETA {
/// # Warning
///
/// If the sink returns an error, part of the message may still be delivered.
/// After any error, the sink is no longer reusable and subsequent sends will
/// return [Error::Closed].
fn send(
&mut self,
bufs: impl Into<IoBufs> + Send,
Expand All @@ -604,6 +606,8 @@ stability_scope!(BETA {
/// # Warning
///
/// If the stream returns an error, partially read data may be discarded.
/// After any error, the stream is no longer reusable and subsequent receives
/// will return [Error::Closed].
fn recv(&mut self, len: usize) -> impl Future<Output = Result<IoBufs, Error>> + Send;

/// Peek at buffered data without consuming.
Expand Down
56 changes: 54 additions & 2 deletions runtime/src/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ impl Channel {
(
Sink {
channel: channel.clone(),
poisoned: false,
},
Stream {
channel,
buffer: BytesMut::new(),
poisoned: false,
},
)
}
Expand All @@ -58,15 +60,21 @@ impl Channel {
/// A mock sink that implements the Sink trait.
pub struct Sink {
channel: Arc<Mutex<Channel>>,
poisoned: bool,
}

impl SinkTrait for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
if self.poisoned {
return Err(Error::Closed);
}

let (os_send, data) = {
let mut channel = self.channel.lock();

// If the receiver is dead, we cannot send any more messages.
if !channel.stream_alive {
self.poisoned = true;
return Err(Error::Closed);
}

Expand Down Expand Up @@ -94,7 +102,10 @@ impl SinkTrait for Sink {
};

// Resolve the waiter.
os_send.send(data).map_err(|_| Error::SendFailed)?;
os_send.send(data).map_err(|_| {
self.poisoned = true;
Error::SendFailed
})?;
Ok(())
}
}
Expand All @@ -114,10 +125,15 @@ pub struct Stream {
channel: Arc<Mutex<Channel>>,
/// Local buffer for data that has been received but not yet consumed.
buffer: BytesMut,
poisoned: bool,
}

impl StreamTrait for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
if self.poisoned {
return Err(Error::Closed);
}

let os_recv = {
let mut channel = self.channel.lock();

Expand All @@ -141,6 +157,7 @@ impl StreamTrait for Stream {

// If the sink is dead, we cannot receive any more messages.
if !channel.sink_alive {
self.poisoned = true;
return Err(Error::Closed);
}

Expand All @@ -153,7 +170,10 @@ impl StreamTrait for Stream {
};

// Wait for the waiter to be resolved.
let data = os_recv.await.map_err(|_| Error::Closed)?;
let data = os_recv.await.map_err(|_| {
self.poisoned = true;
Error::Closed
})?;
self.buffer.extend_from_slice(&data);

assert!(self.buffer.len() >= len);
Expand Down Expand Up @@ -238,6 +258,8 @@ mod tests {
async {
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
},
async {
// Wait for the stream to start waiting
Expand All @@ -257,6 +279,8 @@ mod tests {
executor.start(|_| async move {
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
});
}

Expand Down Expand Up @@ -285,6 +309,8 @@ mod tests {
// Try to send a message. The stream is dropped, so this should fail.
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}

Expand All @@ -297,6 +323,32 @@ mod tests {
executor.start(|_| async move {
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}

#[test]
fn test_send_error_canceled_recv_poisoned() {
let (mut sink, mut stream) = Channel::init();

let executor = deterministic::Runner::default();
executor.start(|context| async move {
// Cancel a pending recv without dropping the stream.
select! {
v = stream.recv(5) => {
panic!("unexpected value: {v:?}");
},
_ = context.sleep(Duration::from_millis(50)) => {},
};

// The first send hits the canceled waiter and fails.
let result = sink.send(b"hello".as_slice()).await;
assert!(matches!(result, Err(Error::SendFailed)));

// After any send error, the mock sink must remain closed.
let result = sink.send(b"world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}

Expand Down
42 changes: 39 additions & 3 deletions runtime/src/network/deterministic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,46 @@ const EPHEMERAL_PORT_RANGE: Range<u16> = 32768..61000;
/// Implementation of [crate::Sink] for a deterministic [Network].
pub struct Sink {
sender: mocks::Sink,
poisoned: bool,
}

impl crate::Sink for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
self.sender.send(bufs).await.map_err(|_| Error::SendFailed)
if self.poisoned {
return Err(Error::Closed);
}

let result = self.sender.send(bufs).await.map_err(|_| Error::SendFailed);

// A failed send leaves the write half unusable.
if result.is_err() {
self.poisoned = true;
}

result
}
}

/// Implementation of [crate::Stream] for a deterministic [Network].
pub struct Stream {
receiver: mocks::Stream,
poisoned: bool,
}

impl crate::Stream for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
self.receiver.recv(len).await.map_err(|_| Error::RecvFailed)
if self.poisoned {
return Err(Error::Closed);
}

let result = self.receiver.recv(len).await.map_err(|_| Error::RecvFailed);

// A failed recv leaves the read half unusable.
if result.is_err() {
self.poisoned = true;
}

result
}

fn peek(&self, max_len: usize) -> &[u8] {
Expand All @@ -48,7 +72,17 @@ impl crate::Listener for Listener {

async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (socket, sender, receiver) = self.listener.recv().await.ok_or(Error::ReadFailed)?;
Ok((socket, Sink { sender }, Stream { receiver }))
Ok((
socket,
Sink {
sender,
poisoned: false,
},
Stream {
receiver,
poisoned: false,
},
))
}

fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
Expand Down Expand Up @@ -137,9 +171,11 @@ impl crate::Network for Network {
Ok((
Sink {
sender: listener_sender,
poisoned: false,
},
Stream {
receiver: dialer_receiver,
poisoned: false,
},
))
}
Expand Down
Loading
Loading