use { axum::{ extract::Request, handler::Handler, response::IntoResponse, routing::Route, Router, ServiceExt, }, std::{ convert::Infallible, io, net::{IpAddr, Ipv4Addr, SocketAddr}, }, tokio::net::TcpListener, tower::{layer::Layer, Service}, tower_http::{ cors::CorsLayer, normalize_path::NormalizePathLayer, trace, trace::{HttpMakeClassifier, TraceLayer}, }, tracing::{info, Level}, }; // TODO trim trailing slash into macro > let _app = NormalizePathLayer::trim_trailing_slash().layer(create_app!(routes)); #[macro_export] macro_rules! create_app { ($router:expr) => { $router }; ($router:expr, $($layer:expr),* $(,)?) => { $router$(.layer($layer))* }; } #[derive(Default)] pub struct AppBuilder { router: Router, socket: Option<(IpAddr, u16)>, cors: Option, normalize_path: Option, tracing: Option>, } impl AppBuilder { /// Creates a new app builder with default options. pub fn new() -> Self { Self::default() } /// Creates the builder from the given router. /// Only the routes and layers will be used. pub fn from_router(router: Router) -> Self { Self { router, ..Self::default() } } /// Adds a route to the previously added routes pub fn route(mut self, route: Router) -> Self { self.router = self.router.merge(route); self } /// Adds multiple routes to the previously added routes pub fn routes(mut self, routes: impl IntoIterator) -> Self { self.router = routes.into_iter().fold(self.router, Router::merge); self } /// Adds a layer to the previously added routes pub fn layer(mut self, layer: L) -> Self where L: Layer + Clone + Send + 'static, L::Service: Service + Clone + Send + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { self.router = self.router.layer(layer); self } /// Sets the socket for the server. pub fn socket>(mut self, socket: impl Into<(IP, u16)>) -> Self { let (ip, port) = socket.into(); self.socket = Some((ip.into(), port)); self } /// Sets the port for the server. pub fn port(mut self, port: u16) -> Self { self.socket = if let Some((ip, _)) = self.socket { Some((ip, port)) } else { Some((Ipv4Addr::UNSPECIFIED.into(), port)) }; self } /// Sets the fallback handler. pub fn fallback(mut self, fallback: H) -> Self where H: Handler, T: 'static, { self.router = self.router.fallback(fallback); self } /// Sets the cors layer. pub fn cors(mut self, cors: CorsLayer) -> Self { self.cors = Some(cors); self } /// Sets the normalize path option. Default is true. pub fn normalize_path(mut self, normalize_path: bool) -> Self { self.normalize_path = Some(normalize_path); self } /// Sets the trace layer. pub fn tracing(mut self, tracing: TraceLayer) -> Self { self.tracing = Some(tracing); self } /// Creates the app with the given options. /// This method is useful for testing purposes. /// Options used for configuring the listener will be lost. pub fn build(self) -> Router { let mut app = self.router; if let Some(cors) = self.cors { app = app.layer(cors); } app.layer( self.tracing.unwrap_or( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(Level::INFO)), ), ) } /// Build the app and start the server /// # Default Options /// - IP == 0.0.0.0 /// - Port == 8000 /// - Cors == None /// - Normalize Path == true /// - Tracing == Default compact pub async fn serve(self) -> io::Result<()> { let _ = fmt_trace(); // Allowed to fail let listener = self.listener().await?; if self.normalize_path.unwrap_or(true) { let app = NormalizePathLayer::trim_trailing_slash().layer(self.build()); axum::serve(listener, ServiceExt::::into_make_service(app)).await?; } else { let app = self.build(); axum::serve(listener, app.into_make_service()).await?; }; Ok(()) } async fn listener(&self) -> io::Result { let addr = SocketAddr::from(self.socket.unwrap_or((Ipv4Addr::UNSPECIFIED.into(), 8000))); info!("Initializing server on: {addr}"); TcpListener::bind(&addr).await } } fn fmt_trace() -> Result<(), String> { tracing_subscriber::fmt() .with_target(false) .compact() .try_init() .map_err(|error| error.to_string()) } #[cfg(test)] mod tests { use super::*; use axum::Router; use std::time::Duration; use tokio::time::sleep; #[tokio::test] async fn test_app_builder_serve() { let handler = tokio::spawn(async { AppBuilder::new().serve().await.unwrap(); }); sleep(Duration::from_millis(250)).await; handler.abort(); } #[tokio::test] async fn test_app_builder_all() { let handler = tokio::spawn(async { AppBuilder::new() .socket((Ipv4Addr::LOCALHOST, 8080)) .routes([Router::new()]) .fallback(|| async { "Fallback" }) .cors(CorsLayer::new()) .normalize_path(true) .tracing(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http()) .serve() .await .unwrap(); }); sleep(Duration::from_millis(250)).await; handler.abort(); } #[test] fn test_create_app_router_only() { let _app: Router<()> = create_app!(Router::new()); } }