Add support for HAProxy proxy protocol for listeners

This commit is contained in:
Lambda 2024-09-15 15:55:47 +00:00
parent 99f3e2aecd
commit 3247c64cd8
7 changed files with 234 additions and 14 deletions

12
Cargo.lock generated
View file

@ -827,7 +827,9 @@ dependencies = [
"opentelemetry_sdk",
"parking_lot",
"phf",
"pin-project-lite",
"prometheus",
"proxy-header",
"rand",
"regex",
"reqwest",
@ -1856,6 +1858,16 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "proxy-header"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc1493f63ddddfba840c3169e997c2905d09538ace72d64e84af6324c6e0e065"
dependencies = [
"pin-project-lite",
"tokio",
]
[[package]]
name = "quick-error"
version = "1.2.3"

View file

@ -113,7 +113,9 @@ opentelemetry-prometheus = "0.17.0"
opentelemetry_sdk = { version = "0.24.0", features = ["rt-tokio"] }
parking_lot = { version = "0.12.3", optional = true }
phf = { version = "0.11.2", features = ["macros"] }
pin-project-lite = "0.2.14"
prometheus = "0.13.4"
proxy-header = { version = "0.1.2", features = ["tokio"] }
rand = "0.8.5"
regex = "1.10.6"
reqwest = { version = "0.12.7", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }

View file

@ -269,3 +269,6 @@ This will be the first release of Grapevine since it was forked from Conduit
([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97))
22. Added a federation self-test, perfomed automatically on startup.
([!106](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/106))
23. Added support for HAProxy [proxy protocol](http://www.haproxy.org/download/3.0/doc/proxy-protocol.txt)
listeners.
([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97))

View file

@ -57,6 +57,7 @@ use crate::{
utils::{
self,
error::{Error, Result},
proxy_protocol::{ProxyAcceptor, ProxyAcceptorConfig},
},
ApplicationState, Services,
};
@ -138,6 +139,7 @@ struct ServerSpawner<'cfg, M> {
middlewares: M,
tls_config: Option<RustlsConfig>,
proxy_config: ProxyAcceptorConfig,
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
handles: Vec<ServerHandle>,
}
@ -172,10 +174,13 @@ where
None
};
let proxy_config = ProxyAcceptorConfig::default();
Ok(Self {
config,
middlewares,
tls_config,
proxy_config,
servers: JoinSet::new(),
handles: Vec::new(),
})
@ -195,6 +200,14 @@ where
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
}
/// Returns a function that transforms a lower-layer acceptor into a Proxy
/// Protocol acceptor.
fn proxy_acceptor_factory<A>(&self) -> impl FnOnce(A) -> ProxyAcceptor<A> {
let config = self.proxy_config.clone();
|inner| ProxyAcceptor::new(inner, config)
}
fn spawn_server_inner<A>(
&mut self,
listen: ListenConfig,
@ -229,16 +242,30 @@ where
address,
port,
tls,
proxy_protocol,
} => {
let addr = SocketAddr::from((address, port));
let server = bind(addr);
if tls {
let server =
server.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
} else {
self.spawn_server_inner(listen, server, app);
match (tls, proxy_protocol) {
(false, false) => {
self.spawn_server_inner(listen, server, app);
}
(false, true) => {
let server = server.map(self.proxy_acceptor_factory());
self.spawn_server_inner(listen, server, app);
}
(true, false) => {
let server =
server.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
}
(true, true) => {
let server = server
.map(self.proxy_acceptor_factory())
.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
}
}
Ok(())

View file

@ -137,22 +137,35 @@ pub(crate) enum ListenTransport {
port: u16,
#[serde(default = "false_fn")]
tls: bool,
#[serde(default = "false_fn")]
proxy_protocol: bool,
},
}
impl Display for ListenTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
match *self {
ListenTransport::Tcp {
address,
port,
tls: false,
} => write!(f, "http://{address}:{port}"),
ListenTransport::Tcp {
address,
port,
tls: true,
} => write!(f, "https://{address}:{port}"),
tls,
proxy_protocol,
} => {
let scheme = format!(
"{}{}",
if proxy_protocol {
"proxy+"
} else {
""
},
if tls {
"https"
} else {
"http"
}
);
write!(f, "{scheme}://{address}:{port}")
}
}
}
}
@ -379,6 +392,7 @@ fn default_listen() -> Vec<ListenConfig> {
address: default_address(),
port: default_port(),
tls: false,
proxy_protocol: false,
},
}]
}

