Skip to content

Commit

Permalink
util: add a channel body (#140)
Browse files Browse the repository at this point in the history
* Add channel body

* review: use `sync::oneshot` for error channel

this applies a review suggestion here:
https://github.com/hyperium/http-body/pull/100/files#r1399781061

this commit refactors the channel-backed body in #100, changing the
`mpsc::Receiver<E>` used to transmit an error into a
`oneshot::Receiver<E>`.

this should improve memory usage, and make the channel a smaller
structure.

in order to achieve this, some minor adjustments are made:

* use pin projection, projecting pinnedness to the oneshot receiver,
  polling it via `core::future::Future::poll(..)` to yield a body frame.

* add `Debug` bounds were needed.

as an alternative, see tokio-rs/tokio#7059, which proposed a
`poll_recv(..)` inherent method for a oneshot channel receiver.

* review: use `&mut self` method receivers

this applies a review suggestion here:
https://github.com/hyperium/http-body/pull/100/files#r1399780355

this commit refactors the channel-backed body in #100, changing the
signature of `send_*` methods on the sender to require a mutable
reference.

* review: fix `<Channel<D, E> as Body>::poll_frame()`

see: #140 (comment)

this commit adds test coverage exposing the bug, and tightens the
pattern used to match frames yielded by the data channel.

now, when the channel is closed, a `None` will flow onwards and poll the
error channel. `None` will be returned when the error channel is closed,
which also indicates that the associated `Sender` has been dropped.

---------

Co-authored-by: David Pedersen <[email protected]>
  • Loading branch information
cratelyn and davidpdrsn authored Jan 17, 2025
1 parent 7339aec commit 86fdf00
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
8 changes: 8 additions & 0 deletions http-body-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ keywords = ["http"]
categories = ["web-programming"]
rust-version = "1.61"

[features]
default = []
channel = ["dep:tokio"]
full = ["channel"]

[dependencies]
bytes = "1"
futures-core = { version = "0.3", default-features = false }
http = "1"
http-body = { version = "1", path = "../http-body" }
pin-project-lite = "0.2"

# optional dependencies
tokio = { version = "1", features = ["sync"], optional = true }

[dev-dependencies]
futures-util = { version = "0.3", default-features = false }
tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] }
234 changes: 234 additions & 0 deletions http-body-util/src/channel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//! A body backed by a channel.
use std::{
fmt::Display,
pin::Pin,
task::{Context, Poll},
};

use bytes::Buf;
use http::HeaderMap;
use http_body::{Body, Frame};
use pin_project_lite::pin_project;
use tokio::sync::{mpsc, oneshot};

pin_project! {
/// A body backed by a channel.
pub struct Channel<D, E = std::convert::Infallible> {
rx_frame: mpsc::Receiver<Frame<D>>,
#[pin]
rx_error: oneshot::Receiver<E>,
}
}

impl<D, E> Channel<D, E> {
/// Create a new channel body.
///
/// The channel will buffer up to the provided number of messages. Once the buffer is full,
/// attempts to send new messages will wait until a message is received from the channel. The
/// provided buffer capacity must be at least 1.
pub fn new(buffer: usize) -> (Sender<D, E>, Self) {
let (tx_frame, rx_frame) = mpsc::channel(buffer);
let (tx_error, rx_error) = oneshot::channel();
(Sender { tx_frame, tx_error }, Self { rx_frame, rx_error })
}
}

impl<D, E> Body for Channel<D, E>
where
D: Buf,
{
type Data = D;
type Error = E;

fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();

match this.rx_frame.poll_recv(cx) {
Poll::Ready(frame @ Some(_)) => return Poll::Ready(frame.map(Ok)),
Poll::Ready(None) | Poll::Pending => {}
}

use core::future::Future;
match this.rx_error.poll(cx) {
Poll::Ready(Ok(error)) => return Poll::Ready(Some(Err(error))),
Poll::Ready(Err(_)) => return Poll::Ready(None),
Poll::Pending => {}
}

Poll::Pending
}
}

