diff --git a/nix/modules/default/default.nix b/nix/modules/default/default.nix index 9c7d35a8..9118cfea 100644 --- a/nix/modules/default/default.nix +++ b/nix/modules/default/default.nix @@ -29,13 +29,6 @@ in type = types.submodule { freeformType = format.type; options = { - address = lib.mkOption { - type = types.nonEmptyStr; - description = '' - The local IP address to bind to. - ''; - default = "::1"; - }; conduit_compat = lib.mkOption { type = types.bool; description = '' @@ -56,12 +49,18 @@ in then "/var/lib/matrix-conduit" else "/var/lib/grapevine"; }; - port = lib.mkOption { - type = types.port; + listen = lib.mkOption { + type = types.listOf format.type; description = '' - The local port to bind to. + List of places to listen for incoming connections. ''; - default = 6167; + default = [ + { + type = "tcp"; + address = "::1"; + port = 6167; + } + ]; }; }; }; diff --git a/src/config.rs b/src/config.rs index 88399a46..bf5de8f9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -25,10 +25,8 @@ pub(crate) static DEFAULT_PATH: Lazy = pub(crate) struct Config { #[serde(default = "false_fn")] pub(crate) conduit_compat: bool, - #[serde(default = "default_address")] - pub(crate) address: IpAddr, - #[serde(default = "default_port")] - pub(crate) port: u16, + #[serde(default = "default_listen")] + pub(crate) listen: Vec, pub(crate) tls: Option, pub(crate) server_name: OwnedServerName, @@ -98,6 +96,19 @@ pub(crate) struct TlsConfig { pub(crate) key: String, } +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum ListenConfig { + Tcp { + #[serde(default = "default_address")] + address: IpAddr, + #[serde(default = "default_port")] + port: u16, + #[serde(default = "false_fn")] + tls: bool, + }, +} + fn false_fn() -> bool { false } @@ -106,6 +117,14 @@ fn true_fn() -> bool { true } +fn default_listen() -> Vec { + vec![ListenConfig::Tcp { + address: default_address(), + port: default_port(), + tls: false, + }] +} + fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } diff --git a/src/error.rs b/src/error.rs index bf09b836..c63f743b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -104,6 +104,15 @@ pub(crate) enum ConfigSearch { #[allow(missing_docs)] #[derive(Error, Debug)] pub(crate) enum Serve { + #[error("no listeners were specified in the configuration file")] + NoListeners, + + #[error( + "listener requested TLS, but no TLS cert was specified in the \ + configuration file. Please set 'tls.certs' and 'tls.key'" + )] + NoTlsCerts, + #[error("failed to read TLS cert and key files at {certs:?} and {key:?}")] LoadCerts { certs: String, diff --git a/src/main.rs b/src/main.rs index 96fdbeef..9f732360 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ use axum::{ use axum_server::{ bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, }; +use futures_util::FutureExt; use http::{ header::{self, HeaderName}, Method, StatusCode, Uri, @@ -26,14 +27,14 @@ use ruma::api::{ }, IncomingRequest, }; -use tokio::signal; +use tokio::{signal, task::JoinSet}; use tower::ServiceBuilder; use tower_http::{ cors::{self, CorsLayer}, trace::TraceLayer, ServiceBuilderExt as _, }; -use tracing::{debug, info, info_span, warn, Instrument}; +use tracing::{debug, error, info, info_span, warn, Instrument}; mod api; mod args; @@ -46,7 +47,7 @@ mod utils; pub(crate) use api::ruma_wrapper::{Ar, Ra}; use api::{client_server, server_server}; -pub(crate) use config::Config; +pub(crate) use config::{Config, ListenConfig}; pub(crate) use database::KeyValueDatabase; pub(crate) use service::{pdu::PduEvent, Services}; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] @@ -135,7 +136,6 @@ async fn run_server() -> Result<(), error::Serve> { use error::Serve as Error; let config = &services().globals.config; - let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -184,36 +184,60 @@ async fn run_server() -> Result<(), error::Serve> { .layer(axum::middleware::from_fn(observability::http_metrics_layer)); let app = routes(config).layer(middlewares).into_make_service(); - let handle = ServerHandle::new(); + let mut handles = Vec::new(); + let mut servers = JoinSet::new(); - tokio::spawn(shutdown_signal(handle.clone())); + let tls_config = if let Some(tls) = &config.tls { + Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err( + |err| Error::LoadCerts { + certs: tls.certs.clone(), + key: tls.key.clone(), + err, + }, + )?) + } else { + None + }; - match &config.tls { - Some(tls) => { - let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key) - .await - .map_err(|err| Error::LoadCerts { - certs: tls.certs.clone(), - key: tls.key.clone(), - err, - })?; - let server = bind_rustls(addr, conf).handle(handle).serve(app); + if config.listen.is_empty() { + return Err(Error::NoListeners); + } - #[cfg(feature = "systemd")] - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) - .expect("should be able to notify systemd"); - - server.await.map_err(Error::Listen)?; + for listen in &config.listen { + match listen { + ListenConfig::Tcp { + address, + port, + tls, + } => { + let addr = SocketAddr::from((*address, *port)); + let handle = ServerHandle::new(); + handles.push(handle.clone()); + let server = if *tls { + let tls_config = + tls_config.clone().ok_or(Error::NoTlsCerts)?; + bind_rustls(addr, tls_config) + .handle(handle) + .serve(app.clone()) + .left_future() + } else { + bind(addr).handle(handle).serve(app.clone()).right_future() + }; + servers.spawn(server); + } } - None => { - let server = bind(addr).handle(handle).serve(app); + } - #[cfg(feature = "systemd")] - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) - .expect("should be able to notify systemd"); + #[cfg(feature = "systemd")] + sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) + .expect("should be able to notify systemd"); - server.await.map_err(Error::Listen)?; - } + tokio::spawn(shutdown_signal(handles)); + + while let Some(result) = servers.join_next().await { + result + .expect("should be able to join server task") + .map_err(Error::Listen)?; } Ok(()) @@ -456,7 +480,7 @@ fn routes(config: &Config) -> Router { } } -async fn shutdown_signal(handle: ServerHandle) { +async fn shutdown_signal(handles: Vec) { let ctrl_c = async { signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); }; @@ -480,10 +504,13 @@ async fn shutdown_signal(handle: ServerHandle) { } warn!("Received {}, shutting down...", sig); - handle.graceful_shutdown(Some(Duration::from_secs(30))); services().globals.shutdown(); + for handle in handles { + handle.graceful_shutdown(Some(Duration::from_secs(30))); + } + #[cfg(feature = "systemd")] sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]) .expect("should be able to notify systemd");