From 0898a501668a74a40e602e3b6557709c57051bbd Mon Sep 17 00:00:00 2001 From: Martin Berg Alstad Date: Sun, 30 Jun 2024 20:16:17 +0200 Subject: [PATCH] Added MultipartFile extractors. Moved cfg macro to lib where possible. Changed some features, and made some deps optional --- Cargo.lock | 62 +++++++++++++- Cargo.toml | 17 ++-- src/axum/app.rs | 12 ++- src/axum/extractor.rs | 187 +++++++++++++++++++++++++++++++++++++++++ src/axum/load.rs | 9 +- src/axum/mod.rs | 2 + src/axum/response.rs | 4 +- src/axum/router.rs | 7 +- src/io/file.rs | 4 +- src/lib.rs | 9 +- src/nom/combinators.rs | 7 +- src/nom/util.rs | 4 +- src/serde/response.rs | 5 +- src/vector/distinct.rs | 4 +- src/vector/map.rs | 3 +- src/vector/matrix.rs | 3 +- src/vector/set.rs | 3 +- 17 files changed, 287 insertions(+), 55 deletions(-) create mode 100644 src/axum/extractor.rs diff --git a/Cargo.lock b/Cargo.lock index 43f98df..78f9692 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,7 @@ dependencies = [ "matchit", "memchr", "mime", + "multer", "percent-encoding", "pin-project-lite", "rustversion", @@ -130,6 +131,15 @@ dependencies = [ "syn", ] +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + [[package]] name = "fnv" version = "1.0.7" @@ -284,12 +294,13 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lib" -version = "1.1.1" +version = "1.3.0" dependencies = [ "axum", "derive", "nom", "serde", + "thiserror", "tokio", "tokio-util", "tower", @@ -354,6 +365,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "nom" version = "7.1.3" @@ -547,6 +575,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "syn" version = "2.0.67" @@ -570,6 +604,26 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +[[package]] +name = "thiserror" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -734,6 +788,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 8a2cdd7..47be2b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lib" -version = "1.1.1" +version = "1.3.0" edition = "2021" authors = ["Martin Berg Alstad"] @@ -8,15 +8,17 @@ authors = ["Martin Berg Alstad"] [dependencies] # Api -axum = { version = "0.7.5", optional = true } +axum = { version = "0.7.5", optional = true, features = ["multipart"] } tower = { version = "0.4.13", optional = true } tower-http = { version = "0.5.2", optional = true, features = ["trace", "cors", "normalize-path"] } # Async tokio = { version = "1.38.0", optional = true, features = ["fs"] } tokio-util = { version = "0.7.11", optional = true, features = ["io"] } +# Error handling +thiserror = { version = "1.0.61", optional = true } # Logging -tracing = "0.1.40" -tracing-subscriber = "0.3.18" +tracing = { version = "0.1.40", optional = true } +tracing-subscriber = { version = "0.3.18", optional = true } # Parsing nom = { version = "7.1.3", optional = true } # Serialization / Deserialization @@ -25,9 +27,10 @@ serde = { version = "1.0.203", optional = true, features = ["derive"] } derive = { path = "derive", optional = true } [features] -axum = ["dep:axum", "dep:tower", "dep:tower-http"] -tokio = ["dep:tokio", "dep:tokio-util"] -vec = [] +axum = ["dep:axum", "dep:tower", "dep:tower-http", "dep:thiserror", "dep:tracing", "dep:tracing-subscriber"] +tokio = ["dep:tokio"] +io = ["dep:tokio", "dep:tokio-util"] +iter = [] nom = ["dep:nom"] serde = ["dep:serde"] derive = ["dep:derive", "axum", "serde"] diff --git a/src/axum/app.rs b/src/axum/app.rs index b90df91..806d96e 100644 --- a/src/axum/app.rs +++ b/src/axum/app.rs @@ -1,4 +1,3 @@ -#[cfg(feature = "axum")] use { axum::{extract::Request, handler::Handler, Router, ServiceExt}, std::net::Ipv4Addr, @@ -11,12 +10,11 @@ use { }, tracing::{info, Level}, }; -#[cfg(all(feature = "axum", feature = "tokio"))] +#[cfg(feature = "tokio")] use {std::io, std::net::SocketAddr, tokio::net::TcpListener}; // TODO trim trailing slash into macro > let _app = NormalizePathLayer::trim_trailing_slash().layer(create_app!(routes)); #[macro_export] -#[cfg(feature = "axum")] macro_rules! create_app { ($router:expr) => { $router @@ -27,7 +25,6 @@ macro_rules! create_app { } #[derive(Default)] -#[cfg(feature = "axum")] pub struct AppBuilder { router: Router, socket: Option<(Ipv4Addr, u16)>, @@ -36,7 +33,6 @@ pub struct AppBuilder { tracing: Option>, } -#[cfg(all(feature = "axum", feature = "tokio"))] impl AppBuilder { pub fn new() -> Self { Self::default() @@ -81,8 +77,9 @@ impl AppBuilder { self } + #[cfg(feature = "tokio")] pub async fn serve(self) -> io::Result<()> { - let _ = fmt_trace(); + let _ = fmt_trace(); // Allowed to fail let listener = self.listener().await?; if self.normalize_path.unwrap_or(true) { @@ -95,6 +92,7 @@ impl AppBuilder { Ok(()) } + #[cfg(feature = "tokio")] async fn listener(&self) -> io::Result { let addr = SocketAddr::from(self.socket.unwrap_or((Ipv4Addr::UNSPECIFIED, 8000))); info!("Initializing server on: {addr}"); @@ -124,7 +122,7 @@ fn fmt_trace() -> Result<(), String> { .map_err(|error| error.to_string()) } -#[cfg(all(test, feature = "axum"))] +#[cfg(test)] mod tests { use axum::Router; diff --git a/src/axum/extractor.rs b/src/axum/extractor.rs new file mode 100644 index 0000000..1828c82 --- /dev/null +++ b/src/axum/extractor.rs @@ -0,0 +1,187 @@ +use axum::{ + async_trait, + extract::{ + multipart::{Field, MultipartError, MultipartRejection}, + FromRequest, Multipart, Request, + }, + response::IntoResponse, +}; +use thiserror::Error; + +#[derive(PartialEq, Eq, Ord, PartialOrd, Hash, Debug, Clone, Copy)] +pub enum ContentType { + Json, + Form, + Multipart, + Pdf, + Html, + Unknown, +} + +impl From<&str> for ContentType { + fn from(content_type: &str) -> Self { + match content_type { + "application/json" => ContentType::Json, + "application/x-www-form-urlencoded" => ContentType::Form, + "multipart/form-data" => ContentType::Multipart, + "application/pdf" => ContentType::Pdf, + "text/html" => ContentType::Html, + _ => ContentType::Unknown, + } + } +} + +impl From for ContentType { + fn from(content_type: String) -> Self { + ContentType::from(content_type.as_str()) + } +} + +impl From> for ContentType { + fn from(content_type: Option<&str>) -> Self { + content_type + .map(ContentType::from) + .unwrap_or(ContentType::Unknown) + } +} + +pub struct File { + pub filename: String, + pub bytes: Vec, + pub content_type: ContentType, +} + +impl File { + pub fn new( + filename: impl Into, + bytes: impl Into>, + content_type: impl Into, + ) -> Self { + Self { + filename: filename.into(), + bytes: bytes.into(), + content_type: content_type.into(), + } + } + + async fn from_field(field: Field<'_>) -> Result { + let filename = field + .file_name() + .ok_or(MultipartFileRejection::MissingFilename)? + .to_string(); + let content_type: ContentType = field.content_type().into(); + let bytes = field.bytes().await?; + Ok(File::new(filename, bytes, content_type)) + } +} + +/// Extractor for a single file from a multipart request. +/// Expects exactly one file. A file must have a name, bytes and optionally a content type. +/// This extractor consumes the request and must ble placed last in the handler. +pub struct MultipartFile(pub File); +/// Extractor for multiple files from a multipart request. +/// Expects at least one file. A file must have a name, bytes and optionally a content type. +/// This extractor consumes the request and must ble placed last in the handler. +pub struct MultipartFiles(pub Vec); + +#[derive(Debug, Error)] +pub enum MultipartFileRejection { + #[error(transparent)] + MultipartRejection(#[from] MultipartRejection), + #[error("Field error: {0}")] + FieldError(String), + #[error("No files found")] + NoFiles, + #[error("Expected one file, got several")] + SeveralFiles, + #[error("Missing filename")] + MissingFilename, + #[error("Error in body of multipart: {0}")] + BodyError(String), +} + +impl From for MultipartFileRejection { + fn from(error: MultipartError) -> Self { + MultipartFileRejection::BodyError(error.body_text()) + } +} + +impl IntoResponse for MultipartFileRejection { + fn into_response(self) -> axum::response::Response { + match self { + MultipartFileRejection::MultipartRejection(rejection) => rejection.into_response(), + MultipartFileRejection::FieldError(error) => { + (axum::http::StatusCode::BAD_REQUEST, error).into_response() + } + MultipartFileRejection::NoFiles => { + (axum::http::StatusCode::BAD_REQUEST, "No files found").into_response() + } + MultipartFileRejection::SeveralFiles => ( + axum::http::StatusCode::BAD_REQUEST, + "Expected one file, got several", + ) + .into_response(), + MultipartFileRejection::MissingFilename => { + (axum::http::StatusCode::BAD_REQUEST, "Missing filename").into_response() + } + MultipartFileRejection::BodyError(error) => { + (axum::http::StatusCode::BAD_REQUEST, error).into_response() + } + } + } +} + +#[async_trait] +impl FromRequest for MultipartFile +where + S: Send + Sync, +{ + type Rejection = MultipartFileRejection; + + async fn from_request(req: Request, state: &S) -> Result { + let mut multipart = Multipart::from_request(req, state).await?; + let fields = get_fields(&mut multipart).await?; + if fields.len() > 1 { + Err(MultipartFileRejection::SeveralFiles) + } else { + let field = fields + .into_iter() + .next() + .ok_or(MultipartFileRejection::NoFiles)?; + Ok(MultipartFile(File::from_field(field).await?)) + } + } +} + +#[async_trait] +impl FromRequest for MultipartFiles +where + S: Send + Sync, +{ + type Rejection = MultipartFileRejection; + + async fn from_request(req: Request, state: &S) -> Result { + let mut multipart = Multipart::from_request(req, state).await?; + let fields = get_fields(&mut multipart).await?; + if fields.is_empty() { + Err(MultipartFileRejection::NoFiles) + } else { + let mut files = vec![]; + for field in fields.into_iter() { + files.push(File::from_field(field).await?); + } + Ok(MultipartFiles(files)) + } + } +} + +async fn get_fields<'a>( + multipart: &'a mut Multipart, +) -> Result>, MultipartFileRejection> { + let fields: Vec = multipart.next_field().await?.into_iter().collect(); + if fields.is_empty() { + Err(MultipartFileRejection::NoFiles) + } else { + Ok(fields) + } +} diff --git a/src/axum/load.rs b/src/axum/load.rs index e44543d..e4cd8d9 100644 --- a/src/axum/load.rs +++ b/src/axum/load.rs @@ -1,4 +1,4 @@ -#[cfg(all(feature = "tokio", feature = "axum"))] +#[cfg(feature = "tokio")] use {crate::io::file, axum::body::Body, axum::response::Html, std::io}; /// Load an HTML file from the given file path, relative to the current directory. @@ -10,7 +10,7 @@ use {crate::io::file, axum::body::Body, axum::response::Html, std::io}; /// ``` /// let html = async { lib::axum::load::load_html("openapi.html").await.unwrap() }; /// ``` -#[cfg(all(feature = "tokio", feature = "axum"))] +#[cfg(feature = "tokio")] pub async fn load_html(file_path: Path) -> Result, io::Error> where Path: AsRef, @@ -18,7 +18,7 @@ where load_file(file_path).await.map(Html) } -#[cfg(all(feature = "tokio", feature = "axum"))] +#[cfg(feature = "tokio")] pub async fn load_file(file_path: Path) -> Result where Path: AsRef, @@ -38,7 +38,6 @@ where /// ``` // TODO check platform and use correct path separator #[macro_export] -#[cfg(feature = "axum")] macro_rules! load_html { ($filepath:expr) => { axum::response::Html( @@ -58,7 +57,7 @@ macro_rules! load_html { }; } -#[cfg(all(test, feature = "axum"))] +#[cfg(test)] mod tests { #[test] fn test_load_html() { diff --git a/src/axum/mod.rs b/src/axum/mod.rs index ace9b3d..561774a 100644 --- a/src/axum/mod.rs +++ b/src/axum/mod.rs @@ -1,4 +1,6 @@ pub mod app; +pub mod extractor; pub mod load; +#[cfg(feature = "serde")] pub mod response; pub mod router; diff --git a/src/axum/response.rs b/src/axum/response.rs index 4d1a4bc..aedd1c8 100644 --- a/src/axum/response.rs +++ b/src/axum/response.rs @@ -1,4 +1,3 @@ -#[cfg(all(feature = "axum", feature = "serde"))] use { crate::serde::response::BaseResponse, axum::{ @@ -8,14 +7,13 @@ use { serde::Serialize, }; -#[cfg(all(feature = "axum", feature = "serde"))] impl IntoResponse for BaseResponse { fn into_response(self) -> Response { Json(self).into_response() } } -#[cfg(all(test, feature = "axum", feature = "serde"))] +#[cfg(test)] mod tests { use axum::http::header::CONTENT_TYPE; use axum::http::{HeaderValue, StatusCode}; diff --git a/src/axum/router.rs b/src/axum/router.rs index c719244..d400c3b 100644 --- a/src/axum/router.rs +++ b/src/axum/router.rs @@ -18,7 +18,6 @@ /// )); /// ``` #[macro_export] -#[cfg(feature = "axum")] macro_rules! router { ($body:expr) => { pub(crate) fn router() -> axum::Router { @@ -52,7 +51,6 @@ macro_rules! router { /// ); /// ``` #[macro_export] -#[cfg(feature = "axum")] macro_rules! routes { ($($method:ident $route:expr => $func:expr),* $(,)?) => { axum::Router::new() @@ -63,7 +61,6 @@ macro_rules! routes { } #[macro_export] -#[cfg(feature = "axum")] macro_rules! join_routes { ($($route:expr),* $(,)?) => { axum::Router::new()$( @@ -72,7 +69,7 @@ macro_rules! join_routes { }; } -#[cfg(all(test, feature = "axum"))] +#[cfg(test)] mod tests { use axum::extract::State; use axum::Router; @@ -117,7 +114,7 @@ mod tests { #[test] fn test_routes() { - let _router: Router<()> = routes!( + let _router: Router = routes!( get "/" => index, post "/" => || async {} ); diff --git a/src/io/file.rs b/src/io/file.rs index d0cace8..9730e54 100644 --- a/src/io/file.rs +++ b/src/io/file.rs @@ -1,7 +1,5 @@ -#[cfg(feature = "tokio")] use {std::io::Error, tokio::fs::File, tokio_util::io::ReaderStream}; -#[cfg(feature = "tokio")] pub async fn load_file(file_path: Path) -> Result, Error> where Path: AsRef, @@ -9,7 +7,7 @@ where File::open(file_path).await.map(ReaderStream::new) } -#[cfg(all(test, feature = "tokio"))] +#[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 291eae0..7e22a9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,16 @@ #![allow(dead_code)] +#[cfg(feature = "axum")] pub mod axum; +#[cfg(feature = "io")] pub mod io; +#[cfg(feature = "nom")] pub mod nom; +#[cfg(feature = "serde")] pub mod serde; -mod traits; +pub mod traits; +#[cfg(feature = "iter")] pub mod vector; -#[cfg(feature = "derive")] +#[cfg(all(feature = "derive", feature = "serde"))] pub extern crate derive; diff --git a/src/nom/combinators.rs b/src/nom/combinators.rs index 225274a..3f61bfd 100644 --- a/src/nom/combinators.rs +++ b/src/nom/combinators.rs @@ -1,4 +1,3 @@ -#[cfg(feature = "nom")] use { nom::{ bytes::complete::take_while_m_n, @@ -16,7 +15,6 @@ use { /// - Parameters /// - `inner`: The parser to trim /// - Returns: A parser that trims leading and trailing whitespace from the input and then runs the value from the inner parser -#[cfg(feature = "nom")] pub fn trim<'a, Parser, R>(inner: Parser) -> impl FnMut(&'a str) -> IResult<&'a str, R> where Parser: Fn(&'a str) -> IResult<&'a str, R>, @@ -29,7 +27,6 @@ where /// - Parameters /// - `inner`: The parser to run inside the parentheses /// - Returns: A parser that parses a parenthesized expression -#[cfg(feature = "nom")] pub fn parenthesized<'a, Parser, R>(inner: Parser) -> impl FnMut(&'a str) -> IResult<&'a str, R> where Parser: Fn(&'a str) -> IResult<&'a str, R>, @@ -42,7 +39,6 @@ where /// - `n`: The length of the string to take /// - `predicate`: The predicate to call to validate the input /// - Returns: A parser that takes `n` characters from the input -#[cfg(feature = "nom")] pub fn take_where(n: usize, predicate: F) -> impl Fn(Input) -> IResult where Input: InputTake + InputIter + InputLength + Slice>, @@ -51,7 +47,6 @@ where take_while_m_n(n, n, predicate) } -#[cfg(feature = "nom")] pub fn exhausted<'a, Parser, R>(inner: Parser) -> impl FnMut(&'a str) -> IResult<&'a str, R> where Parser: Fn(&'a str) -> IResult<&'a str, R>, @@ -59,7 +54,7 @@ where terminated(inner, eof) } -#[cfg(all(test, feature = "nom"))] +#[cfg(test)] mod tests { use nom::bytes::streaming::take_while; diff --git a/src/nom/util.rs b/src/nom/util.rs index d368d9d..c75164d 100644 --- a/src/nom/util.rs +++ b/src/nom/util.rs @@ -1,10 +1,8 @@ -#[cfg(feature = "nom")] use { crate::traits::IntoResult, nom::{error::Error, IResult}, }; -#[cfg(feature = "nom")] impl IntoResult for IResult { type Error = nom::Err>; fn into_result(self) -> Result { @@ -12,7 +10,7 @@ impl IntoResult for IResult { } } -#[cfg(all(test, feature = "nom"))] +#[cfg(test)] mod tests { use nom::character::complete::char as c; diff --git a/src/serde/response.rs b/src/serde/response.rs index 8b1a4a9..8e928a8 100644 --- a/src/serde/response.rs +++ b/src/serde/response.rs @@ -1,15 +1,12 @@ -#[cfg(feature = "serde")] use serde::Serialize; #[derive(Serialize)] -#[cfg(feature = "serde")] pub struct BaseResponse { pub version: String, #[serde(flatten)] pub body: T, // T must be a struct (or enum?) } -#[cfg(feature = "serde")] impl BaseResponse { pub fn new(version: impl Into, body: T) -> Self { Self { @@ -19,7 +16,7 @@ impl BaseResponse { } } -#[cfg(all(test, feature = "serde"))] +#[cfg(test)] mod tests { use super::*; diff --git a/src/vector/distinct.rs b/src/vector/distinct.rs index 284f02e..700ffc1 100644 --- a/src/vector/distinct.rs +++ b/src/vector/distinct.rs @@ -1,9 +1,7 @@ -#[cfg(feature = "vec")] pub trait Distinct { fn distinct(&mut self); } -#[cfg(feature = "vec")] impl Distinct for Vec { fn distinct(&mut self) { *self = self.iter().fold(vec![], |mut acc, x| { @@ -15,7 +13,7 @@ impl Distinct for Vec { } } -#[cfg(all(test, feature = "vec"))] +#[cfg(test)] mod tests { use super::*; diff --git a/src/vector/map.rs b/src/vector/map.rs index e2d7c08..b0cd573 100644 --- a/src/vector/map.rs +++ b/src/vector/map.rs @@ -1,5 +1,4 @@ #[macro_export] -#[cfg(feature = "vec")] macro_rules! map { () => { std::collections::HashMap::new() }; ($($k:expr => $v:expr),* $(,)?) => { @@ -13,7 +12,7 @@ macro_rules! map { }; } -#[cfg(all(test, feature = "vec"))] +#[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/src/vector/matrix.rs b/src/vector/matrix.rs index 0d61551..83463f4 100644 --- a/src/vector/matrix.rs +++ b/src/vector/matrix.rs @@ -1,5 +1,4 @@ #[macro_export] -#[cfg(feature = "vec")] macro_rules! matrix { ($x:expr; $m:expr, $n:expr) => { vec![vec![$x; $n]; $m] @@ -16,7 +15,7 @@ macro_rules! matrix { }; } -#[cfg(all(test, feature = "vec"))] +#[cfg(test)] mod tests { #[test] diff --git a/src/vector/set.rs b/src/vector/set.rs index d585766..85b80e9 100644 --- a/src/vector/set.rs +++ b/src/vector/set.rs @@ -1,5 +1,4 @@ #[macro_export] -#[cfg(feature = "vec")] macro_rules! set { () => { std::collections::HashSet::new() }; ($($x:expr),* $(,)?) => { @@ -13,7 +12,7 @@ macro_rules! set { }; } -#[cfg(all(test, feature = "vec"))] +#[cfg(test)] mod tests { use std::collections::HashSet;