Add support for HAProxy proxy protocol for listeners

This commit is contained in:
Lambda 2024-09-15 15:55:47 +00:00
parent 99f3e2aecd
commit 3247c64cd8
7 changed files with 234 additions and 14 deletions

12
Cargo.lock generated
View file

@ -827,7 +827,9 @@ dependencies = [
"opentelemetry_sdk", "opentelemetry_sdk",
"parking_lot", "parking_lot",
"phf", "phf",
"pin-project-lite",
"prometheus", "prometheus",
"proxy-header",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -1856,6 +1858,16 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "proxy-header"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc1493f63ddddfba840c3169e997c2905d09538ace72d64e84af6324c6e0e065"
dependencies = [
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "1.2.3" version = "1.2.3"

View file

@ -113,7 +113,9 @@ opentelemetry-prometheus = "0.17.0"
opentelemetry_sdk = { version = "0.24.0", features = ["rt-tokio"] } opentelemetry_sdk = { version = "0.24.0", features = ["rt-tokio"] }
parking_lot = { version = "0.12.3", optional = true } parking_lot = { version = "0.12.3", optional = true }
phf = { version = "0.11.2", features = ["macros"] } phf = { version = "0.11.2", features = ["macros"] }
pin-project-lite = "0.2.14"
prometheus = "0.13.4" prometheus = "0.13.4"
proxy-header = { version = "0.1.2", features = ["tokio"] }
rand = "0.8.5" rand = "0.8.5"
regex = "1.10.6" regex = "1.10.6"
reqwest = { version = "0.12.7", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } reqwest = { version = "0.12.7", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }

View file

@ -269,3 +269,6 @@ This will be the first release of Grapevine since it was forked from Conduit
([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97)) ([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97))
22. Added a federation self-test, perfomed automatically on startup. 22. Added a federation self-test, perfomed automatically on startup.
([!106](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/106)) ([!106](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/106))
23. Added support for HAProxy [proxy protocol](http://www.haproxy.org/download/3.0/doc/proxy-protocol.txt)
listeners.
([!97](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/97))

View file

@ -57,6 +57,7 @@ use crate::{
utils::{ utils::{
self, self,
error::{Error, Result}, error::{Error, Result},
proxy_protocol::{ProxyAcceptor, ProxyAcceptorConfig},
}, },
ApplicationState, Services, ApplicationState, Services,
}; };
@ -138,6 +139,7 @@ struct ServerSpawner<'cfg, M> {
middlewares: M, middlewares: M,
tls_config: Option<RustlsConfig>, tls_config: Option<RustlsConfig>,
proxy_config: ProxyAcceptorConfig,
servers: JoinSet<(ListenConfig, std::io::Result<()>)>, servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
handles: Vec<ServerHandle>, handles: Vec<ServerHandle>,
} }
@ -172,10 +174,13 @@ where
None None
}; };
let proxy_config = ProxyAcceptorConfig::default();
Ok(Self { Ok(Self {
config, config,
middlewares, middlewares,
tls_config, tls_config,
proxy_config,
servers: JoinSet::new(), servers: JoinSet::new(),
handles: Vec::new(), handles: Vec::new(),
}) })
@ -195,6 +200,14 @@ where
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner)) Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
} }
/// Returns a function that transforms a lower-layer acceptor into a Proxy
/// Protocol acceptor.
fn proxy_acceptor_factory<A>(&self) -> impl FnOnce(A) -> ProxyAcceptor<A> {
let config = self.proxy_config.clone();
|inner| ProxyAcceptor::new(inner, config)
}
fn spawn_server_inner<A>( fn spawn_server_inner<A>(
&mut self, &mut self,
listen: ListenConfig, listen: ListenConfig,
@ -229,16 +242,30 @@ where
address, address,
port, port,
tls, tls,
proxy_protocol,
} => { } => {
let addr = SocketAddr::from((address, port)); let addr = SocketAddr::from((address, port));
let server = bind(addr); let server = bind(addr);
if tls { match (tls, proxy_protocol) {
let server = (false, false) => {
server.map(self.tls_acceptor_factory(&listen)?); self.spawn_server_inner(listen, server, app);
self.spawn_server_inner(listen, server, app); }
} else { (false, true) => {
self.spawn_server_inner(listen, server, app); let server = server.map(self.proxy_acceptor_factory());
self.spawn_server_inner(listen, server, app);
}
(true, false) => {
let server =
server.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
}
(true, true) => {
let server = server
.map(self.proxy_acceptor_factory())
.map(self.tls_acceptor_factory(&listen)?);
self.spawn_server_inner(listen, server, app);
}
} }
Ok(()) Ok(())

View file

@ -137,22 +137,35 @@ pub(crate) enum ListenTransport {
port: u16, port: u16,
#[serde(default = "false_fn")] #[serde(default = "false_fn")]
tls: bool, tls: bool,
#[serde(default = "false_fn")]
proxy_protocol: bool,
}, },
} }
impl Display for ListenTransport { impl Display for ListenTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match *self {
ListenTransport::Tcp { ListenTransport::Tcp {
address, address,
port, port,
tls: false, tls,
} => write!(f, "http://{address}:{port}"), proxy_protocol,
ListenTransport::Tcp { } => {
address, let scheme = format!(
port, "{}{}",
tls: true, if proxy_protocol {
} => write!(f, "https://{address}:{port}"), "proxy+"
} else {
""
},
if tls {
"https"
} else {
"http"
}
);
write!(f, "{scheme}://{address}:{port}")
}
} }
} }
} }
@ -379,6 +392,7 @@ fn default_listen() -> Vec<ListenConfig> {
address: default_address(), address: default_address(),
port: default_port(), port: default_port(),
tls: false, tls: false,
proxy_protocol: false,
}, },
}] }]
} }

