allow listening on multiple ports in config

This is a config compatibility break.

The ability to listen on multiple ports, including both TLS and non-TLS,
is necessary for running complement against grapevine.
This commit is contained in:
Benjamin Lee 2024-06-07 19:49:59 -07:00
parent b7ad00ef6e
commit f7d7952f9b
No known key found for this signature in database
GPG key ID: FB9624E2885D55A4
4 changed files with 99 additions and 45 deletions

View file

@ -25,10 +25,8 @@ pub(crate) static DEFAULT_PATH: Lazy<PathBuf> =
pub(crate) struct Config {
#[serde(default = "false_fn")]
pub(crate) conduit_compat: bool,
#[serde(default = "default_address")]
pub(crate) address: IpAddr,
#[serde(default = "default_port")]
pub(crate) port: u16,
#[serde(default = "default_listen")]
pub(crate) listen: Vec<ListenConfig>,
pub(crate) tls: Option<TlsConfig>,
pub(crate) server_name: OwnedServerName,
@ -98,6 +96,19 @@ pub(crate) struct TlsConfig {
pub(crate) key: String,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ListenConfig {
Tcp {
#[serde(default = "default_address")]
address: IpAddr,
#[serde(default = "default_port")]
port: u16,
#[serde(default = "false_fn")]
tls: bool,
},
}
fn false_fn() -> bool {
false
}
@ -106,6 +117,14 @@ fn true_fn() -> bool {
true
}
fn default_listen() -> Vec<ListenConfig> {
vec![ListenConfig::Tcp {
address: default_address(),
port: default_port(),
tls: false,
}]
}
fn default_address() -> IpAddr {
Ipv4Addr::LOCALHOST.into()
}

View file

@ -104,6 +104,15 @@ pub(crate) enum ConfigSearch {
#[allow(missing_docs)]
#[derive(Error, Debug)]
pub(crate) enum Serve {
#[error("no listeners were specified in the configuration file")]
NoListeners,
#[error(
"listener requested TLS, but no TLS cert was specified in the \
configuration file. Please set 'tls.certs' and 'tls.key'"
)]
NoTlsCerts,
#[error("failed to read TLS cert and key files at {certs:?} and {key:?}")]
LoadCerts {
certs: String,

View file

@ -15,6 +15,7 @@ use axum::{
use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
};
use futures_util::FutureExt;
use http::{
header::{self, HeaderName},
Method, StatusCode, Uri,
@ -26,14 +27,14 @@ use ruma::api::{
},
IncomingRequest,
};
use tokio::signal;
use tokio::{signal, task::JoinSet};
use tower::ServiceBuilder;
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
ServiceBuilderExt as _,
};
use tracing::{debug, info, info_span, warn, Instrument};
use tracing::{debug, error, info, info_span, warn, Instrument};
mod api;
mod args;
@ -46,7 +47,7 @@ mod utils;
pub(crate) use api::ruma_wrapper::{Ar, Ra};
use api::{client_server, server_server};
pub(crate) use config::Config;
pub(crate) use config::{Config, ListenConfig};
pub(crate) use database::KeyValueDatabase;
pub(crate) use service::{pdu::PduEvent, Services};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
@ -135,7 +136,6 @@ async fn run_server() -> Result<(), error::Serve> {
use error::Serve as Error;
let config = &services().globals.config;
let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");
@ -184,36 +184,60 @@ async fn run_server() -> Result<(), error::Serve> {
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
let app = routes(config).layer(middlewares).into_make_service();
let handle = ServerHandle::new();
let mut handles = Vec::new();
let mut servers = JoinSet::new();
tokio::spawn(shutdown_signal(handle.clone()));
let tls_config = if let Some(tls) = &config.tls {
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
},
)?)
} else {
None
};
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key)
.await
.map_err(|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
})?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
if config.listen.is_empty() {
return Err(Error::NoListeners);
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
server.await.map_err(Error::Listen)?;
for listen in &config.listen {
match listen {
ListenConfig::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((*address, *port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if *tls {
let tls_config =
tls_config.clone().ok_or(Error::NoTlsCerts)?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.clone())
.left_future()
} else {
bind(addr).handle(handle).serve(app.clone()).right_future()
};
servers.spawn(server);
}
}
None => {
let server = bind(addr).handle(handle).serve(app);
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
server.await.map_err(Error::Listen)?;
}
tokio::spawn(shutdown_signal(handles));
while let Some(result) = servers.join_next().await {
result
.expect("should be able to join server task")
.map_err(Error::Listen)?;
}
Ok(())
@ -456,7 +480,7 @@ fn routes(config: &Config) -> Router {
}
}
async fn shutdown_signal(handle: ServerHandle) {
async fn shutdown_signal(handles: Vec<ServerHandle>) {
let ctrl_c = async {
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
};
@ -480,10 +504,13 @@ async fn shutdown_signal(handle: ServerHandle) {
}
warn!("Received {}, shutting down...", sig);
handle.graceful_shutdown(Some(Duration::from_secs(30)));
services().globals.shutdown();
for handle in handles {
handle.graceful_shutdown(Some(Duration::from_secs(30)));
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
.expect("should be able to notify systemd");