Add support for HAProxy proxy protocol for listeners

This commit is contained in:
Lambda 2024-09-15 15:55:47 +00:00
parent 99f3e2aecd
commit 3247c64cd8
7 changed files with 234 additions and 14 deletions

161
src/utils/proxy_protocol.rs Normal file
View file

@ -0,0 +1,161 @@
use std::{
future::Future, io::ErrorKind, pin::Pin, task::Poll, time::Duration,
};
use axum::{middleware::AddExtension, Extension};
use axum_server::accept::Accept;
use pin_project_lite::pin_project;
use proxy_header::{io::ProxiedStream, ParseConfig, ProxyHeader};
use tokio::{
io::AsyncRead,
time::{timeout, Timeout},
};
use tower::Layer;
use tracing::warn;
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptorConfig {
pub(crate) header_timeout: Duration,
pub(crate) parse_config: ParseConfig,
}
impl Default for ProxyAcceptorConfig {
fn default() -> Self {
Self {
header_timeout: Duration::from_secs(5),
parse_config: ParseConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ProxyAcceptor<A> {
inner: A,
config: ProxyAcceptorConfig,
}
impl<A> ProxyAcceptor<A> {
pub(crate) fn new(inner: A, config: ProxyAcceptorConfig) -> Self {
Self {
inner,
config,
}
}
}
impl<A, I, S> Accept<I, S> for ProxyAcceptor<A>
where
A: Accept<I, S>,
A::Stream: AsyncRead + Unpin + Send + 'static,
{
type Future = AcceptorFuture<A::Future, A::Stream, A::Service>;
type Service = AddExtension<A::Service, ProxyHeader<'static>>;
type Stream = ProxiedStream<A::Stream>;
fn accept(&self, stream: I, service: S) -> Self::Future {
let inner_future = self.inner.accept(stream, service);
let config = self.config.clone();
AcceptorFuture::new(inner_future, config)
}
}
/// Future returned by [`ProxiedStream::create_from_tokio()`].
type StreamFuture<I> =
Pin<Box<dyn Future<Output = std::io::Result<ProxiedStream<I>>> + Send>>;
pin_project! {
#[project = AcceptorFutureProj]
pub(crate) enum AcceptorFuture<F, I, S> {
Inner {
#[pin]
inner_future: F,
config: ProxyAcceptorConfig,
},
Proxy {
#[pin]
stream_future: Timeout<StreamFuture<I>>,
service: Option<S>,
},
}
}
impl<F, I, S> AcceptorFuture<F, I, S> {
fn new(inner_future: F, config: ProxyAcceptorConfig) -> Self {
Self::Inner {
inner_future,
config,
}
}
}
impl<F, I, S> Future for AcceptorFuture<F, I, S>
where
F: Future<Output = std::io::Result<(I, S)>>,
I: AsyncRead + Unpin + Send + 'static,
{
type Output = std::io::Result<(
ProxiedStream<I>,
AddExtension<S, ProxyHeader<'static>>,
)>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
match self.as_mut().project() {
AcceptorFutureProj::Inner {
inner_future,
config,
} => {
let Poll::Ready((stream, service)) =
inner_future.poll(cx)?
else {
return Poll::Pending;
};
let stream_future: StreamFuture<I> =
Box::pin(ProxiedStream::create_from_tokio(
stream,
config.parse_config,
));
let stream_future =
timeout(config.header_timeout, stream_future);
self.set(AcceptorFuture::Proxy {
stream_future,
service: Some(service),
});
}
AcceptorFutureProj::Proxy {
stream_future,
service,
} => {
let Poll::Ready(ret) = stream_future.poll(cx) else {
return Poll::Pending;
};
let stream = ret
.map_err(|e| {
warn!(
"Timed out waiting for HAProxy protocol header"
);
std::io::Error::new(ErrorKind::TimedOut, e)
})?
.inspect_err(|error| {
warn!(%error, "Failed to parse HAProxy protocol header");
})?;
let service =
Extension(stream.proxy_header().clone().into_owned())
.layer(service.take().expect(
"future should not be polled after ready",
));
return Poll::Ready(Ok((stream, service)));
}
}
}
}
}