Skip to content

Commit

Permalink
Implement OptionalFromRequest for Multipart
Browse files Browse the repository at this point in the history
  • Loading branch information
mcginty committed Feb 12, 2025
1 parent a192480 commit f6e4fef
Showing 1 changed file with 55 additions and 4 deletions.
59 changes: 55 additions & 4 deletions axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{FromRequest, Request};
use crate::body::Bytes;
use axum_core::{
__composite_rejection as composite_rejection, __define_rejection as define_rejection,
extract::OptionalFromRequest,
response::{IntoResponse, Response},
RequestExt,
};
Expand Down Expand Up @@ -71,13 +72,37 @@ where
type Rejection = MultipartRejection;

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let boundary = content_type_str(req.headers())
.and_then(|content_type| multer::parse_boundary(content_type).ok())
.ok_or(InvalidBoundary)?;
let stream = req.with_limited_body().into_body();
let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
Ok(Self { inner: multipart })
}
}

impl<S> OptionalFromRequest<S> for Multipart
where
S: Send + Sync,
{
type Rejection = MultipartRejection;

async fn from_request(req: Request, _state: &S) -> Result<Option<Self>, Self::Rejection> {
let Some(content_type) = content_type_str(req.headers()) else {
return Ok(None);
};
match multer::parse_boundary(content_type) {
Ok(boundary) => {
let stream = req.with_limited_body().into_body();
let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
Ok(Some(Self { inner: multipart }))
}
Err(multer::Error::NoMultipart) => Ok(None),
Err(_) => Err(MultipartRejection::InvalidBoundary(InvalidBoundary)),
}
}
}

impl Multipart {
/// Yields the next [`Field`] if available.
pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
Expand Down Expand Up @@ -282,9 +307,8 @@ impl IntoResponse for MultipartError {
}
}

fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
fn content_type_str(headers: &HeaderMap) -> Option<&str> {
headers.get(CONTENT_TYPE)?.to_str().ok()
}

composite_rejection! {
Expand Down Expand Up @@ -378,4 +402,31 @@ mod tests {
let res = client.post("/").multipart(form).await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[crate::test]
async fn optional_multipart() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();

async fn handle(multipart: Option<Multipart>) -> Result<StatusCode, MultipartError> {
if let Some(mut multipart) = multipart {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(StatusCode::OK)
} else {
Ok(StatusCode::NO_CONTENT)
}
}

let app = Router::new().route("/", post(handle));
let client = TestClient::new(app);
let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));

let res = client.post("/").multipart(form).await;
assert_eq!(res.status(), StatusCode::OK);

let res = client.post("/").await;
assert_eq!(res.status(), StatusCode::NO_CONTENT);
}
}

0 comments on commit f6e4fef

Please sign in to comment.