mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-17 07:41:23 +01:00
Refactor server listener spawning
This commit is contained in:
parent
86481fd651
commit
99f3e2aecd
1 changed files with 135 additions and 53 deletions
188
src/cli/serve.rs
188
src/cli/serve.rs
|
|
@ -1,22 +1,26 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashSet, future::Future, net::SocketAddr, sync::atomic,
|
collections::HashSet, convert::Infallible, future::Future, net::SocketAddr,
|
||||||
time::Duration,
|
sync::atomic, time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{any, get, on, MethodFilter},
|
routing::{any, get, on, IntoMakeService, MethodFilter, Route},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use axum_server::{
|
use axum_server::{
|
||||||
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
|
accept::Accept,
|
||||||
|
bind,
|
||||||
|
service::SendService,
|
||||||
|
tls_rustls::{RustlsAcceptor, RustlsConfig},
|
||||||
|
Handle as ServerHandle, Server,
|
||||||
};
|
};
|
||||||
use futures_util::FutureExt;
|
|
||||||
use http::{
|
use http::{
|
||||||
header::{self, HeaderName},
|
header::{self, HeaderName},
|
||||||
Method, StatusCode, Uri,
|
Method, StatusCode, Uri,
|
||||||
};
|
};
|
||||||
|
use hyper::body::Incoming;
|
||||||
use ruma::api::{
|
use ruma::api::{
|
||||||
client::{
|
client::{
|
||||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||||
|
|
@ -25,8 +29,13 @@ use ruma::api::{
|
||||||
federation::discovery::get_server_version,
|
federation::discovery::get_server_version,
|
||||||
IncomingRequest,
|
IncomingRequest,
|
||||||
};
|
};
|
||||||
use tokio::{signal, task::JoinSet};
|
use tokio::{
|
||||||
use tower::ServiceBuilder;
|
io::{AsyncRead, AsyncWrite},
|
||||||
|
net::TcpStream,
|
||||||
|
signal,
|
||||||
|
task::JoinSet,
|
||||||
|
};
|
||||||
|
use tower::{Layer, Service, ServiceBuilder};
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
cors::{self, CorsLayer},
|
cors::{self, CorsLayer},
|
||||||
trace::TraceLayer,
|
trace::TraceLayer,
|
||||||
|
|
@ -42,7 +51,7 @@ use crate::{
|
||||||
server_server::{self, AllowLoopbackRequests, LogRequestError},
|
server_server::{self, AllowLoopbackRequests, LogRequestError},
|
||||||
well_known,
|
well_known,
|
||||||
},
|
},
|
||||||
config::{self, Config, ListenComponent, ListenTransport},
|
config::{self, Config, ListenComponent, ListenConfig, ListenTransport},
|
||||||
database::KeyValueDatabase,
|
database::KeyValueDatabase,
|
||||||
error, observability, services, set_application_state,
|
error, observability, services, set_application_state,
|
||||||
utils::{
|
utils::{
|
||||||
|
|
@ -124,6 +133,120 @@ async fn federation_self_test() -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ServerSpawner<'cfg, M> {
|
||||||
|
config: &'cfg Config,
|
||||||
|
middlewares: M,
|
||||||
|
|
||||||
|
tls_config: Option<RustlsConfig>,
|
||||||
|
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
|
||||||
|
handles: Vec<ServerHandle>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'cfg, M> ServerSpawner<'cfg, M>
|
||||||
|
where
|
||||||
|
M: Layer<Route> + Clone + Send + 'static,
|
||||||
|
M::Service: Service<
|
||||||
|
axum::extract::Request,
|
||||||
|
Response = axum::response::Response,
|
||||||
|
Error = Infallible,
|
||||||
|
> + Clone
|
||||||
|
+ Send
|
||||||
|
+ 'static,
|
||||||
|
<M::Service as Service<axum::extract::Request>>::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
async fn new(
|
||||||
|
config: &'cfg Config,
|
||||||
|
middlewares: M,
|
||||||
|
) -> Result<Self, error::Serve> {
|
||||||
|
let tls_config = if let Some(tls) = &config.tls {
|
||||||
|
Some(
|
||||||
|
RustlsConfig::from_pem_file(&tls.certs, &tls.key)
|
||||||
|
.await
|
||||||
|
.map_err(|err| error::Serve::LoadCerts {
|
||||||
|
certs: tls.certs.clone(),
|
||||||
|
key: tls.key.clone(),
|
||||||
|
err,
|
||||||
|
})?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
middlewares,
|
||||||
|
tls_config,
|
||||||
|
servers: JoinSet::new(),
|
||||||
|
handles: Vec::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a function that transforms a lower-layer acceptor into a TLS
|
||||||
|
/// acceptor.
|
||||||
|
fn tls_acceptor_factory<A>(
|
||||||
|
&self,
|
||||||
|
listen: &ListenConfig,
|
||||||
|
) -> Result<impl FnOnce(A) -> RustlsAcceptor<A>, error::Serve> {
|
||||||
|
let config = self
|
||||||
|
.tls_config
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?;
|
||||||
|
|
||||||
|
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_server_inner<A>(
|
||||||
|
&mut self,
|
||||||
|
listen: ListenConfig,
|
||||||
|
server: Server<A>,
|
||||||
|
app: IntoMakeService<Router>,
|
||||||
|
) where
|
||||||
|
A: Accept<TcpStream, Router> + Clone + Send + Sync + 'static,
|
||||||
|
A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync,
|
||||||
|
A::Service: SendService<http::Request<Incoming>> + Send,
|
||||||
|
A::Future: Send,
|
||||||
|
{
|
||||||
|
let handle = ServerHandle::new();
|
||||||
|
let server = server.handle(handle.clone()).serve(app);
|
||||||
|
self.servers.spawn(async move {
|
||||||
|
let result = server.await;
|
||||||
|
|
||||||
|
(listen, result)
|
||||||
|
});
|
||||||
|
self.handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_server(
|
||||||
|
&mut self,
|
||||||
|
listen: ListenConfig,
|
||||||
|
) -> Result<(), error::Serve> {
|
||||||
|
let app = routes(self.config, &listen.components)
|
||||||
|
.layer(self.middlewares.clone())
|
||||||
|
.into_make_service();
|
||||||
|
|
||||||
|
match listen.transport {
|
||||||
|
ListenTransport::Tcp {
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
tls,
|
||||||
|
} => {
|
||||||
|
let addr = SocketAddr::from((address, port));
|
||||||
|
let server = bind(addr);
|
||||||
|
|
||||||
|
if tls {
|
||||||
|
let server =
|
||||||
|
server.map(self.tls_acceptor_factory(&listen)?);
|
||||||
|
self.spawn_server_inner(listen, server, app);
|
||||||
|
} else {
|
||||||
|
self.spawn_server_inner(listen, server, app);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
async fn run_server() -> Result<(), error::Serve> {
|
async fn run_server() -> Result<(), error::Serve> {
|
||||||
use error::Serve as Error;
|
use error::Serve as Error;
|
||||||
|
|
@ -200,20 +323,7 @@ async fn run_server() -> Result<(), error::Serve> {
|
||||||
.layer(axum::middleware::from_fn(observability::http_metrics_layer))
|
.layer(axum::middleware::from_fn(observability::http_metrics_layer))
|
||||||
.layer(axum::middleware::from_fn(observability::traceresponse_layer));
|
.layer(axum::middleware::from_fn(observability::traceresponse_layer));
|
||||||
|
|
||||||
let mut handles = Vec::new();
|
let mut spawner = ServerSpawner::new(config, middlewares).await?;
|
||||||
let mut servers = JoinSet::new();
|
|
||||||
|
|
||||||
let tls_config = if let Some(tls) = &config.tls {
|
|
||||||
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|
|
||||||
|err| Error::LoadCerts {
|
|
||||||
certs: tls.certs.clone(),
|
|
||||||
key: tls.key.clone(),
|
|
||||||
err,
|
|
||||||
},
|
|
||||||
)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if config.listen.is_empty() {
|
if config.listen.is_empty() {
|
||||||
return Err(Error::NoListeners);
|
return Err(Error::NoListeners);
|
||||||
|
|
@ -221,38 +331,10 @@ async fn run_server() -> Result<(), error::Serve> {
|
||||||
|
|
||||||
for listen in &config.listen {
|
for listen in &config.listen {
|
||||||
info!(listener = %listen, "Listening for incoming traffic");
|
info!(listener = %listen, "Listening for incoming traffic");
|
||||||
let app = routes(config, &listen.components)
|
spawner.spawn_server(listen.clone())?;
|
||||||
.layer(middlewares.clone())
|
|
||||||
.into_make_service();
|
|
||||||
|
|
||||||
match listen.transport {
|
|
||||||
ListenTransport::Tcp {
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
tls,
|
|
||||||
} => {
|
|
||||||
let addr = SocketAddr::from((address, port));
|
|
||||||
let handle = ServerHandle::new();
|
|
||||||
handles.push(handle.clone());
|
|
||||||
let server = if tls {
|
|
||||||
let tls_config = tls_config
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
|
|
||||||
bind_rustls(addr, tls_config)
|
|
||||||
.handle(handle)
|
|
||||||
.serve(app)
|
|
||||||
.left_future()
|
|
||||||
} else {
|
|
||||||
bind(addr).handle(handle).serve(app).right_future()
|
|
||||||
};
|
|
||||||
servers.spawn(
|
|
||||||
server.then(|result| async { (listen.clone(), result) }),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::spawn(handle_signals(tls_config, handles));
|
tokio::spawn(handle_signals(spawner.tls_config, spawner.handles));
|
||||||
|
|
||||||
if config.federation.enable && config.federation.self_test {
|
if config.federation.enable && config.federation.self_test {
|
||||||
federation_self_test()
|
federation_self_test()
|
||||||
|
|
@ -263,7 +345,7 @@ async fn run_server() -> Result<(), error::Serve> {
|
||||||
|
|
||||||
set_application_state(ApplicationState::Ready);
|
set_application_state(ApplicationState::Ready);
|
||||||
|
|
||||||
while let Some(result) = servers.join_next().await {
|
while let Some(result) = spawner.servers.join_next().await {
|
||||||
let (listen, result) =
|
let (listen, result) =
|
||||||
result.expect("should be able to join server task");
|
result.expect("should be able to join server task");
|
||||||
result.map_err(|err| Error::Listen(err, listen))?;
|
result.map_err(|err| Error::Listen(err, listen))?;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue