From 3247c64cd89c349008febb0a619778646a64522e Mon Sep 17 00:00:00 2001 From: Lambda Date: Sun, 15 Sep 2024 15:55:47 +0000 Subject: [PATCH] Add support for HAProxy proxy protocol for listeners --- Cargo.lock | 12 +++ Cargo.toml | 2 + book/changelog.md | 3 + src/cli/serve.rs | 39 +++++++-- src/config.rs | 30 +++++-- src/utils.rs | 1 + src/utils/proxy_protocol.rs | 161 ++++++++++++++++++++++++++++++++++++ 7 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 src/utils/proxy_protocol.rs diff --git a/Cargo.lock b/Cargo.lock index caff42d9..65aa7f4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -827,7 +827,9 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "phf", + "pin-project-lite", "prometheus", + "proxy-header", "rand", "regex", "reqwest", @@ -1856,6 +1858,16 @@ version = "2.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" +[[package]] +name = "proxy-header" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1493f63ddddfba840c3169e997c2905d09538ace72d64e84af6324c6e0e065" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "quick-error" version = "1.2.3" diff --git a/Cargo.toml b/Cargo.toml index b406358b..eebcbd2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,7 +113,9 @@ opentelemetry-prometheus = "0.17.0" opentelemetry_sdk = { version = "0.24.0", features = ["rt-tokio"] } parking_lot = { version = "0.12.3", optional = true } phf = { version = "0.11.2", features = ["macros"] } +pin-project-lite = "0.2.14" prometheus = "0.13.4" +proxy-header = { version = "0.1.2", features = ["tokio"] } rand = "0.8.5" regex = "1.10.6" reqwest = { version = "0.12.7", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } diff --git a/book/changelog.md b/book/changelog.md index cb2e70af..a8951562 100644 --- a/book/changelog.md +++ b/book/changelog.md @@ -269,3 +269,6 @@ This will be the first release of Grapevine since it was forked from Conduit ([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97)) 22. Added a federation self-test, perfomed automatically on startup. ([!106](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/106)) +23. Added support for HAProxy [proxy protocol](http://www.haproxy.org/download/3.0/doc/proxy-protocol.txt) + listeners. + ([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97)) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index e0612576..aa507522 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -57,6 +57,7 @@ use crate::{ utils::{ self, error::{Error, Result}, + proxy_protocol::{ProxyAcceptor, ProxyAcceptorConfig}, }, ApplicationState, Services, }; @@ -138,6 +139,7 @@ struct ServerSpawner<'cfg, M> { middlewares: M, tls_config: Option, + proxy_config: ProxyAcceptorConfig, servers: JoinSet<(ListenConfig, std::io::Result<()>)>, handles: Vec, } @@ -172,10 +174,13 @@ where None }; + let proxy_config = ProxyAcceptorConfig::default(); + Ok(Self { config, middlewares, tls_config, + proxy_config, servers: JoinSet::new(), handles: Vec::new(), }) @@ -195,6 +200,14 @@ where Ok(|inner| RustlsAcceptor::new(config).acceptor(inner)) } + /// Returns a function that transforms a lower-layer acceptor into a Proxy + /// Protocol acceptor. + fn proxy_acceptor_factory(&self) -> impl FnOnce(A) -> ProxyAcceptor { + let config = self.proxy_config.clone(); + + |inner| ProxyAcceptor::new(inner, config) + } + fn spawn_server_inner( &mut self, listen: ListenConfig, @@ -229,16 +242,30 @@ where address, port, tls, + proxy_protocol, } => { let addr = SocketAddr::from((address, port)); let server = bind(addr); - if tls { - let server = - server.map(self.tls_acceptor_factory(&listen)?); - self.spawn_server_inner(listen, server, app); - } else { - self.spawn_server_inner(listen, server, app); + match (tls, proxy_protocol) { + (false, false) => { + self.spawn_server_inner(listen, server, app); + } + (false, true) => { + let server = server.map(self.proxy_acceptor_factory()); + self.spawn_server_inner(listen, server, app); + } + (true, false) => { + let server = + server.map(self.tls_acceptor_factory(&listen)?); + self.spawn_server_inner(listen, server, app); + } + (true, true) => { + let server = server + .map(self.proxy_acceptor_factory()) + .map(self.tls_acceptor_factory(&listen)?); + self.spawn_server_inner(listen, server, app); + } } Ok(()) diff --git a/src/config.rs b/src/config.rs index bf08c7c1..4c3d5a44 100644 --- a/src/config.rs +++ b/src/config.rs @@ -137,22 +137,35 @@ pub(crate) enum ListenTransport { port: u16, #[serde(default = "false_fn")] tls: bool, + #[serde(default = "false_fn")] + proxy_protocol: bool, }, } impl Display for ListenTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { + match *self { ListenTransport::Tcp { address, port, - tls: false, - } => write!(f, "http://{address}:{port}"), - ListenTransport::Tcp { - address, - port, - tls: true, - } => write!(f, "https://{address}:{port}"), + tls, + proxy_protocol, + } => { + let scheme = format!( + "{}{}", + if proxy_protocol { + "proxy+" + } else { + "" + }, + if tls { + "https" + } else { + "http" + } + ); + write!(f, "{scheme}://{address}:{port}") + } } } } @@ -379,6 +392,7 @@ fn default_listen() -> Vec { address: default_address(), port: default_port(), tls: false, + proxy_protocol: false, }, }] } diff --git a/src/utils.rs b/src/utils.rs index 1eca4b4b..45211399 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,6 +19,7 @@ use crate::{Error, Result}; pub(crate) mod error; pub(crate) mod on_demand_hashmap; +pub(crate) mod proxy_protocol; pub(crate) mod room_version; // Hopefully we have a better chat protocol in 530 years diff --git a/src/utils/proxy_protocol.rs b/src/utils/proxy_protocol.rs new file mode 100644 index 00000000..53021755 --- /dev/null +++ b/src/utils/proxy_protocol.rs @@ -0,0 +1,161 @@ +use std::{ + future::Future, io::ErrorKind, pin::Pin, task::Poll, time::Duration, +}; + +use axum::{middleware::AddExtension, Extension}; +use axum_server::accept::Accept; +use pin_project_lite::pin_project; +use proxy_header::{io::ProxiedStream, ParseConfig, ProxyHeader}; +use tokio::{ + io::AsyncRead, + time::{timeout, Timeout}, +}; +use tower::Layer; +use tracing::warn; + +#[derive(Debug, Clone)] +pub(crate) struct ProxyAcceptorConfig { + pub(crate) header_timeout: Duration, + pub(crate) parse_config: ParseConfig, +} + +impl Default for ProxyAcceptorConfig { + fn default() -> Self { + Self { + header_timeout: Duration::from_secs(5), + parse_config: ParseConfig::default(), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ProxyAcceptor { + inner: A, + config: ProxyAcceptorConfig, +} + +impl ProxyAcceptor { + pub(crate) fn new(inner: A, config: ProxyAcceptorConfig) -> Self { + Self { + inner, + config, + } + } +} + +impl Accept for ProxyAcceptor +where + A: Accept, + A::Stream: AsyncRead + Unpin + Send + 'static, +{ + type Future = AcceptorFuture; + type Service = AddExtension>; + type Stream = ProxiedStream; + + fn accept(&self, stream: I, service: S) -> Self::Future { + let inner_future = self.inner.accept(stream, service); + let config = self.config.clone(); + + AcceptorFuture::new(inner_future, config) + } +} + +/// Future returned by [`ProxiedStream::create_from_tokio()`]. +type StreamFuture = + Pin>> + Send>>; + +pin_project! { + #[project = AcceptorFutureProj] + pub(crate) enum AcceptorFuture { + Inner { + #[pin] + inner_future: F, + config: ProxyAcceptorConfig, + }, + Proxy { + #[pin] + stream_future: Timeout>, + service: Option, + }, + } +} + +impl AcceptorFuture { + fn new(inner_future: F, config: ProxyAcceptorConfig) -> Self { + Self::Inner { + inner_future, + config, + } + } +} + +impl Future for AcceptorFuture +where + F: Future>, + I: AsyncRead + Unpin + Send + 'static, +{ + type Output = std::io::Result<( + ProxiedStream, + AddExtension>, + )>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + loop { + match self.as_mut().project() { + AcceptorFutureProj::Inner { + inner_future, + config, + } => { + let Poll::Ready((stream, service)) = + inner_future.poll(cx)? + else { + return Poll::Pending; + }; + + let stream_future: StreamFuture = + Box::pin(ProxiedStream::create_from_tokio( + stream, + config.parse_config, + )); + let stream_future = + timeout(config.header_timeout, stream_future); + + self.set(AcceptorFuture::Proxy { + stream_future, + service: Some(service), + }); + } + AcceptorFutureProj::Proxy { + stream_future, + service, + } => { + let Poll::Ready(ret) = stream_future.poll(cx) else { + return Poll::Pending; + }; + + let stream = ret + .map_err(|e| { + warn!( + "Timed out waiting for HAProxy protocol header" + ); + std::io::Error::new(ErrorKind::TimedOut, e) + })? + .inspect_err(|error| { + warn!(%error, "Failed to parse HAProxy protocol header"); + })?; + + let service = + Extension(stream.proxy_header().clone().into_owned()) + .layer(service.take().expect( + "future should not be polled after ready", + )); + + return Poll::Ready(Ok((stream, service))); + } + } + } + } +}