Allow configuring served components per listener

This commit is contained in:
Lambda 2024-09-21 15:55:17 +00:00 committed by Charles Hall
parent d62d0e2f0e
commit 084d862e51
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
3 changed files with 89 additions and 27 deletions

View file

@ -246,3 +246,5 @@ This will be the first release of Grapevine since it was forked from Conduit
18. Added admin commands to delete media 18. Added admin commands to delete media
([!99](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/99), ([!99](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/99),
[!102](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/102)) [!102](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/102))
19. Allow configuring the served API components per listener.
([!109](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/109))

View file

@ -1,4 +1,7 @@
use std::{future::Future, net::SocketAddr, sync::atomic, time::Duration}; use std::{
collections::HashSet, future::Future, net::SocketAddr, sync::atomic,
time::Duration,
};
use axum::{ use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
@ -38,7 +41,7 @@ use crate::{
server_server, well_known, server_server, well_known,
}, },
config, config,
config::{Config, ListenConfig}, config::{Config, ListenComponent, ListenTransport},
database::KeyValueDatabase, database::KeyValueDatabase,
error, observability, services, utils, error, observability, services, utils,
utils::error::{Error, Result}, utils::error::{Error, Result},
@ -148,7 +151,6 @@ async fn run_server() -> Result<(), error::Serve> {
)) ))
.layer(axum::middleware::from_fn(observability::http_metrics_layer)); .layer(axum::middleware::from_fn(observability::http_metrics_layer));
let app = routes(config).layer(middlewares).into_make_service();
let mut handles = Vec::new(); let mut handles = Vec::new();
let mut servers = JoinSet::new(); let mut servers = JoinSet::new();
@ -170,25 +172,29 @@ async fn run_server() -> Result<(), error::Serve> {
for listen in &config.listen { for listen in &config.listen {
info!(listener = %listen, "Listening for incoming traffic"); info!(listener = %listen, "Listening for incoming traffic");
match listen { let app = routes(config, &listen.components)
ListenConfig::Tcp { .layer(middlewares.clone())
.into_make_service();
match listen.transport {
ListenTransport::Tcp {
address, address,
port, port,
tls, tls,
} => { } => {
let addr = SocketAddr::from((*address, *port)); let addr = SocketAddr::from((address, port));
let handle = ServerHandle::new(); let handle = ServerHandle::new();
handles.push(handle.clone()); handles.push(handle.clone());
let server = if *tls { let server = if tls {
let tls_config = tls_config let tls_config = tls_config
.clone() .clone()
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?; .ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
bind_rustls(addr, tls_config) bind_rustls(addr, tls_config)
.handle(handle) .handle(handle)
.serve(app.clone()) .serve(app)
.left_future() .left_future()
} else { } else {
bind(addr).handle(handle).serve(app.clone()).right_future() bind(addr).handle(handle).serve(app).right_future()
}; };
servers.spawn( servers.spawn(
server.then(|result| async { (listen.clone(), result) }), server.then(|result| async { (listen.clone(), result) }),
@ -493,15 +499,24 @@ fn well_known_routes() -> Router {
.route("/.well-known/matrix/server", get(well_known::server)) .route("/.well-known/matrix/server", get(well_known::server))
} }
fn routes(config: &Config) -> Router { fn routes(config: &Config, components: &HashSet<ListenComponent>) -> Router {
Router::new() let mut router = Router::new();
.merge(client_routes()) for &component in components {
.merge(federation_routes(config)) router = router.merge(match component {
.merge(legacy_media_routes(config)) ListenComponent::Client => client_routes(),
.merge(well_known_routes()) ListenComponent::Federation => federation_routes(config),
.merge(metrics_routes(config)) ListenComponent::Metrics => metrics_routes(config),
.route("/", get(it_works)) ListenComponent::WellKnown => well_known_routes(),
.fallback(not_found) });
}
if components.contains(&ListenComponent::Client)
|| components.contains(&ListenComponent::Federation)
{
router = router.merge(legacy_media_routes(config));
}
router.route("/", get(it_works)).fallback(not_found)
} }
async fn shutdown_signal(handles: Vec<ServerHandle>) { async fn shutdown_signal(handles: Vec<ServerHandle>) {

View file

@ -1,6 +1,6 @@
use std::{ use std::{
borrow::Cow, borrow::Cow,
collections::BTreeMap, collections::{BTreeMap, HashSet},
fmt::{self, Display}, fmt::{self, Display},
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -13,6 +13,7 @@ use ruma::{
OwnedServerSigningKeyId, RoomVersionId, OwnedServerSigningKeyId, RoomVersionId,
}; };
use serde::Deserialize; use serde::Deserialize;
use strum::{Display, EnumIter, IntoEnumIterator};
use crate::error; use crate::error;
@ -108,9 +109,27 @@ pub(crate) struct TlsConfig {
pub(crate) key: String, pub(crate) key: String,
} }
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub(crate) enum ListenComponent {
Client,
Federation,
Metrics,
WellKnown,
}
impl ListenComponent {
fn all_components() -> HashSet<Self> {
Self::iter().collect()
}
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ListenConfig { pub(crate) enum ListenTransport {
Tcp { Tcp {
#[serde(default = "default_address")] #[serde(default = "default_address")]
address: IpAddr, address: IpAddr,
@ -121,15 +140,15 @@ pub(crate) enum ListenConfig {
}, },
} }
impl Display for ListenConfig { impl Display for ListenTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
ListenConfig::Tcp { ListenTransport::Tcp {
address, address,
port, port,
tls: false, tls: false,
} => write!(f, "http://{address}:{port}"), } => write!(f, "http://{address}:{port}"),
ListenConfig::Tcp { ListenTransport::Tcp {
address, address,
port, port,
tls: true, tls: true,
@ -138,6 +157,29 @@ impl Display for ListenConfig {
} }
} }
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct ListenConfig {
#[serde(default = "ListenComponent::all_components")]
pub(crate) components: HashSet<ListenComponent>,
#[serde(flatten)]
pub(crate) transport: ListenTransport,
}
impl Display for ListenConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} ({})",
self.transport,
self.components
.iter()
.map(ListenComponent::to_string)
.collect::<Vec<_>>()
.join(", ")
)
}
}
#[derive(Copy, Clone, Default, Debug, Deserialize)] #[derive(Copy, Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub(crate) enum LogFormat { pub(crate) enum LogFormat {
@ -315,10 +357,13 @@ fn true_fn() -> bool {
} }
fn default_listen() -> Vec<ListenConfig> { fn default_listen() -> Vec<ListenConfig> {
vec![ListenConfig::Tcp { vec![ListenConfig {
address: default_address(), components: ListenComponent::all_components(),
port: default_port(), transport: ListenTransport::Tcp {
tls: false, address: default_address(),
port: default_port(),
tls: false,
},
}] }]
} }