Allow configuring served components per listener

This commit is contained in:
Lambda 2024-09-21 15:55:17 +00:00 committed by Charles Hall
parent d62d0e2f0e
commit 084d862e51
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
3 changed files with 89 additions and 27 deletions

View file

@ -1,4 +1,7 @@
use std::{future::Future, net::SocketAddr, sync::atomic, time::Duration};
use std::{
collections::HashSet, future::Future, net::SocketAddr, sync::atomic,
time::Duration,
};
use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
@ -38,7 +41,7 @@ use crate::{
server_server, well_known,
},
config,
config::{Config, ListenConfig},
config::{Config, ListenComponent, ListenTransport},
database::KeyValueDatabase,
error, observability, services, utils,
utils::error::{Error, Result},
@ -148,7 +151,6 @@ async fn run_server() -> Result<(), error::Serve> {
))
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
let app = routes(config).layer(middlewares).into_make_service();
let mut handles = Vec::new();
let mut servers = JoinSet::new();
@ -170,25 +172,29 @@ async fn run_server() -> Result<(), error::Serve> {
for listen in &config.listen {
info!(listener = %listen, "Listening for incoming traffic");
match listen {
ListenConfig::Tcp {
let app = routes(config, &listen.components)
.layer(middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((*address, *port));
let addr = SocketAddr::from((address, port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if *tls {
let server = if tls {
let tls_config = tls_config
.clone()
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.clone())
.serve(app)
.left_future()
} else {
bind(addr).handle(handle).serve(app.clone()).right_future()
bind(addr).handle(handle).serve(app).right_future()
};
servers.spawn(
server.then(|result| async { (listen.clone(), result) }),
@ -493,15 +499,24 @@ fn well_known_routes() -> Router {
.route("/.well-known/matrix/server", get(well_known::server))
}
fn routes(config: &Config) -> Router {
Router::new()
.merge(client_routes())
.merge(federation_routes(config))
.merge(legacy_media_routes(config))
.merge(well_known_routes())
.merge(metrics_routes(config))
.route("/", get(it_works))
.fallback(not_found)
fn routes(config: &Config, components: &HashSet<ListenComponent>) -> Router {
let mut router = Router::new();
for &component in components {
router = router.merge(match component {
ListenComponent::Client => client_routes(),
ListenComponent::Federation => federation_routes(config),
ListenComponent::Metrics => metrics_routes(config),
ListenComponent::WellKnown => well_known_routes(),
});
}
if components.contains(&ListenComponent::Client)
|| components.contains(&ListenComponent::Federation)
{
router = router.merge(legacy_media_routes(config));
}
router.route("/", get(it_works)).fallback(not_found)
}
async fn shutdown_signal(handles: Vec<ServerHandle>) {