Refactor server listener spawning

This commit is contained in:
Lambda 2024-09-15 15:12:49 +00:00
parent 86481fd651
commit 99f3e2aecd

View file

@ -1,22 +1,26 @@
use std::{ use std::{
collections::HashSet, future::Future, net::SocketAddr, sync::atomic, collections::HashSet, convert::Infallible, future::Future, net::SocketAddr,
time::Duration, sync::atomic, time::Duration,
}; };
use axum::{ use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
response::IntoResponse, response::IntoResponse,
routing::{any, get, on, MethodFilter}, routing::{any, get, on, IntoMakeService, MethodFilter, Route},
Router, Router,
}; };
use axum_server::{ use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, accept::Accept,
bind,
service::SendService,
tls_rustls::{RustlsAcceptor, RustlsConfig},
Handle as ServerHandle, Server,
}; };
use futures_util::FutureExt;
use http::{ use http::{
header::{self, HeaderName}, header::{self, HeaderName},
Method, StatusCode, Uri, Method, StatusCode, Uri,
}; };
use hyper::body::Incoming;
use ruma::api::{ use ruma::api::{
client::{ client::{
error::{Error as RumaError, ErrorBody, ErrorKind}, error::{Error as RumaError, ErrorBody, ErrorKind},
@ -25,8 +29,13 @@ use ruma::api::{
federation::discovery::get_server_version, federation::discovery::get_server_version,
IncomingRequest, IncomingRequest,
}; };
use tokio::{signal, task::JoinSet}; use tokio::{
use tower::ServiceBuilder; io::{AsyncRead, AsyncWrite},
net::TcpStream,
signal,
task::JoinSet,
};
use tower::{Layer, Service, ServiceBuilder};
use tower_http::{ use tower_http::{
cors::{self, CorsLayer}, cors::{self, CorsLayer},
trace::TraceLayer, trace::TraceLayer,
@ -42,7 +51,7 @@ use crate::{
server_server::{self, AllowLoopbackRequests, LogRequestError}, server_server::{self, AllowLoopbackRequests, LogRequestError},
well_known, well_known,
}, },
config::{self, Config, ListenComponent, ListenTransport}, config::{self, Config, ListenComponent, ListenConfig, ListenTransport},
database::KeyValueDatabase, database::KeyValueDatabase,
error, observability, services, set_application_state, error, observability, services, set_application_state,
utils::{ utils::{
@ -124,6 +133,120 @@ async fn federation_self_test() -> Result<()> {
Ok(()) Ok(())
} }
struct ServerSpawner<'cfg, M> {
config: &'cfg Config,
middlewares: M,
tls_config: Option<RustlsConfig>,
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
handles: Vec<ServerHandle>,
}
impl<'cfg, M> ServerSpawner<'cfg, M>
where
M: Layer<Route> + Clone + Send + 'static,
M::Service: Service<
axum::extract::Request,
Response = axum::response::Response,
Error = Infallible,
> + Clone
+ Send
+ 'static,
<M::Service as Service<axum::extract::Request>>::Future: Send + 'static,
{
async fn new(
config: &'cfg Config,
middlewares: M,
) -> Result<Self, error::Serve> {
let tls_config = if let Some(tls) = &config.tls {
Some(
RustlsConfig::from_pem_file(&tls.certs, &tls.key)
.await
.map_err(|err| error::Serve::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
})?,
)
} else {
None
};
Ok(Self {
config,
middlewares,
tls_config,
servers: JoinSet::new(),
handles: Vec::new(),
})
}
/// Returns a function that transforms a lower-layer acceptor into a TLS
/// acceptor.
fn tls_acceptor_factory<A>(
&self,
listen: &ListenConfig,
) -> Result<impl FnOnce(A) -> RustlsAcceptor<A>, error::Serve> {
let config = self
.tls_config
.clone()
.ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?;
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
}
fn spawn_server_inner<A>(
&mut self,
listen: ListenConfig,
server: Server<A>,
app: IntoMakeService<Router>,
) where
A: Accept<TcpStream, Router> + Clone + Send + Sync + 'static,
A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync,
A::Service: SendService<http::Request<Incoming>> + Send,
A::Future: Send,
{
let handle = ServerHandle::new();
let server = server.handle(handle.clone()).serve(app);
self.servers.spawn(async move {
let result = server.await;
(listen, result)
});
self.handles.push(handle);
}
fn spawn_server(
&mut self,
listen: ListenConfig,
) -> Result<(), error::Serve> {
let app = routes(self.config, &listen.components)
.layer(self.middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((address, port));
let server = bind(addr);
if tls {
let server =
server.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
} else {
self.spawn_server_inner(listen, server, app);
}
Ok(())
}
}
}
}
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
async fn run_server() -> Result<(), error::Serve> { async fn run_server() -> Result<(), error::Serve> {
use error::Serve as Error; use error::Serve as Error;
@ -200,20 +323,7 @@ async fn run_server() -> Result<(), error::Serve> {
.layer(axum::middleware::from_fn(observability::http_metrics_layer)) .layer(axum::middleware::from_fn(observability::http_metrics_layer))
.layer(axum::middleware::from_fn(observability::traceresponse_layer)); .layer(axum::middleware::from_fn(observability::traceresponse_layer));
let mut handles = Vec::new(); let mut spawner = ServerSpawner::new(config, middlewares).await?;
let mut servers = JoinSet::new();
let tls_config = if let Some(tls) = &config.tls {
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
},
)?)
} else {
None
};
if config.listen.is_empty() { if config.listen.is_empty() {
return Err(Error::NoListeners); return Err(Error::NoListeners);
@ -221,38 +331,10 @@ async fn run_server() -> Result<(), error::Serve> {
for listen in &config.listen { for listen in &config.listen {
info!(listener = %listen, "Listening for incoming traffic"); info!(listener = %listen, "Listening for incoming traffic");
let app = routes(config, &listen.components) spawner.spawn_server(listen.clone())?;
.layer(middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((address, port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if tls {
let tls_config = tls_config
.clone()
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app)
.left_future()
} else {
bind(addr).handle(handle).serve(app).right_future()
};
servers.spawn(
server.then(|result| async { (listen.clone(), result) }),
);
}
}
} }
tokio::spawn(handle_signals(tls_config, handles)); tokio::spawn(handle_signals(spawner.tls_config, spawner.handles));
if config.federation.enable && config.federation.self_test { if config.federation.enable && config.federation.self_test {
federation_self_test() federation_self_test()
@ -263,7 +345,7 @@ async fn run_server() -> Result<(), error::Serve> {
set_application_state(ApplicationState::Ready); set_application_state(ApplicationState::Ready);
while let Some(result) = servers.join_next().await { while let Some(result) = spawner.servers.join_next().await {
let (listen, result) = let (listen, result) =
result.expect("should be able to join server task"); result.expect("should be able to join server task");
result.map_err(|err| Error::Listen(err, listen))?; result.map_err(|err| Error::Listen(err, listen))?;