This commit is contained in:
Martin Berg Alstad 2024-08-27 00:21:25 +02:00
parent cd99466266
commit ce770e9c6f
6 changed files with 81 additions and 38 deletions

View File

@ -24,7 +24,7 @@ homepage = { workspace = true }
axum = { version = "0.7", optional = true, features = ["multipart"] } axum = { version = "0.7", optional = true, features = ["multipart"] }
tower = { version = "0.5", optional = true } tower = { version = "0.5", optional = true }
tower-http = { version = "0.5", optional = true, features = ["trace", "cors", "normalize-path"] } tower-http = { version = "0.5", optional = true, features = ["trace", "cors", "normalize-path"] }
mime = { version = "0.3.17", optional = true } mime = { version = "0.3", optional = true }
# Async # Async
tokio = { version = "1.39", optional = true, features = ["fs"] } tokio = { version = "1.39", optional = true, features = ["fs"] }
tokio-util = { version = "0.7", optional = true, features = ["io"] } tokio-util = { version = "0.7", optional = true, features = ["io"] }

View File

@ -40,10 +40,13 @@ pub struct AppBuilder {
} }
impl AppBuilder { impl AppBuilder {
/// Creates a new app builder with default options.
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
/// Creates the builder from the given router.
/// Only the routes and layers will be used.
pub fn from_router(router: Router) -> Self { pub fn from_router(router: Router) -> Self {
Self { Self {
router, router,
@ -51,11 +54,13 @@ impl AppBuilder {
} }
} }
/// Adds a route to the previously added routes
pub fn route(mut self, route: Router) -> Self { pub fn route(mut self, route: Router) -> Self {
self.router = self.router.merge(route); self.router = self.router.merge(route);
self self
} }
/// Adds multiple routes to the previously added routes
pub fn routes(mut self, routes: impl IntoIterator<Item = Router>) -> Self { pub fn routes(mut self, routes: impl IntoIterator<Item = Router>) -> Self {
self.router = routes.into_iter().fold(self.router, Router::merge); self.router = routes.into_iter().fold(self.router, Router::merge);
self self
@ -74,12 +79,14 @@ impl AppBuilder {
self self
} }
/// Sets the socket for the server.
pub fn socket<IP: Into<IpAddr>>(mut self, socket: impl Into<(IP, u16)>) -> Self { pub fn socket<IP: Into<IpAddr>>(mut self, socket: impl Into<(IP, u16)>) -> Self {
let (ip, port) = socket.into(); let (ip, port) = socket.into();
self.socket = Some((ip.into(), port)); self.socket = Some((ip.into(), port));
self self
} }
/// Sets the port for the server.
pub fn port(mut self, port: u16) -> Self { pub fn port(mut self, port: u16) -> Self {
self.socket = if let Some((ip, _)) = self.socket { self.socket = if let Some((ip, _)) = self.socket {
Some((ip, port)) Some((ip, port))
@ -89,6 +96,7 @@ impl AppBuilder {
self self
} }
/// Sets the fallback handler.
pub fn fallback<H, T>(mut self, fallback: H) -> Self pub fn fallback<H, T>(mut self, fallback: H) -> Self
where where
H: Handler<T, ()>, H: Handler<T, ()>,
@ -98,16 +106,19 @@ impl AppBuilder {
self self
} }
/// Sets the cors layer.
pub fn cors(mut self, cors: CorsLayer) -> Self { pub fn cors(mut self, cors: CorsLayer) -> Self {
self.cors = Some(cors); self.cors = Some(cors);
self self
} }
/// Sets the normalize path option. Default is true.
pub fn normalize_path(mut self, normalize_path: bool) -> Self { pub fn normalize_path(mut self, normalize_path: bool) -> Self {
self.normalize_path = Some(normalize_path); self.normalize_path = Some(normalize_path);
self self
} }
/// Sets the trace layer.
pub fn tracing(mut self, tracing: TraceLayer<HttpMakeClassifier>) -> Self { pub fn tracing(mut self, tracing: TraceLayer<HttpMakeClassifier>) -> Self {
self.tracing = Some(tracing); self.tracing = Some(tracing);
self self
@ -168,44 +179,37 @@ fn fmt_trace() -> Result<(), String> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use axum::Router;
use super::*; use super::*;
use axum::Router;
use std::time::Duration;
use tokio::time::sleep;
mod tokio_tests { #[tokio::test]
use std::time::Duration; async fn test_app_builder_serve() {
let handler = tokio::spawn(async {
AppBuilder::new().serve().await.unwrap();
});
sleep(Duration::from_millis(250)).await;
handler.abort();
}
use tokio::time::sleep; #[tokio::test]
async fn test_app_builder_all() {
use super::*; let handler = tokio::spawn(async {
AppBuilder::new()
#[tokio::test] .socket((Ipv4Addr::LOCALHOST, 8080))
async fn test_app_builder_serve() { .routes([Router::new()])
let handler = tokio::spawn(async { .fallback(|| async { "Fallback" })
AppBuilder::new().serve().await.unwrap(); .cors(CorsLayer::new())
}); .normalize_path(true)
sleep(Duration::from_millis(250)).await; .tracing(TraceLayer::new_for_http())
handler.abort(); .layer(TraceLayer::new_for_http())
} .serve()
.await
#[tokio::test] .unwrap();
async fn test_app_builder_all() { });
let handler = tokio::spawn(async { sleep(Duration::from_millis(250)).await;
AppBuilder::new() handler.abort();
.socket((Ipv4Addr::LOCALHOST, 8080))
.routes([Router::new()])
.fallback(|| async { "Fallback" })
.cors(CorsLayer::new())
.normalize_path(true)
.tracing(TraceLayer::new_for_http())
.layer(TraceLayer::new_for_http())
.serve()
.await
.unwrap();
});
sleep(Duration::from_millis(250)).await;
handler.abort();
}
} }
#[test] #[test]

View File

@ -10,6 +10,7 @@ use mime::Mime;
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
/// A file extracted from a multipart request.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct File { pub struct File {
pub filename: String, pub filename: String,
@ -18,6 +19,7 @@ pub struct File {
} }
impl File { impl File {
/// Creates a new file with the given filename, bytes and content type.
pub fn new( pub fn new(
filename: impl Into<String>, filename: impl Into<String>,
bytes: impl Into<Vec<u8>>, bytes: impl Into<Vec<u8>>,
@ -30,7 +32,8 @@ impl File {
} }
} }
async fn from_field(field: Field<'_>) -> Result<Self, MultipartFileRejection> { /// Creates a new file from a field in a multipart request.
pub async fn from_field(field: Field<'_>) -> Result<Self, MultipartFileRejection> {
let filename = field let filename = field
.file_name() .file_name()
.ok_or(MultipartFileRejection::MissingFilename)? .ok_or(MultipartFileRejection::MissingFilename)?
@ -54,6 +57,7 @@ pub struct MultipartFile(pub File);
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct MultipartFiles(pub Vec<File>); pub struct MultipartFiles(pub Vec<File>);
/// Rejection type for multipart file extractors.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum MultipartFileRejection { pub enum MultipartFileRejection {
#[error(transparent)] #[error(transparent)]
@ -113,6 +117,19 @@ where
{ {
type Rejection = MultipartFileRejection; type Rejection = MultipartFileRejection;
/// Extracts a single file from a multipart request.
/// Expects exactly one file. A file must have a name, bytes and optionally a content type.
/// This extractor consumes the request and must ble placed last in the handler.
/// # Example
/// ```
/// use std::str::from_utf8;
/// use axum::response::Html;
/// use lib::axum::extractor::MultipartFile;
///
/// async fn upload_file(MultipartFile(file): MultipartFile) -> Html<String> {
/// Html(String::from_utf8(file.bytes).unwrap())
/// }
/// ```
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let multipart = Multipart::from_request(req, state).await?; let multipart = Multipart::from_request(req, state).await?;
let files = get_files(multipart).await?; let files = get_files(multipart).await?;
@ -132,6 +149,24 @@ where
{ {
type Rejection = MultipartFileRejection; type Rejection = MultipartFileRejection;
/// Extracts multiple files from a multipart request.
/// Expects at least one file. A file must have a name, bytes and optionally a content type.
/// This extractor consumes the request and must ble placed last in the handler.
/// # Example
/// ```
/// use axum::response::Html;
/// use lib::axum::extractor::MultipartFiles;
/// use std::str::from_utf8;
///
/// async fn upload_files(MultipartFiles(files): MultipartFiles) -> Html<String> {
/// let content = files
/// .iter()
/// .map(|file| String::from_utf8(file.bytes.clone()).unwrap())
/// .collect::<Vec<String>>()
/// .join("<br>");
/// Html(content)
/// }
/// ```
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let multipart = Multipart::from_request(req, state).await?; let multipart = Multipart::from_request(req, state).await?;
let files = get_files(multipart).await?; let files = get_files(multipart).await?;

View File

@ -18,6 +18,7 @@ mod tests {
use axum::http::header::CONTENT_TYPE; use axum::http::header::CONTENT_TYPE;
use axum::http::{HeaderValue, StatusCode}; use axum::http::{HeaderValue, StatusCode};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use mime::APPLICATION_JSON;
use serde::Serialize; use serde::Serialize;
use crate::serde::response::BaseResponse; use crate::serde::response::BaseResponse;
@ -39,7 +40,7 @@ mod tests {
assert_eq!(json_response.status(), StatusCode::OK); assert_eq!(json_response.status(), StatusCode::OK);
assert_eq!( assert_eq!(
json_response.headers().get(CONTENT_TYPE), json_response.headers().get(CONTENT_TYPE),
Some(&HeaderValue::from_static("application/json")) Some(&HeaderValue::from_static(APPLICATION_JSON.as_ref()))
); );
} }

View File

@ -76,6 +76,7 @@ macro_rules! routes {
}; };
} }
/// Merges the given routers into a single router.
#[macro_export] #[macro_export]
macro_rules! join_routes { macro_rules! join_routes {
($($route:expr),* $(,)?) => { ($($route:expr),* $(,)?) => {

View File

@ -3,11 +3,13 @@ use derive_more::{Constructor, From};
use into_response_derive::IntoResponse; use into_response_derive::IntoResponse;
use serde::Serialize; use serde::Serialize;
/// Wrapper for a vector of items.
#[derive(Debug, Clone, PartialEq, Default, Serialize, From, Constructor)] #[derive(Debug, Clone, PartialEq, Default, Serialize, From, Constructor)]
pub struct Array<T: Serialize> { pub struct Array<T: Serialize> {
pub data: Vec<T>, pub data: Vec<T>,
} }
/// Wrapper for a count.
#[derive( #[derive(
Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, IntoResponse, From, Constructor, Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, IntoResponse, From, Constructor,
)] )]