View file

@ -19,6 +19,7 @@ use crate::{Error, Result};
pub(crate) mod error;
pub(crate) mod on_demand_hashmap;
pub(crate) mod proxy_protocol;
pub(crate) mod room_version;
// Hopefully we have a better chat protocol in 530 years

161
src/utils/proxy_protocol.rs Normal file
View file

@ -0,0 +1,161 @@
use std::{
future::Future, io::ErrorKind, pin::Pin, task::Poll, time::Duration,
};
use axum::{middleware::AddExtension, Extension};
use axum_server::accept::Accept;
use pin_project_lite::pin_project;
use proxy_header::{io::ProxiedStream, ParseConfig, ProxyHeader};
use tokio::{
io::AsyncRead,
time::{timeout, Timeout},
};
use tower::Layer;
use tracing::warn;
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptorConfig {
pub(crate) header_timeout: Duration,
pub(crate) parse_config: ParseConfig,
}
impl Default for ProxyAcceptorConfig {
fn default() -> Self {
Self {
header_timeout: Duration::from_secs(5),
parse_config: ParseConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptor<A> {
inner: A,
config: ProxyAcceptorConfig,
}
impl<A> ProxyAcceptor<A> {
pub(crate) fn new(inner: A, config: ProxyAcceptorConfig) -> Self {
Self {
inner,
config,
}
}
}
impl<A, I, S> Accept<I, S> for ProxyAcceptor<A>
where
A: Accept<I, S>,
A::Stream: AsyncRead + Unpin + Send + 'static,
{
type Future = AcceptorFuture<A::Future, A::Stream, A::Service>;
type Service = AddExtension<A::Service, ProxyHeader<'static>>;
type Stream = ProxiedStream<A::Stream>;
fn accept(&self, stream: I, service: S) -> Self::Future {
let inner_future = self.inner.accept(stream, service);
let config = self.config.clone();
AcceptorFuture::new(inner_future, config)
}
}
/// Future returned by [`ProxiedStream::create_from_tokio()`].
type StreamFuture<I> =
Pin<Box<dyn Future<Output = std::io::Result<ProxiedStream<I>>> + Send>>;
pin_project! {
#[project = AcceptorFutureProj]
pub(crate) enum AcceptorFuture<F, I, S> {
Inner {
#[pin]
inner_future: F,
config: ProxyAcceptorConfig,
},
Proxy {
#[pin]
stream_future: Timeout<StreamFuture<I>>,
service: Option<S>,
},
}
}
impl<F, I, S> AcceptorFuture<F, I, S> {
fn new(inner_future: F, config: ProxyAcceptorConfig) -> Self {
Self::Inner {
inner_future,
config,
}
}
}
impl<F, I, S> Future for AcceptorFuture<F, I, S>
where
F: Future<Output = std::io::Result<(I, S)>>,
I: AsyncRead + Unpin + Send + 'static,
{
type Output = std::io::Result<(
ProxiedStream<I>,
AddExtension<S, ProxyHeader<'static>>,
)>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
match self.as_mut().project() {
AcceptorFutureProj::Inner {
inner_future,
config,
} => {
let Poll::Ready((stream, service)) =
inner_future.poll(cx)?
else {
return Poll::Pending;
};
let stream_future: StreamFuture<I> =
Box::pin(ProxiedStream::create_from_tokio(
stream,
config.parse_config,
));
let stream_future =
timeout(config.header_timeout, stream_future);
self.set(AcceptorFuture::Proxy {
stream_future,
service: Some(service),
});
}
AcceptorFutureProj::Proxy {
stream_future,
service,
} => {
let Poll::Ready(ret) = stream_future.poll(cx) else {
return Poll::Pending;
};
let stream = ret
.map_err(|e| {
warn!(
"Timed out waiting for HAProxy protocol header"
);
std::io::Error::new(ErrorKind::TimedOut, e)
})?
.inspect_err(|error| {
warn!(%error, "Failed to parse HAProxy protocol header");
})?;
let service =
Extension(stream.proxy_header().clone().into_owned())
.layer(service.take().expect(
"future should not be polled after ready",
));
return Poll::Ready(Ok((stream, service)));
}
}
}
}
}