allow listening on multiple ports in config

This is a config compatibility break.

The ability to listen on multiple ports, including both TLS and non-TLS,
is necessary for running complement against grapevine.
This commit is contained in:
Benjamin Lee 2024-06-07 19:49:59 -07:00
parent b7ad00ef6e
commit f7d7952f9b
No known key found for this signature in database
GPG key ID: FB9624E2885D55A4
4 changed files with 99 additions and 45 deletions

View file

@ -15,6 +15,7 @@ use axum::{
use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
};
use futures_util::FutureExt;
use http::{
header::{self, HeaderName},
Method, StatusCode, Uri,
@ -26,14 +27,14 @@ use ruma::api::{
},
IncomingRequest,
};
use tokio::signal;
use tokio::{signal, task::JoinSet};
use tower::ServiceBuilder;
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
ServiceBuilderExt as _,
};
use tracing::{debug, info, info_span, warn, Instrument};
use tracing::{debug, error, info, info_span, warn, Instrument};
mod api;
mod args;
@ -46,7 +47,7 @@ mod utils;
pub(crate) use api::ruma_wrapper::{Ar, Ra};
use api::{client_server, server_server};
pub(crate) use config::Config;
pub(crate) use config::{Config, ListenConfig};
pub(crate) use database::KeyValueDatabase;
pub(crate) use service::{pdu::PduEvent, Services};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
@ -135,7 +136,6 @@ async fn run_server() -> Result<(), error::Serve> {
use error::Serve as Error;
let config = &services().globals.config;
let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");
@ -184,36 +184,60 @@ async fn run_server() -> Result<(), error::Serve> {
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
let app = routes(config).layer(middlewares).into_make_service();
let handle = ServerHandle::new();
let mut handles = Vec::new();
let mut servers = JoinSet::new();
tokio::spawn(shutdown_signal(handle.clone()));
let tls_config = if let Some(tls) = &config.tls {
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
},
)?)
} else {
None
};
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key)
.await
.map_err(|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
})?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
if config.listen.is_empty() {
return Err(Error::NoListeners);
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
server.await.map_err(Error::Listen)?;
for listen in &config.listen {
match listen {
ListenConfig::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((*address, *port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if *tls {
let tls_config =
tls_config.clone().ok_or(Error::NoTlsCerts)?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.clone())
.left_future()
} else {
bind(addr).handle(handle).serve(app.clone()).right_future()
};
servers.spawn(server);
}
}
None => {
let server = bind(addr).handle(handle).serve(app);
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
server.await.map_err(Error::Listen)?;
}
tokio::spawn(shutdown_signal(handles));
while let Some(result) = servers.join_next().await {
result
.expect("should be able to join server task")
.map_err(Error::Listen)?;
}
Ok(())
@ -456,7 +480,7 @@ fn routes(config: &Config) -> Router {
}
}
async fn shutdown_signal(handle: ServerHandle) {
async fn shutdown_signal(handles: Vec<ServerHandle>) {
let ctrl_c = async {
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
};
@ -480,10 +504,13 @@ async fn shutdown_signal(handle: ServerHandle) {
}
warn!("Received {}, shutting down...", sig);
handle.graceful_shutdown(Some(Duration::from_secs(30)));
services().globals.shutdown();
for handle in handles {
handle.graceful_shutdown(Some(Duration::from_secs(30)));
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
.expect("should be able to notify systemd");