View file

@ -19,6 +19,7 @@ use crate::{Error, Result};
pub(crate) mod error; pub(crate) mod error;
pub(crate) mod on_demand_hashmap; pub(crate) mod on_demand_hashmap;
pub(crate) mod proxy_protocol;
pub(crate) mod room_version; pub(crate) mod room_version;
// Hopefully we have a better chat protocol in 530 years // Hopefully we have a better chat protocol in 530 years

161
src/utils/proxy_protocol.rs Normal file
View file

@ -0,0 +1,161 @@
use std::{
future::Future, io::ErrorKind, pin::Pin, task::Poll, time::Duration,
};
use axum::{middleware::AddExtension, Extension};
use axum_server::accept::Accept;
use pin_project_lite::pin_project;
use proxy_header::{io::ProxiedStream, ParseConfig, ProxyHeader};
use tokio::{
io::AsyncRead,
time::{timeout, Timeout},
};
use tower::Layer;
use tracing::warn;
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptorConfig {
pub(crate) header_timeout: Duration,
pub(crate) parse_config: ParseConfig,
}
impl Default for ProxyAcceptorConfig {
fn default() -> Self {
Self {
header_timeout: Duration::from_secs(5),
parse_config: ParseConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptor<A> {
inner: A,
config: ProxyAcceptorConfig,
}
impl<A> ProxyAcceptor<A> {
pub(crate) fn new(inner: A, config: ProxyAcceptorConfig) -> Self {
Self {
inner,
config,
}
}
}
impl<A, I, S> Accept<I, S> for ProxyAcceptor<A>
where
A: Accept<I, S>,
A::Stream: AsyncRead + Unpin + Send + 'static,
{
type Future = AcceptorFuture<A::Future, A::Stream, A::Service>;
type Service = AddExtension<A::Service, ProxyHeader<'static>>;
type Stream = ProxiedStream<A::Stream>;
fn accept(&self, stream: I, service: S) -> Self::Future {
let inner_future = self.inner.accept(stream, service);
let config = self.config.clone();
AcceptorFuture::new(inner_future, config)
}
}
/// Future returned by [`ProxiedStream::create_from_tokio()`].
type StreamFuture<I> =
Pin<Box<dyn Future<Output = std::io::Result<ProxiedStream<I>>> + Send>>;
pin_project! {
#[project = AcceptorFutureProj]
pub(crate) enum AcceptorFuture<F, I, S> {
Inner {
#[pin]
inner_future: F,
config: ProxyAcceptorConfig,
},
Proxy {
#[pin]
stream_future: Timeout<StreamFuture<I>>,
service: Option<S>,
},
}
}
impl<F, I, S> AcceptorFuture<F, I, S> {
fn new(inner_future: F, config: ProxyAcceptorConfig) -> Self {
Self::Inner {
inner_future,
config,
}
}
}
impl<F, I, S> Future for AcceptorFuture<F, I, S>
where
F: Future<Output = std::io::Result<(I, S)>>,
I: AsyncRead + Unpin + Send + 'static,
{
type Output = std::io::Result<(
ProxiedStream<I>,
AddExtension<S, ProxyHeader<'static>>,
)>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
match self.as_mut().project() {
AcceptorFutureProj::Inner {
inner_future,
config,
} => {
let Poll::Ready((stream, service)) =
inner_future.poll(cx)?
else {
return Poll::Pending;
};
let stream_future: StreamFuture<I> =
Box::pin(ProxiedStream::create_from_tokio(
stream,
config.parse_config,
));
let stream_future =
timeout(config.header_timeout, stream_future);
self.set(AcceptorFuture::Proxy {
stream_future,
service: Some(service),
});
}
AcceptorFutureProj::Proxy {
stream_future,
service,
} => {
let Poll::Ready(ret) = stream_future.poll(cx) else {
return Poll::Pending;
};
let stream = ret
.map_err(|e| {
warn!(
"Timed out waiting for HAProxy protocol header"
);
std::io::Error::new(ErrorKind::TimedOut, e)
})?
.inspect_err(|error| {
warn!(%error, "Failed to parse HAProxy protocol header");
})?;
let service =
Extension(stream.proxy_header().clone().into_owned())
.layer(service.take().expect(
"future should not be polled after ready",
));
return Poll::Ready(Ok((stream, service)));
}
}
}
}
}