impl<D, E: std::fmt::Debug> std::fmt::Debug for Channel<D, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Channel")
.field("rx_frame", &self.rx_frame)
.field("rx_error", &self.rx_error)
.finish()
}
}

/// A sender half created through [`Channel::new`].
pub struct Sender<D, E = std::convert::Infallible> {
tx_frame: mpsc::Sender<Frame<D>>,
tx_error: oneshot::Sender<E>,
}

impl<D, E> Sender<D, E> {
/// Send a frame on the channel.
pub async fn send(&mut self, frame: Frame<D>) -> Result<(), SendError> {
self.tx_frame.send(frame).await.map_err(|_| SendError)
}

/// Send data on data channel.
pub async fn send_data(&mut self, buf: D) -> Result<(), SendError> {
self.send(Frame::data(buf)).await
}

/// Send trailers on trailers channel.
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), SendError> {
self.send(Frame::trailers(trailers)).await
}

/// Aborts the body in an abnormal fashion.
pub fn abort(self, error: E) {
self.tx_error.send(error).ok();
}
}

impl<D, E: std::fmt::Debug> std::fmt::Debug for Sender<D, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender")
.field("tx_frame", &self.tx_frame)
.field("tx_error", &self.tx_error)
.finish()
}
}

/// The error returned if [`Sender`] fails to send because the receiver is closed.
#[derive(Debug)]
#[non_exhaustive]
pub struct SendError;

impl Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to send frame")
}
}

impl std::error::Error for SendError {}

#[cfg(test)]
mod tests {
use bytes::Bytes;
use http::{HeaderName, HeaderValue};

use crate::BodyExt;

use super::*;

#[tokio::test]
async fn empty() {
let (tx, body) = Channel::<Bytes>::new(1024);
drop(tx);

let collected = body.collect().await.unwrap();
assert!(collected.trailers().is_none());
assert!(collected.to_bytes().is_empty());
}

#[tokio::test]
async fn can_send_data() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
});

let collected = body.collect().await.unwrap();
assert!(collected.trailers().is_none());
assert_eq!(collected.to_bytes(), "Hello!");
}

#[tokio::test]
async fn can_send_trailers() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);
tx.send_trailers(trailers).await.unwrap();
});

let collected = body.collect().await.unwrap();
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
assert!(collected.to_bytes().is_empty());
}

#[tokio::test]
async fn can_send_both_data_and_trailers() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);
tx.send_trailers(trailers).await.unwrap();
});

let collected = body.collect().await.unwrap();
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
assert_eq!(collected.to_bytes(), "Hello!");
}

/// A stand-in for an error type, for unit tests.
type Error = &'static str;
/// An example error message.
const MSG: Error = "oh no";

#[tokio::test]
async fn aborts_before_trailers() {
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
tx.abort(MSG);
});

let err = body.collect().await.unwrap_err();
assert_eq!(err, MSG);
}

#[tokio::test]
async fn aborts_after_trailers() {
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);
tx.send_trailers(trailers).await.unwrap();
tx.abort(MSG);
});

let err = body.collect().await.unwrap_err();
assert_eq!(err, MSG);
}
}
6 changes: 6 additions & 0 deletions http-body-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod full;
mod limited;
mod stream;

#[cfg(feature = "channel")]
pub mod channel;

mod util;

use self::combinators::{BoxBody, MapErr, MapFrame, UnsyncBoxBody};
Expand All @@ -26,6 +29,9 @@ pub use self::full::Full;
pub use self::limited::{LengthLimitError, Limited};
pub use self::stream::{BodyDataStream, BodyStream, StreamBody};

#[cfg(feature = "channel")]
pub use self::channel::Channel;

/// An extension trait for [`http_body::Body`] adding various combinators and adapters
pub trait BodyExt: http_body::Body {
/// Returns a future that resolves to the next [`Frame`], if any.
Expand Down

0 comments on commit 86fdf00

Please sign in to comment.