mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 15:21:24 +01:00
Refactor server listener spawning
This commit is contained in:
parent
86481fd651
commit
99f3e2aecd
1 changed files with 135 additions and 53 deletions
188
src/cli/serve.rs
188
src/cli/serve.rs
|
|
@ -1,22 +1,26 @@
|
|||
use std::{
|
||||
collections::HashSet, future::Future, net::SocketAddr, sync::atomic,
|
||||
time::Duration,
|
||||
collections::HashSet, convert::Infallible, future::Future, net::SocketAddr,
|
||||
sync::atomic, time::Duration,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
||||
response::IntoResponse,
|
||||
routing::{any, get, on, MethodFilter},
|
||||
routing::{any, get, on, IntoMakeService, MethodFilter, Route},
|
||||
Router,
|
||||
};
|
||||
use axum_server::{
|
||||
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
|
||||
accept::Accept,
|
||||
bind,
|
||||
service::SendService,
|
||||
tls_rustls::{RustlsAcceptor, RustlsConfig},
|
||||
Handle as ServerHandle, Server,
|
||||
};
|
||||
use futures_util::FutureExt;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use hyper::body::Incoming;
|
||||
use ruma::api::{
|
||||
client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
|
|
@ -25,8 +29,13 @@ use ruma::api::{
|
|||
federation::discovery::get_server_version,
|
||||
IncomingRequest,
|
||||
};
|
||||
use tokio::{signal, task::JoinSet};
|
||||
use tower::ServiceBuilder;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::TcpStream,
|
||||
signal,
|
||||
task::JoinSet,
|
||||
};
|
||||
use tower::{Layer, Service, ServiceBuilder};
|
||||
use tower_http::{
|
||||
cors::{self, CorsLayer},
|
||||
trace::TraceLayer,
|
||||
|
|
@ -42,7 +51,7 @@ use crate::{
|
|||
server_server::{self, AllowLoopbackRequests, LogRequestError},
|
||||
well_known,
|
||||
},
|
||||
config::{self, Config, ListenComponent, ListenTransport},
|
||||
config::{self, Config, ListenComponent, ListenConfig, ListenTransport},
|
||||
database::KeyValueDatabase,
|
||||
error, observability, services, set_application_state,
|
||||
utils::{
|
||||
|
|
@ -124,6 +133,120 @@ async fn federation_self_test() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct ServerSpawner<'cfg, M> {
|
||||
config: &'cfg Config,
|
||||
middlewares: M,
|
||||
|
||||
tls_config: Option<RustlsConfig>,
|
||||
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
|
||||
handles: Vec<ServerHandle>,
|
||||
}
|
||||
|
||||
impl<'cfg, M> ServerSpawner<'cfg, M>
|
||||
where
|
||||
M: Layer<Route> + Clone + Send + 'static,
|
||||
M::Service: Service<
|
||||
axum::extract::Request,
|
||||
Response = axum::response::Response,
|
||||
Error = Infallible,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
<M::Service as Service<axum::extract::Request>>::Future: Send + 'static,
|
||||
{
|
||||
async fn new(
|
||||
config: &'cfg Config,
|
||||
middlewares: M,
|
||||
) -> Result<Self, error::Serve> {
|
||||
let tls_config = if let Some(tls) = &config.tls {
|
||||
Some(
|
||||
RustlsConfig::from_pem_file(&tls.certs, &tls.key)
|
||||
.await
|
||||
.map_err(|err| error::Serve::LoadCerts {
|
||||
certs: tls.certs.clone(),
|
||||
key: tls.key.clone(),
|
||||
err,
|
||||
})?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
middlewares,
|
||||
tls_config,
|
||||
servers: JoinSet::new(),
|
||||
handles: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a function that transforms a lower-layer acceptor into a TLS
|
||||
/// acceptor.
|
||||
fn tls_acceptor_factory<A>(
|
||||
&self,
|
||||
listen: &ListenConfig,
|
||||
) -> Result<impl FnOnce(A) -> RustlsAcceptor<A>, error::Serve> {
|
||||
let config = self
|
||||
.tls_config
|
||||
.clone()
|
||||
.ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?;
|
||||
|
||||
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
|
||||
}
|
||||
|
||||
fn spawn_server_inner<A>(
|
||||
&mut self,
|
||||
listen: ListenConfig,
|
||||
server: Server<A>,
|
||||
app: IntoMakeService<Router>,
|
||||
) where
|
||||
A: Accept<TcpStream, Router> + Clone + Send + Sync + 'static,
|
||||
A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync,
|
||||
A::Service: SendService<http::Request<Incoming>> + Send,
|
||||
A::Future: Send,
|
||||
{
|
||||
let handle = ServerHandle::new();
|
||||
let server = server.handle(handle.clone()).serve(app);
|
||||
self.servers.spawn(async move {
|
||||
let result = server.await;
|
||||
|
||||
(listen, result)
|
||||
});
|
||||
self.handles.push(handle);
|
||||
}
|
||||
|
||||
fn spawn_server(
|
||||
&mut self,
|
||||
listen: ListenConfig,
|
||||
) -> Result<(), error::Serve> {
|
||||
let app = routes(self.config, &listen.components)
|
||||
.layer(self.middlewares.clone())
|
||||
.into_make_service();
|
||||
|
||||
match listen.transport {
|
||||
ListenTransport::Tcp {
|
||||
address,
|
||||
port,
|
||||
tls,
|
||||
} => {
|
||||
let addr = SocketAddr::from((address, port));
|
||||
let server = bind(addr);
|
||||
|
||||
if tls {
|
||||
let server =
|
||||
server.map(self.tls_acceptor_factory(&listen)?);
|
||||
self.spawn_server_inner(listen, server, app);
|
||||
} else {
|
||||
self.spawn_server_inner(listen, server, app);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn run_server() -> Result<(), error::Serve> {
|
||||
use error::Serve as Error;
|
||||
|
|
@ -200,20 +323,7 @@ async fn run_server() -> Result<(), error::Serve> {
|
|||
.layer(axum::middleware::from_fn(observability::http_metrics_layer))
|
||||
.layer(axum::middleware::from_fn(observability::traceresponse_layer));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
let mut servers = JoinSet::new();
|
||||
|
||||
let tls_config = if let Some(tls) = &config.tls {
|
||||
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|
||||
|err| Error::LoadCerts {
|
||||
certs: tls.certs.clone(),
|
||||
key: tls.key.clone(),
|
||||
err,
|
||||
},
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut spawner = ServerSpawner::new(config, middlewares).await?;
|
||||
|
||||
if config.listen.is_empty() {
|
||||
return Err(Error::NoListeners);
|
||||
|
|
@ -221,38 +331,10 @@ async fn run_server() -> Result<(), error::Serve> {
|
|||
|
||||
for listen in &config.listen {
|
||||
info!(listener = %listen, "Listening for incoming traffic");
|
||||
let app = routes(config, &listen.components)
|
||||
.layer(middlewares.clone())
|
||||
.into_make_service();
|
||||
|
||||
match listen.transport {
|
||||
ListenTransport::Tcp {
|
||||
address,
|
||||
port,
|
||||
tls,
|
||||
} => {
|
||||
let addr = SocketAddr::from((address, port));
|
||||
let handle = ServerHandle::new();
|
||||
handles.push(handle.clone());
|
||||
let server = if tls {
|
||||
let tls_config = tls_config
|
||||
.clone()
|
||||
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
|
||||
bind_rustls(addr, tls_config)
|
||||
.handle(handle)
|
||||
.serve(app)
|
||||
.left_future()
|
||||
} else {
|
||||
bind(addr).handle(handle).serve(app).right_future()
|
||||
};
|
||||
servers.spawn(
|
||||
server.then(|result| async { (listen.clone(), result) }),
|
||||
);
|
||||
}
|
||||
}
|
||||
spawner.spawn_server(listen.clone())?;
|
||||
}
|
||||
|
||||
tokio::spawn(handle_signals(tls_config, handles));
|
||||
tokio::spawn(handle_signals(spawner.tls_config, spawner.handles));
|
||||
|
||||
if config.federation.enable && config.federation.self_test {
|
||||
federation_self_test()
|
||||
|
|
@ -263,7 +345,7 @@ async fn run_server() -> Result<(), error::Serve> {
|
|||
|
||||
set_application_state(ApplicationState::Ready);
|
||||
|
||||
while let Some(result) = servers.join_next().await {
|
||||
while let Some(result) = spawner.servers.join_next().await {
|
||||
let (listen, result) =
|
||||
result.expect("should be able to join server task");
|
||||
result.map_err(|err| Error::Listen(err, listen))?;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue