From 99f3e2aecd01d80639115b2e7fc6fa8357719235 Mon Sep 17 00:00:00 2001 From: Lambda Date: Sun, 15 Sep 2024 15:12:49 +0000 Subject: [PATCH] Refactor server listener spawning --- src/cli/serve.rs | 188 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 135 insertions(+), 53 deletions(-) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index 8a2b57b9..e0612576 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -1,22 +1,26 @@ use std::{ - collections::HashSet, future::Future, net::SocketAddr, sync::atomic, - time::Duration, + collections::HashSet, convert::Infallible, future::Future, net::SocketAddr, + sync::atomic, time::Duration, }; use axum::{ extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, response::IntoResponse, - routing::{any, get, on, MethodFilter}, + routing::{any, get, on, IntoMakeService, MethodFilter, Route}, Router, }; use axum_server::{ - bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, + accept::Accept, + bind, + service::SendService, + tls_rustls::{RustlsAcceptor, RustlsConfig}, + Handle as ServerHandle, Server, }; -use futures_util::FutureExt; use http::{ header::{self, HeaderName}, Method, StatusCode, Uri, }; +use hyper::body::Incoming; use ruma::api::{ client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, @@ -25,8 +29,13 @@ use ruma::api::{ federation::discovery::get_server_version, IncomingRequest, }; -use tokio::{signal, task::JoinSet}; -use tower::ServiceBuilder; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + signal, + task::JoinSet, +}; +use tower::{Layer, Service, ServiceBuilder}; use tower_http::{ cors::{self, CorsLayer}, trace::TraceLayer, @@ -42,7 +51,7 @@ use crate::{ server_server::{self, AllowLoopbackRequests, LogRequestError}, well_known, }, - config::{self, Config, ListenComponent, ListenTransport}, + config::{self, Config, ListenComponent, ListenConfig, ListenTransport}, database::KeyValueDatabase, error, observability, services, set_application_state, utils::{ @@ -124,6 +133,120 @@ async fn federation_self_test() -> Result<()> { Ok(()) } +struct ServerSpawner<'cfg, M> { + config: &'cfg Config, + middlewares: M, + + tls_config: Option, + servers: JoinSet<(ListenConfig, std::io::Result<()>)>, + handles: Vec, +} + +impl<'cfg, M> ServerSpawner<'cfg, M> +where + M: Layer + Clone + Send + 'static, + M::Service: Service< + axum::extract::Request, + Response = axum::response::Response, + Error = Infallible, + > + Clone + + Send + + 'static, + >::Future: Send + 'static, +{ + async fn new( + config: &'cfg Config, + middlewares: M, + ) -> Result { + let tls_config = if let Some(tls) = &config.tls { + Some( + RustlsConfig::from_pem_file(&tls.certs, &tls.key) + .await + .map_err(|err| error::Serve::LoadCerts { + certs: tls.certs.clone(), + key: tls.key.clone(), + err, + })?, + ) + } else { + None + }; + + Ok(Self { + config, + middlewares, + tls_config, + servers: JoinSet::new(), + handles: Vec::new(), + }) + } + + /// Returns a function that transforms a lower-layer acceptor into a TLS + /// acceptor. + fn tls_acceptor_factory( + &self, + listen: &ListenConfig, + ) -> Result RustlsAcceptor, error::Serve> { + let config = self + .tls_config + .clone() + .ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?; + + Ok(|inner| RustlsAcceptor::new(config).acceptor(inner)) + } + + fn spawn_server_inner( + &mut self, + listen: ListenConfig, + server: Server, + app: IntoMakeService, + ) where + A: Accept + Clone + Send + Sync + 'static, + A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync, + A::Service: SendService> + Send, + A::Future: Send, + { + let handle = ServerHandle::new(); + let server = server.handle(handle.clone()).serve(app); + self.servers.spawn(async move { + let result = server.await; + + (listen, result) + }); + self.handles.push(handle); + } + + fn spawn_server( + &mut self, + listen: ListenConfig, + ) -> Result<(), error::Serve> { + let app = routes(self.config, &listen.components) + .layer(self.middlewares.clone()) + .into_make_service(); + + match listen.transport { + ListenTransport::Tcp { + address, + port, + tls, + } => { + let addr = SocketAddr::from((address, port)); + let server = bind(addr); + + if tls { + let server = + server.map(self.tls_acceptor_factory(&listen)?); + self.spawn_server_inner(listen, server, app); + } else { + self.spawn_server_inner(listen, server, app); + } + + Ok(()) + } + } + } +} + #[allow(clippy::too_many_lines)] async fn run_server() -> Result<(), error::Serve> { use error::Serve as Error; @@ -200,20 +323,7 @@ async fn run_server() -> Result<(), error::Serve> { .layer(axum::middleware::from_fn(observability::http_metrics_layer)) .layer(axum::middleware::from_fn(observability::traceresponse_layer)); - let mut handles = Vec::new(); - let mut servers = JoinSet::new(); - - let tls_config = if let Some(tls) = &config.tls { - Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err( - |err| Error::LoadCerts { - certs: tls.certs.clone(), - key: tls.key.clone(), - err, - }, - )?) - } else { - None - }; + let mut spawner = ServerSpawner::new(config, middlewares).await?; if config.listen.is_empty() { return Err(Error::NoListeners); @@ -221,38 +331,10 @@ async fn run_server() -> Result<(), error::Serve> { for listen in &config.listen { info!(listener = %listen, "Listening for incoming traffic"); - let app = routes(config, &listen.components) - .layer(middlewares.clone()) - .into_make_service(); - - match listen.transport { - ListenTransport::Tcp { - address, - port, - tls, - } => { - let addr = SocketAddr::from((address, port)); - let handle = ServerHandle::new(); - handles.push(handle.clone()); - let server = if tls { - let tls_config = tls_config - .clone() - .ok_or_else(|| Error::NoTlsCerts(listen.clone()))?; - bind_rustls(addr, tls_config) - .handle(handle) - .serve(app) - .left_future() - } else { - bind(addr).handle(handle).serve(app).right_future() - }; - servers.spawn( - server.then(|result| async { (listen.clone(), result) }), - ); - } - } + spawner.spawn_server(listen.clone())?; } - tokio::spawn(handle_signals(tls_config, handles)); + tokio::spawn(handle_signals(spawner.tls_config, spawner.handles)); if config.federation.enable && config.federation.self_test { federation_self_test() @@ -263,7 +345,7 @@ async fn run_server() -> Result<(), error::Serve> { set_application_state(ApplicationState::Ready); - while let Some(result) = servers.join_next().await { + while let Some(result) = spawner.servers.join_next().await { let (listen, result) = result.expect("should be able to join server task"); result.map_err(|err| Error::Listen(err, listen))?;