From 084d862e5175df9350a8ba2f67f2c28c4de43d75 Mon Sep 17 00:00:00 2001 From: Lambda Date: Sat, 21 Sep 2024 15:55:17 +0000 Subject: [PATCH] Allow configuring served components per listener --- book/changelog.md | 2 ++ src/cli/serve.rs | 51 ++++++++++++++++++++++++-------------- src/config.rs | 63 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/book/changelog.md b/book/changelog.md index b06790d9..26ed1d67 100644 --- a/book/changelog.md +++ b/book/changelog.md @@ -246,3 +246,5 @@ This will be the first release of Grapevine since it was forked from Conduit 18. Added admin commands to delete media ([!99](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/99), [!102](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/102)) +19. Allow configuring the served API components per listener. + ([!109](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/109)) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index 95ae9b44..8c9a2168 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -1,4 +1,7 @@ -use std::{future::Future, net::SocketAddr, sync::atomic, time::Duration}; +use std::{ + collections::HashSet, future::Future, net::SocketAddr, sync::atomic, + time::Duration, +}; use axum::{ extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, @@ -38,7 +41,7 @@ use crate::{ server_server, well_known, }, config, - config::{Config, ListenConfig}, + config::{Config, ListenComponent, ListenTransport}, database::KeyValueDatabase, error, observability, services, utils, utils::error::{Error, Result}, @@ -148,7 +151,6 @@ async fn run_server() -> Result<(), error::Serve> { )) .layer(axum::middleware::from_fn(observability::http_metrics_layer)); - let app = routes(config).layer(middlewares).into_make_service(); let mut handles = Vec::new(); let mut servers = JoinSet::new(); @@ -170,25 +172,29 @@ async fn run_server() -> Result<(), error::Serve> { for listen in &config.listen { info!(listener = %listen, "Listening for incoming traffic"); - match listen { - ListenConfig::Tcp { + let app = routes(config, &listen.components) + .layer(middlewares.clone()) + .into_make_service(); + + match listen.transport { + ListenTransport::Tcp { address, port, tls, } => { - let addr = SocketAddr::from((*address, *port)); + let addr = SocketAddr::from((address, port)); let handle = ServerHandle::new(); handles.push(handle.clone()); - let server = if *tls { + let server = if tls { let tls_config = tls_config .clone() .ok_or_else(|| Error::NoTlsCerts(listen.clone()))?; bind_rustls(addr, tls_config) .handle(handle) - .serve(app.clone()) + .serve(app) .left_future() } else { - bind(addr).handle(handle).serve(app.clone()).right_future() + bind(addr).handle(handle).serve(app).right_future() }; servers.spawn( server.then(|result| async { (listen.clone(), result) }), @@ -493,15 +499,24 @@ fn well_known_routes() -> Router { .route("/.well-known/matrix/server", get(well_known::server)) } -fn routes(config: &Config) -> Router { - Router::new() - .merge(client_routes()) - .merge(federation_routes(config)) - .merge(legacy_media_routes(config)) - .merge(well_known_routes()) - .merge(metrics_routes(config)) - .route("/", get(it_works)) - .fallback(not_found) +fn routes(config: &Config, components: &HashSet) -> Router { + let mut router = Router::new(); + for &component in components { + router = router.merge(match component { + ListenComponent::Client => client_routes(), + ListenComponent::Federation => federation_routes(config), + ListenComponent::Metrics => metrics_routes(config), + ListenComponent::WellKnown => well_known_routes(), + }); + } + + if components.contains(&ListenComponent::Client) + || components.contains(&ListenComponent::Federation) + { + router = router.merge(legacy_media_routes(config)); + } + + router.route("/", get(it_works)).fallback(not_found) } async fn shutdown_signal(handles: Vec) { diff --git a/src/config.rs b/src/config.rs index 197be925..394313e0 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use std::{ borrow::Cow, - collections::BTreeMap, + collections::{BTreeMap, HashSet}, fmt::{self, Display}, net::{IpAddr, Ipv4Addr}, path::{Path, PathBuf}, @@ -13,6 +13,7 @@ use ruma::{ OwnedServerSigningKeyId, RoomVersionId, }; use serde::Deserialize; +use strum::{Display, EnumIter, IntoEnumIterator}; use crate::error; @@ -108,9 +109,27 @@ pub(crate) struct TlsConfig { pub(crate) key: String, } +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub(crate) enum ListenComponent { + Client, + Federation, + Metrics, + WellKnown, +} + +impl ListenComponent { + fn all_components() -> HashSet { + Self::iter().collect() + } +} + #[derive(Clone, Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] -pub(crate) enum ListenConfig { +pub(crate) enum ListenTransport { Tcp { #[serde(default = "default_address")] address: IpAddr, @@ -121,15 +140,15 @@ pub(crate) enum ListenConfig { }, } -impl Display for ListenConfig { +impl Display for ListenTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ListenConfig::Tcp { + ListenTransport::Tcp { address, port, tls: false, } => write!(f, "http://{address}:{port}"), - ListenConfig::Tcp { + ListenTransport::Tcp { address, port, tls: true, @@ -138,6 +157,29 @@ impl Display for ListenConfig { } } +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct ListenConfig { + #[serde(default = "ListenComponent::all_components")] + pub(crate) components: HashSet, + #[serde(flatten)] + pub(crate) transport: ListenTransport, +} + +impl Display for ListenConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} ({})", + self.transport, + self.components + .iter() + .map(ListenComponent::to_string) + .collect::>() + .join(", ") + ) + } +} + #[derive(Copy, Clone, Default, Debug, Deserialize)] #[serde(rename_all = "snake_case")] pub(crate) enum LogFormat { @@ -315,10 +357,13 @@ fn true_fn() -> bool { } fn default_listen() -> Vec { - vec![ListenConfig::Tcp { - address: default_address(), - port: default_port(), - tls: false, + vec![ListenConfig { + components: ListenComponent::all_components(), + transport: ListenTransport::Tcp { + address: default_address(), + port: default_port(), + tls: false, + }, }] }