Refactor server listener spawning

This commit is contained in:
Lambda 2024-09-15 15:12:49 +00:00
parent 86481fd651
commit 99f3e2aecd

View file

@ -1,22 +1,26 @@
use std::{
collections::HashSet, future::Future, net::SocketAddr, sync::atomic,
time::Duration,
collections::HashSet, convert::Infallible, future::Future, net::SocketAddr,
sync::atomic, time::Duration,
};
use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
response::IntoResponse,
routing::{any, get, on, MethodFilter},
routing::{any, get, on, IntoMakeService, MethodFilter, Route},
Router,
};
use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
accept::Accept,
bind,
service::SendService,
tls_rustls::{RustlsAcceptor, RustlsConfig},
Handle as ServerHandle, Server,
};
use futures_util::FutureExt;
use http::{
header::{self, HeaderName},
Method, StatusCode, Uri,
};
use hyper::body::Incoming;
use ruma::api::{
client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
@ -25,8 +29,13 @@ use ruma::api::{
federation::discovery::get_server_version,
IncomingRequest,
};
use tokio::{signal, task::JoinSet};
use tower::ServiceBuilder;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
signal,
task::JoinSet,
};
use tower::{Layer, Service, ServiceBuilder};
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
@ -42,7 +51,7 @@ use crate::{
server_server::{self, AllowLoopbackRequests, LogRequestError},
well_known,
},
config::{self, Config, ListenComponent, ListenTransport},
config::{self, Config, ListenComponent, ListenConfig, ListenTransport},
database::KeyValueDatabase,
error, observability, services, set_application_state,
utils::{
@ -124,6 +133,120 @@ async fn federation_self_test() -> Result<()> {
Ok(())
}
struct ServerSpawner<'cfg, M> {
config: &'cfg Config,
middlewares: M,
tls_config: Option<RustlsConfig>,
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
handles: Vec<ServerHandle>,
}
impl<'cfg, M> ServerSpawner<'cfg, M>
where
M: Layer<Route> + Clone + Send + 'static,
M::Service: Service<
axum::extract::Request,
Response = axum::response::Response,
Error = Infallible,
> + Clone
+ Send
+ 'static,
<M::Service as Service<axum::extract::Request>>::Future: Send + 'static,
{
async fn new(
config: &'cfg Config,
middlewares: M,
) -> Result<Self, error::Serve> {
let tls_config = if let Some(tls) = &config.tls {
Some(
RustlsConfig::from_pem_file(&tls.certs, &tls.key)
.await
.map_err(|err| error::Serve::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
})?,
)
} else {
None
};
Ok(Self {
config,
middlewares,
tls_config,
servers: JoinSet::new(),
handles: Vec::new(),
})
}
/// Returns a function that transforms a lower-layer acceptor into a TLS
/// acceptor.
fn tls_acceptor_factory<A>(
&self,
listen: &ListenConfig,
) -> Result<impl FnOnce(A) -> RustlsAcceptor<A>, error::Serve> {
let config = self
.tls_config
.clone()
.ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?;
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
}
fn spawn_server_inner<A>(
&mut self,
listen: ListenConfig,
server: Server<A>,
app: IntoMakeService<Router>,
) where
A: Accept<TcpStream, Router> + Clone + Send + Sync + 'static,
A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync,
A::Service: SendService<http::Request<Incoming>> + Send,
A::Future: Send,
{
let handle = ServerHandle::new();
let server = server.handle(handle.clone()).serve(app);
self.servers.spawn(async move {
let result = server.await;
(listen, result)
});
self.handles.push(handle);
}
fn spawn_server(
&mut self,
listen: ListenConfig,
) -> Result<(), error::Serve> {
let app = routes(self.config, &listen.components)
.layer(self.middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((address, port));
let server = bind(addr);
if tls {
let server =
server.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
} else {
self.spawn_server_inner(listen, server, app);
}
Ok(())
}
}
}
}
#[allow(clippy::too_many_lines)]
async fn run_server() -> Result<(), error::Serve> {
use error::Serve as Error;
@ -200,20 +323,7 @@ async fn run_server() -> Result<(), error::Serve> {
.layer(axum::middleware::from_fn(observability::http_metrics_layer))
.layer(axum::middleware::from_fn(observability::traceresponse_layer));
let mut handles = Vec::new();
let mut servers = JoinSet::new();
let tls_config = if let Some(tls) = &config.tls {
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
},
)?)
} else {
None
};
let mut spawner = ServerSpawner::new(config, middlewares).await?;
if config.listen.is_empty() {
return Err(Error::NoListeners);
@ -221,38 +331,10 @@ async fn run_server() -> Result<(), error::Serve> {
for listen in &config.listen {
info!(listener = %listen, "Listening for incoming traffic");
let app = routes(config, &listen.components)
.layer(middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((address, port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if tls {
let tls_config = tls_config
.clone()
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app)
.left_future()
} else {
bind(addr).handle(handle).serve(app).right_future()
};
servers.spawn(
server.then(|result| async { (listen.clone(), result) }),
);
}
}
spawner.spawn_server(listen.clone())?;
}
tokio::spawn(handle_signals(tls_config, handles));
tokio::spawn(handle_signals(spawner.tls_config, spawner.handles));
if config.federation.enable && config.federation.self_test {
federation_self_test()
@ -263,7 +345,7 @@ async fn run_server() -> Result<(), error::Serve> {
set_application_state(ApplicationState::Ready);
while let Some(result) = servers.join_next().await {
while let Some(result) = spawner.servers.join_next().await {
let (listen, result) =
result.expect("should be able to join server task");
result.map_err(|err| Error::Listen(err, listen))?;