diff --git a/Cargo.lock b/Cargo.lock index 199a5971..6e0714f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,9 +210,8 @@ dependencies = [ [[package]] name = "axum-server" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab" +version = "0.7.2+grapevine-1" +source = "git+https://gitlab.computer.surgery/matrix/thirdparty/axum-server.git?rev=v0.7.2%2Bgrapevine-1#1f9b20296494792a1f09ab14689f3b2954b4f782" dependencies = [ "arc-swap", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 65761015..d8fb1488 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ argon2 = "0.5.3" async-trait = "0.1.88" axum = { version = "0.7.9", default-features = false, features = ["form", "http1", "http2", "json", "matched-path", "tokio", "tracing"] } axum-extra = { version = "0.9.5", features = ["typed-header"] } -axum-server = { version = "0.7.2", features = ["tls-rustls-no-provider"] } +axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-server.git", rev = "v0.7.2+grapevine-1", version = "0.7.2", features = ["tls-rustls-no-provider"] } base64 = "0.22.1" bytes = "1.10.1" clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } diff --git a/book/changelog.md b/book/changelog.md index ae9c8a6a..23ae40d6 100644 --- a/book/changelog.md +++ b/book/changelog.md @@ -329,3 +329,5 @@ This will be the first release of Grapevine since it was forked from Conduit ([!158](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/158)) 27. Grapevine now sends a User-Agent header on outbound requests ([!189](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/189)) +28. Added the ability to listen on Unix sockets + ([!187](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/187)) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index d08807b6..e1bc79c7 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -1,12 +1,13 @@ use std::{ - collections::HashSet, convert::Infallible, future::Future, net::SocketAddr, - sync::atomic, time::Duration, + collections::HashSet, convert::Infallible, future::Future, + net::SocketAddr as IpSocketAddr, + os::unix::net::SocketAddr as UnixSocketAddr, sync::atomic, time::Duration, }; use axum::{ extract::{ - connect_info::IntoMakeServiceWithConnectInfo, ConnectInfo, - DefaultBodyLimit, FromRequestParts, MatchedPath, + connect_info::{Connected, IntoMakeServiceWithConnectInfo}, + ConnectInfo, DefaultBodyLimit, FromRequestParts, MatchedPath, }, middleware::AddExtension, response::IntoResponse, @@ -18,7 +19,7 @@ use axum_server::{ bind, service::SendService, tls_rustls::{RustlsAcceptor, RustlsConfig}, - Handle as ServerHandle, Server, + Address, Server, }; use http::{ header::{self, HeaderName}, @@ -34,9 +35,9 @@ use ruma::api::{ federation::discovery::get_server_version, IncomingRequest, }; +use strum::Display; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpStream, signal, task::JoinSet, }; @@ -84,6 +85,9 @@ pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> { .map_err(Error::DatabaseError)?, )); + // This struct will remove old Unix sockets once it's dropped. + let _clean_up_socks = CleanUpUnixSockets(config.listen.clone()); + Services::new(db, config, Some(reload_handles)) .map_err(Error::InitializeServices)? .install(); @@ -104,6 +108,29 @@ pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> { Ok(()) } +struct CleanUpUnixSockets(Vec); + +impl Drop for CleanUpUnixSockets { + fn drop(&mut self) { + // Remove old Unix sockets + for listen in &self.0 { + if let ListenTransport::Unix { + path, + .. + } = &listen.transport + { + info!( + path = path.display().to_string(), + "Removing Unix socket" + ); + if let Err(error) = std::fs::remove_file(path) { + warn!(%error, "Couldn't remove Unix socket"); + } + } + } + } +} + #[tracing::instrument] async fn federation_self_test() -> Result<()> { let response = server_server::send_request( @@ -128,6 +155,48 @@ async fn federation_self_test() -> Result<()> { Ok(()) } +// A trait we'll implement on `axum_server::Handle` in order to be able to +// shutdown handles regardless of their generics. +trait ServerHandle: Send { + fn shutdown(&self, timeout: Option); +} + +impl ServerHandle for axum_server::Handle { + fn shutdown(&self, timeout: Option) { + self.graceful_shutdown(timeout); + } +} + +/// This type is needed to allow us to find out where incoming connections came +/// from. Before Unix socket support, we could simply use `IpSocketAddr` here, +/// but this is no longer possible. +#[derive(Clone, Display)] +enum AddrConnectInfo { + #[strum(to_string = "{0}")] + Ip(IpSocketAddr), + + #[strum(to_string = "[unix socket]")] + UnixSocket, + + #[strum(to_string = "[unknown]")] + Unknown, +} + +impl Connected for AddrConnectInfo { + fn connect_info(target: IpSocketAddr) -> Self { + Self::Ip(target) + } +} + +impl Connected for AddrConnectInfo { + fn connect_info(_target: UnixSocketAddr) -> Self { + // The `UnixSocketAddr` we get here is one that we can't recover the + // path from (`as_pathname` returns `None`), so there's no point + // in saving it (we only use all this for logging). + Self::UnixSocket + } +} + struct ServerSpawner<'cfg, M> { config: &'cfg Config, middlewares: M, @@ -135,7 +204,7 @@ struct ServerSpawner<'cfg, M> { tls_config: Option, proxy_config: ProxyAcceptorConfig, servers: JoinSet<(ListenConfig, std::io::Result<()>)>, - handles: Vec, + handles: Vec>, } impl<'cfg, M> ServerSpawner<'cfg, M> @@ -202,14 +271,20 @@ where |inner| ProxyAcceptor::new(inner, config) } - fn spawn_server_inner( + fn spawn_server_inner( &mut self, listen: ListenConfig, - server: Server, - app: IntoMakeServiceWithConnectInfo, + server: Server, + app: IntoMakeServiceWithConnectInfo, ) where - A: Accept>> - + Clone + AddrConnectInfo: Connected, + Addr: Address + Send + 'static, + Addr::Stream: Send, + Addr::Listener: Send, + A: Accept< + Addr::Stream, + AddExtension>, + > + Clone + Send + Sync + 'static, @@ -217,14 +292,14 @@ where A::Service: SendService> + Send, A::Future: Send, { - let handle = ServerHandle::new(); + let handle = axum_server::Handle::new(); let server = server.handle(handle.clone()).serve(app); self.servers.spawn(async move { let result = server.await; (listen, result) }); - self.handles.push(handle); + self.handles.push(Box::new(handle)); } fn spawn_server( @@ -233,16 +308,16 @@ where ) -> Result<(), error::Serve> { let app = routes(self.config, &listen.components) .layer(self.middlewares.clone()) - .into_make_service_with_connect_info::(); + .into_make_service_with_connect_info::(); - match listen.transport { + match &listen.transport { ListenTransport::Tcp { address, port, tls, proxy_protocol, } => { - let addr = SocketAddr::from((address, port)); + let addr = IpSocketAddr::from((*address, *port)); let server = bind(addr); match (tls, proxy_protocol) { @@ -266,6 +341,30 @@ where } } + Ok(()) + } + ListenTransport::Unix { + path, + proxy_protocol, + } => { + let addr = match UnixSocketAddr::from_pathname(path) { + Ok(addr) => addr, + Err(e) => { + // We can't use `map_err` here, as that would move + // `listen` into a closure, preventing us from using it + // later. + return Err(error::Serve::Listen(e, listen)); + } + }; + let server = bind(addr); + + if *proxy_protocol { + let server = server.map(self.proxy_acceptor_factory()); + self.spawn_server_inner(listen, server, app); + } else { + self.spawn_server_inner(listen, server, app); + } + Ok(()) } } @@ -303,21 +402,23 @@ async fn run_server() -> Result<(), error::Serve> { || { request .extensions() - .get::>() - .map(|&ConnectInfo(addr)| addr) + .get::>() + .map(|ConnectInfo(addr)| addr.clone()) }, - |h| h.proxied_address().map(|addr| addr.source), - ); + |h| { + h.proxied_address().map(|addr| { + AddrConnectInfo::Ip(addr.source) + }) + }, + ) + .unwrap_or(AddrConnectInfo::Unknown); tracing::info_span!( "http_request", otel.name = format!("{method} {endpoint}"), %method, %endpoint, - source_address = source_address.map_or( - "[unknown]".to_owned(), - |a| a.to_string() - ), + %source_address, ) }) .on_request( @@ -720,7 +821,7 @@ async fn reload_tls_config( async fn handle_signals( tls_config: Option, - handles: Vec, + handles: Vec>, ) { #[cfg(unix)] async fn wait_signal(sig: signal::unix::SignalKind) { @@ -769,7 +870,7 @@ async fn handle_signals( services().globals.shutdown(); for handle in handles { - handle.graceful_shutdown(Some(Duration::from_secs(30))); + handle.shutdown(Some(Duration::from_secs(30))); } set_application_state(ApplicationState::Stopping); diff --git a/src/config.rs b/src/config.rs index 9ad6e1ec..b597e4a9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -196,11 +196,16 @@ pub(crate) enum ListenTransport { #[serde(default = "false_fn")] proxy_protocol: bool, }, + Unix { + path: PathBuf, + #[serde(default = "false_fn")] + proxy_protocol: bool, + }, } impl Display for ListenTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { + match self { ListenTransport::Tcp { address, port, @@ -209,12 +214,12 @@ impl Display for ListenTransport { } => { let scheme = format!( "{}{}", - if proxy_protocol { + if *proxy_protocol { "proxy+" } else { "" }, - if tls { + if *tls { "https" } else { "http" @@ -222,6 +227,21 @@ impl Display for ListenTransport { ); write!(f, "{scheme}://{address}:{port}") } + ListenTransport::Unix { + path, + proxy_protocol, + } => { + write!( + f, + "{}http+unix://{}", + if *proxy_protocol { + "proxy+" + } else { + "" + }, + path.display() + ) + } } } } diff --git a/tests/integrations/check_config.rs b/tests/integrations/check_config.rs index 7db366e8..119acd03 100644 --- a/tests/integrations/check_config.rs +++ b/tests/integrations/check_config.rs @@ -111,3 +111,9 @@ make_snapshot_test!( "A config with the database path inside the media path fails", "database-in-media.toml", ); + +make_snapshot_test!( + unix_socket, + "A config listening to a Unix socket is valid", + "unix-socket.toml", +); diff --git a/tests/integrations/fixtures/check_config/unix-socket.toml b/tests/integrations/fixtures/check_config/unix-socket.toml new file mode 100644 index 00000000..fe088d5d --- /dev/null +++ b/tests/integrations/fixtures/check_config/unix-socket.toml @@ -0,0 +1,14 @@ +server_name = "example.com" +listen = [{ type = "unix", path = "/tmp/grapevine.sock" }] + +[server_discovery] +client.base_url = "https://matrix.example.com" + +[database] +backend = "rocksdb" +path = "tests/integrations/fixtures/check_config/dirs/a" + +[media.backend] +type = "filesystem" +path = "tests/integrations/fixtures/check_config/dirs/b" + diff --git a/tests/integrations/snapshots/integrations__check_config__unix_socket@status_code.snap b/tests/integrations/snapshots/integrations__check_config__unix_socket@status_code.snap new file mode 100644 index 00000000..2866558c --- /dev/null +++ b/tests/integrations/snapshots/integrations__check_config__unix_socket@status_code.snap @@ -0,0 +1,7 @@ +--- +source: tests/integrations/check_config.rs +description: A config listening to a Unix socket is valid +--- +Some( + 0, +) diff --git a/tests/integrations/snapshots/integrations__check_config__unix_socket@stderr.snap b/tests/integrations/snapshots/integrations__check_config__unix_socket@stderr.snap new file mode 100644 index 00000000..93e08914 --- /dev/null +++ b/tests/integrations/snapshots/integrations__check_config__unix_socket@stderr.snap @@ -0,0 +1,14 @@ +--- +source: tests/integrations/check_config.rs +description: A config listening to a Unix socket is valid +--- +[ + { + "fields": { + "message": "Configuration looks good" + }, + "level": "INFO", + "target": "grapevine::cli::check_config", + "timestamp": "[timestamp]" + } +] diff --git a/tests/integrations/snapshots/integrations__check_config__unix_socket@stdout.snap b/tests/integrations/snapshots/integrations__check_config__unix_socket@stdout.snap new file mode 100644 index 00000000..0b9549b9 --- /dev/null +++ b/tests/integrations/snapshots/integrations__check_config__unix_socket@stdout.snap @@ -0,0 +1,5 @@ +--- +source: tests/integrations/check_config.rs +description: A config listening to a Unix socket is valid +--- +