grapevine/src/config.rs
Charles Hall 153e3e4c93
make database config a sum type
This way we don't need to construct the entire configuration to load a
database or database engine.

The other advantage is that it allows having options that are unique to
each database backend.

The one thing I don't like about this is `DatabaseConfig::path`, whose
existence implies all databases will have a file path, which is not
true for out-of-process databases. The only thing this is really used
for is creating the media directory. I think we should restructure the
configuration in the future to resolve this.
2024-10-22 10:36:04 -07:00

479 lines
12 KiB
Rust

use std::{
borrow::Cow,
collections::{BTreeMap, HashSet},
fmt::{self, Display},
net::{IpAddr, Ipv4Addr},
path::{Path, PathBuf},
};
use once_cell::sync::Lazy;
use reqwest::Url;
use ruma::{
api::federation::discovery::OldVerifyKey, OwnedServerName,
OwnedServerSigningKeyId, RoomVersionId,
};
use serde::Deserialize;
use strum::{Display, EnumIter, IntoEnumIterator};
use crate::error;
mod env_filter_clone;
mod proxy;
pub(crate) use env_filter_clone::EnvFilterClone;
use proxy::ProxyConfig;
/// The default configuration file path
pub(crate) static DEFAULT_PATH: Lazy<PathBuf> =
Lazy::new(|| [env!("CARGO_PKG_NAME"), "config.toml"].iter().collect());
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Deserialize)]
pub(crate) struct Config {
#[serde(default = "false_fn")]
pub(crate) conduit_compat: bool,
#[serde(default = "default_listen")]
pub(crate) listen: Vec<ListenConfig>,
pub(crate) tls: Option<TlsConfig>,
/// The name of this homeserver
///
/// This is the value that will appear e.g. in user IDs and room aliases.
pub(crate) server_name: OwnedServerName,
pub(crate) server_discovery: ServerDiscovery,
pub(crate) database: DatabaseConfig,
#[serde(default)]
pub(crate) federation: FederationConfig,
#[serde(default = "default_cache_capacity_modifier")]
pub(crate) cache_capacity_modifier: f64,
#[serde(default = "default_pdu_cache_capacity")]
pub(crate) pdu_cache_capacity: usize,
#[serde(default = "default_cleanup_second_interval")]
pub(crate) cleanup_second_interval: u32,
#[serde(default = "default_max_request_size")]
pub(crate) max_request_size: u32,
#[serde(default = "false_fn")]
pub(crate) allow_registration: bool,
pub(crate) registration_token: Option<String>,
#[serde(default = "true_fn")]
pub(crate) allow_encryption: bool,
#[serde(default = "true_fn")]
pub(crate) allow_room_creation: bool,
#[serde(default = "false_fn")]
pub(crate) serve_media_unauthenticated: bool,
#[serde(default = "default_default_room_version")]
pub(crate) default_room_version: RoomVersionId,
#[serde(default)]
pub(crate) proxy: ProxyConfig,
pub(crate) jwt_secret: Option<String>,
#[serde(default)]
pub(crate) observability: ObservabilityConfig,
#[serde(default)]
pub(crate) turn: TurnConfig,
pub(crate) emergency_password: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ServerDiscovery {
/// Server-server discovery configuration
#[serde(default)]
pub(crate) server: ServerServerDiscovery,
/// Client-server discovery configuration
pub(crate) client: ClientServerDiscovery,
}
/// Server-server discovery configuration
#[derive(Debug, Default, Deserialize)]
pub(crate) struct ServerServerDiscovery {
/// The alternative authority to make server-server API requests to
pub(crate) authority: Option<OwnedServerName>,
}
/// Client-server discovery configuration
#[derive(Debug, Deserialize)]
pub(crate) struct ClientServerDiscovery {
/// The base URL to make client-server API requests to
pub(crate) base_url: Url,
#[serde(default, rename = "advertise_buggy_sliding_sync")]
pub(crate) advertise_sliding_sync: bool,
}
#[derive(Debug, Deserialize)]
pub(crate) struct TlsConfig {
pub(crate) certs: String,
pub(crate) key: String,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub(crate) enum ListenComponent {
Client,
Federation,
Metrics,
WellKnown,
}
impl ListenComponent {
fn all_components() -> HashSet<Self> {
Self::iter().collect()
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ListenTransport {
Tcp {
#[serde(default = "default_address")]
address: IpAddr,
#[serde(default = "default_port")]
port: u16,
#[serde(default = "false_fn")]
tls: bool,
},
}
impl Display for ListenTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ListenTransport::Tcp {
address,
port,
tls: false,
} => write!(f, "http://{address}:{port}"),
ListenTransport::Tcp {
address,
port,
tls: true,
} => write!(f, "https://{address}:{port}"),
}
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct ListenConfig {
#[serde(default = "ListenComponent::all_components")]
pub(crate) components: HashSet<ListenComponent>,
#[serde(flatten)]
pub(crate) transport: ListenTransport,
}
impl Display for ListenConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} ({})",
self.transport,
self.components
.iter()
.map(ListenComponent::to_string)
.collect::<Vec<_>>()
.join(", ")
)
}
}
#[derive(Copy, Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum LogFormat {
/// Use the [`tracing_subscriber::fmt::format::Pretty`] formatter
Pretty,
/// Use the [`tracing_subscriber::fmt::format::Full`] formatter
#[default]
Full,
/// Use the [`tracing_subscriber::fmt::format::Compact`] formatter
Compact,
/// Use the [`tracing_subscriber::fmt::format::Json`] formatter
Json,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub(crate) struct TurnConfig {
pub(crate) username: String,
pub(crate) password: String,
pub(crate) uris: Vec<String>,
pub(crate) secret: String,
pub(crate) ttl: u64,
}
impl Default for TurnConfig {
fn default() -> Self {
Self {
username: String::new(),
password: String::new(),
uris: Vec::new(),
secret: String::new(),
ttl: 60 * 60 * 24,
}
}
}
#[cfg(feature = "rocksdb")]
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct RocksdbConfig {
pub(crate) path: PathBuf,
#[serde(default = "default_rocksdb_max_open_files")]
pub(crate) max_open_files: i32,
#[serde(default = "default_rocksdb_cache_capacity_bytes")]
pub(crate) cache_capacity_bytes: usize,
}
#[cfg(feature = "sqlite")]
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct SqliteConfig {
pub(crate) path: PathBuf,
#[serde(default = "default_sqlite_cache_capacity_kilobytes")]
pub(crate) cache_capacity_kilobytes: u32,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "lowercase", tag = "backend")]
pub(crate) enum DatabaseConfig {
#[cfg(feature = "rocksdb")]
Rocksdb(RocksdbConfig),
#[cfg(feature = "sqlite")]
Sqlite(SqliteConfig),
}
impl DatabaseConfig {
pub(crate) fn path(&self) -> &Path {
match self {
#[cfg(feature = "rocksdb")]
DatabaseConfig::Rocksdb(x) => &x.path,
#[cfg(feature = "sqlite")]
DatabaseConfig::Sqlite(x) => &x.path,
}
}
}
impl Display for DatabaseConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
#[cfg(feature = "rocksdb")]
DatabaseConfig::Rocksdb(_) => write!(f, "RocksDB"),
#[cfg(feature = "sqlite")]
DatabaseConfig::Sqlite(_) => write!(f, "SQLite"),
}
}
}
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct MetricsConfig {
pub(crate) enable: bool,
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct OtelTraceConfig {
pub(crate) enable: bool,
pub(crate) filter: EnvFilterClone,
pub(crate) endpoint: Option<String>,
pub(crate) service_name: String,
}
impl Default for OtelTraceConfig {
fn default() -> Self {
Self {
enable: false,
filter: default_tracing_filter(),
endpoint: None,
service_name: env!("CARGO_PKG_NAME").to_owned(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct FlameConfig {
pub(crate) enable: bool,
pub(crate) filter: EnvFilterClone,
pub(crate) filename: String,
}
impl Default for FlameConfig {
fn default() -> Self {
Self {
enable: false,
filter: default_tracing_filter(),
filename: "./tracing.folded".to_owned(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct LogConfig {
pub(crate) filter: EnvFilterClone,
pub(crate) colors: bool,
pub(crate) format: LogFormat,
pub(crate) timestamp: bool,
}
impl Default for LogConfig {
fn default() -> Self {
Self {
filter: default_tracing_filter(),
colors: true,
format: LogFormat::default(),
timestamp: true,
}
}
}
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct ObservabilityConfig {
/// Prometheus metrics
pub(crate) metrics: MetricsConfig,
/// OpenTelemetry traces
pub(crate) traces: OtelTraceConfig,
/// Folded inferno stack traces
pub(crate) flame: FlameConfig,
/// Logging to stdout
pub(crate) logs: LogConfig,
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct FederationConfig {
pub(crate) enable: bool,
pub(crate) self_test: bool,
pub(crate) trusted_servers: Vec<OwnedServerName>,
pub(crate) max_fetch_prev_events: u16,
pub(crate) max_concurrent_requests: u16,
pub(crate) old_verify_keys: BTreeMap<OwnedServerSigningKeyId, OldVerifyKey>,
}
impl Default for FederationConfig {
fn default() -> Self {
Self {
enable: true,
self_test: true,
trusted_servers: vec![
OwnedServerName::try_from("matrix.org").unwrap()
],
max_fetch_prev_events: 100,
max_concurrent_requests: 100,
old_verify_keys: BTreeMap::new(),
}
}
}
fn false_fn() -> bool {
false
}
fn true_fn() -> bool {
true
}
fn default_listen() -> Vec<ListenConfig> {
vec![ListenConfig {
components: ListenComponent::all_components(),
transport: ListenTransport::Tcp {
address: default_address(),
port: default_port(),
tls: false,
},
}]
}
fn default_address() -> IpAddr {
Ipv4Addr::LOCALHOST.into()
}
fn default_port() -> u16 {
6167
}
fn default_cache_capacity_modifier() -> f64 {
1.0
}
#[cfg(feature = "rocksdb")]
fn default_rocksdb_max_open_files() -> i32 {
1000
}
#[cfg(feature = "rocksdb")]
fn default_rocksdb_cache_capacity_bytes() -> usize {
300 * 1024 * 1024
}
#[cfg(feature = "sqlite")]
fn default_sqlite_cache_capacity_kilobytes() -> u32 {
300 * 1024
}
fn default_pdu_cache_capacity() -> usize {
150_000
}
fn default_cleanup_second_interval() -> u32 {
// every minute
60
}
fn default_max_request_size() -> u32 {
// Default to 20 MB
20 * 1024 * 1024
}
fn default_tracing_filter() -> EnvFilterClone {
"info,ruma_state_res=warn"
.parse()
.expect("hardcoded env filter should be valid")
}
// I know, it's a great name
pub(crate) fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10
}
/// Search default locations for a configuration file
///
/// If one isn't found, the list of tried paths is returned.
fn search() -> Result<PathBuf, error::ConfigSearch> {
use error::ConfigSearch as Error;
xdg::BaseDirectories::new()?
.find_config_file(&*DEFAULT_PATH)
.ok_or(Error::NotFound)
}
/// Load the configuration from the given path or XDG Base Directories
pub(crate) async fn load<P>(path: Option<P>) -> Result<Config, error::Config>
where
P: AsRef<Path>,
{
use error::Config as Error;
let path = match path.as_ref().map(AsRef::as_ref) {
Some(x) => Cow::Borrowed(x),
None => Cow::Owned(search()?),
};
let path = path.as_ref();
let config: Config = toml::from_str(
&tokio::fs::read_to_string(path)
.await
.map_err(|e| Error::Read(e, path.to_owned()))?,
)
.map_err(|e| Error::Parse(e, path.to_owned()))?;
if config.registration_token.as_deref() == Some("") {
return Err(Error::RegistrationTokenEmpty);
}
Ok(config)
}