diff --git a/Cargo.toml b/Cargo.toml index 7904770e..b406358b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,7 +144,7 @@ trust-dns-resolver = "0.23.2" xdg = "2.5.2" [target.'cfg(unix)'.dependencies] -nix = { version = "0.29", features = ["resource"] } +nix = { version = "0.29", features = ["resource", "time"] } [features] default = ["rocksdb", "sqlite", "systemd"] diff --git a/book/changelog.md b/book/changelog.md index 377e9ce8..66506096 100644 --- a/book/changelog.md +++ b/book/changelog.md @@ -251,3 +251,6 @@ This will be the first release of Grapevine since it was forked from Conduit 20. Include the [`traceresponse` header](https://w3c.github.io/trace-context/#traceresponse-header) if OpenTelemetry Tracing is in use. ([!112](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/112)) +21. Sending SIGHUP to the grapevine process now reloads TLS certificates from + disk. + ([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97)) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index 38832e6b..bfde80cf 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -228,7 +228,7 @@ async fn run_server() -> Result<(), error::Serve> { set_application_state(ApplicationState::Ready); - tokio::spawn(shutdown_signal(handles)); + tokio::spawn(handle_signals(tls_config, handles)); while let Some(result) = servers.join_next().await { let (listen, result) = @@ -540,28 +540,72 @@ fn routes(config: &Config, components: &HashSet) -> Router { router.route("/", get(it_works)).fallback(not_found) } -async fn shutdown_signal(handles: Vec) { - let ctrl_c = async { - signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); - }; +async fn reload_tls_config( + tls_config: &RustlsConfig, +) -> Result<(), error::Serve> { + let config = services() + .globals + .config + .tls + .as_ref() + .expect("TLS config should exist if TLS listener exists"); + tls_config.reload_from_pem_file(&config.certs, &config.key).await.map_err( + |err| error::Serve::LoadCerts { + certs: config.certs.clone(), + key: config.key.clone(), + err, + }, + )?; + + Ok(()) +} + +async fn handle_signals( + tls_config: Option, + handles: Vec, +) { #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) + async fn wait_signal(sig: signal::unix::SignalKind) { + signal::unix::signal(sig) .expect("failed to install signal handler") .recv() .await; + } + + #[cfg(unix)] + let terminate = || wait_signal(signal::unix::SignalKind::terminate()); + #[cfg(not(unix))] + let terminate = || std::future::pending::<()>(); + + #[cfg(unix)] + let sighup = || wait_signal(signal::unix::SignalKind::hangup()); + #[cfg(not(unix))] + let sighup = || std::future::pending::<()>(); + + let ctrl_c = || async { + signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); }; - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let sig = loop { + tokio::select! { + () = sighup() => { + info!("Received reload request"); - let sig: &str; + set_application_state(ApplicationState::Reloading); - tokio::select! { - () = ctrl_c => { sig = "Ctrl+C"; }, - () = terminate => { sig = "SIGTERM"; }, - } + if let Some(tls_config) = tls_config.as_ref() { + if let Err(error) = reload_tls_config(tls_config).await { + error!(?error, "Failed to reload TLS config"); + } + } + + set_application_state(ApplicationState::Ready); + }, + () = terminate() => { break "SIGTERM"; }, + () = ctrl_c() => { break "Ctrl+C"; }, + } + }; warn!(signal = %sig, "Shutting down due to signal"); diff --git a/src/main.rs b/src/main.rs index e0365fc3..dfef8616 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,6 +44,7 @@ fn version() -> String { #[derive(Debug, Clone, Copy)] enum ApplicationState { Ready, + Reloading, Stopping, } @@ -61,6 +62,21 @@ fn set_application_state(state: ApplicationState) { match state { ApplicationState::Ready => notify(&[NotifyState::Ready]), + ApplicationState::Reloading => { + let timespec = nix::time::clock_gettime( + nix::time::ClockId::CLOCK_MONOTONIC, + ) + .expect("CLOCK_MONOTONIC should be usable"); + let monotonic_usec = + timespec.tv_sec() * 1_000_000 + timespec.tv_nsec() / 1000; + + notify(&[ + NotifyState::Reloading, + NotifyState::Custom(&format!( + "MONOTONIC_USEC={monotonic_usec}", + )), + ]); + } ApplicationState::Stopping => notify(&[NotifyState::Stopping]), }; }