use std::{ borrow::Cow, collections::{BTreeMap, HashSet}, fmt::{self, Display}, net::{IpAddr, Ipv4Addr}, path::{Path, PathBuf}, sync::LazyLock, }; use reqwest::Url; use ruma::{ api::federation::discovery::OldVerifyKey, OwnedServerName, OwnedServerSigningKeyId, RoomVersionId, UInt, }; use serde::Deserialize; use strum::{Display, EnumIter, IntoEnumIterator}; use crate::error; mod env_filter_clone; mod proxy; pub(crate) use env_filter_clone::EnvFilterClone; use proxy::ProxyConfig; /// The default configuration file path pub(crate) static DEFAULT_PATH: LazyLock = LazyLock::new(|| [env!("CARGO_PKG_NAME"), "config.toml"].iter().collect()); #[allow(clippy::struct_excessive_bools)] #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct Config { #[serde(default = "false_fn")] pub(crate) conduit_compat: bool, #[serde(default = "default_listen")] pub(crate) listen: Vec, pub(crate) tls: Option, /// The name of this homeserver /// /// This is the value that will appear e.g. in user IDs and room aliases. pub(crate) server_name: OwnedServerName, pub(crate) server_discovery: ServerDiscovery, pub(crate) database: DatabaseConfig, #[serde(default)] pub(crate) media: MediaConfig, #[serde(default)] pub(crate) federation: FederationConfig, #[serde(default)] pub(crate) cache: CacheConfig, #[serde(default = "default_cleanup_second_interval")] pub(crate) cleanup_second_interval: u32, #[serde(default = "default_max_request_size")] pub(crate) max_request_size: UInt, #[serde(default = "false_fn")] pub(crate) allow_registration: bool, pub(crate) registration_token: Option, #[serde(default = "true_fn")] pub(crate) allow_encryption: bool, #[serde(default = "true_fn")] pub(crate) allow_room_creation: bool, #[serde(default = "default_default_room_version")] pub(crate) default_room_version: RoomVersionId, #[serde(default)] pub(crate) proxy: ProxyConfig, pub(crate) jwt_secret: Option, #[serde(default)] pub(crate) observability: ObservabilityConfig, #[serde(default)] pub(crate) turn: TurnConfig, pub(crate) emergency_password: Option, } #[derive(Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] pub(crate) struct MediaConfig { #[serde(default)] pub(crate) allow_unauthenticated_access: bool, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields, default)] pub(crate) struct CacheConfig { pub(crate) pdu: usize, pub(crate) auth_chain: usize, pub(crate) short_eventid: usize, pub(crate) eventid_short: usize, pub(crate) statekey_short: usize, pub(crate) short_statekey: usize, pub(crate) server_visibility: usize, pub(crate) user_visibility: usize, pub(crate) state_info: usize, pub(crate) roomid_spacechunk: usize, } impl Default for CacheConfig { fn default() -> Self { Self { pdu: 150_000, auth_chain: 100_000, short_eventid: 100_000, eventid_short: 100_000, statekey_short: 100_000, short_statekey: 100_000, server_visibility: 100, user_visibility: 100, state_info: 100, roomid_spacechunk: 200, } } } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct ServerDiscovery { /// Server-server discovery configuration #[serde(default)] pub(crate) server: ServerServerDiscovery, /// Client-server discovery configuration pub(crate) client: ClientServerDiscovery, } /// Server-server discovery configuration #[derive(Debug, Default, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct ServerServerDiscovery { /// The alternative authority to make server-server API requests to pub(crate) authority: Option, } /// Client-server discovery configuration #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct ClientServerDiscovery { /// The base URL to make client-server API requests to pub(crate) base_url: Url, #[serde(default, rename = "advertise_buggy_sliding_sync")] pub(crate) advertise_sliding_sync: bool, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct TlsConfig { pub(crate) certs: String, pub(crate) key: String, } #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, EnumIter, Display, )] #[serde(rename_all = "snake_case")] #[strum(serialize_all = "snake_case")] #[serde(deny_unknown_fields)] pub(crate) enum ListenComponent { Client, Federation, Metrics, WellKnown, } impl ListenComponent { fn all_components() -> HashSet { Self::iter().collect() } } #[derive(Clone, Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] #[serde(deny_unknown_fields)] pub(crate) enum ListenTransport { Tcp { #[serde(default = "default_address")] address: IpAddr, #[serde(default = "default_port")] port: u16, #[serde(default = "false_fn")] tls: bool, #[serde(default = "false_fn")] proxy_protocol: bool, }, } impl Display for ListenTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { ListenTransport::Tcp { address, port, tls, proxy_protocol, } => { let scheme = format!( "{}{}", if proxy_protocol { "proxy+" } else { "" }, if tls { "https" } else { "http" } ); write!(f, "{scheme}://{address}:{port}") } } } } #[derive(Clone, Debug, Deserialize)] // Incompatible with deny_unknown_fields due to serde(flatten). pub(crate) struct ListenConfig { #[serde(default = "ListenComponent::all_components")] pub(crate) components: HashSet, #[serde(flatten)] pub(crate) transport: ListenTransport, } impl Display for ListenConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{} ({})", self.transport, self.components .iter() .map(ListenComponent::to_string) .collect::>() .join(", ") ) } } #[derive(Copy, Clone, Default, Debug, Deserialize, clap::ValueEnum)] #[serde(deny_unknown_fields)] #[serde(rename_all = "snake_case")] pub(crate) enum LogFormat { /// Multiple lines per event, includes all information Pretty, /// One line per event, includes most information #[default] Full, /// One line per event, includes less information Compact, /// One JSON object per line per event, includes most information Json, } impl Display for LogFormat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { LogFormat::Pretty => write!(f, "pretty"), LogFormat::Full => write!(f, "full"), LogFormat::Compact => write!(f, "compact"), LogFormat::Json => write!(f, "json"), } } } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct TurnConfig { pub(crate) username: String, pub(crate) password: String, pub(crate) uris: Vec, pub(crate) secret: String, pub(crate) ttl: u64, } impl Default for TurnConfig { fn default() -> Self { Self { username: String::new(), password: String::new(), uris: Vec::new(), secret: String::new(), ttl: 60 * 60 * 24, } } } #[derive(Clone, Copy, Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(rename_all = "lowercase")] pub(crate) enum DatabaseBackend { #[cfg(feature = "rocksdb")] Rocksdb, #[cfg(feature = "sqlite")] Sqlite, } impl Display for DatabaseBackend { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { #[cfg(feature = "rocksdb")] DatabaseBackend::Rocksdb => write!(f, "RocksDB"), #[cfg(feature = "sqlite")] DatabaseBackend::Sqlite => write!(f, "SQLite"), } } } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct DatabaseConfig { pub(crate) backend: DatabaseBackend, pub(crate) path: String, #[serde(default = "default_db_cache_capacity_mb")] pub(crate) cache_capacity_mb: f64, #[cfg(feature = "rocksdb")] #[serde(default = "default_rocksdb_max_open_files")] pub(crate) rocksdb_max_open_files: i32, } #[derive(Clone, Debug, Default, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct MetricsConfig { pub(crate) enable: bool, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct OtelTraceConfig { pub(crate) enable: bool, pub(crate) filter: EnvFilterClone, pub(crate) endpoint: Option, pub(crate) service_name: String, } impl Default for OtelTraceConfig { fn default() -> Self { Self { enable: false, filter: default_tracing_filter(), endpoint: None, service_name: env!("CARGO_PKG_NAME").to_owned(), } } } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct FlameConfig { pub(crate) enable: bool, pub(crate) filter: EnvFilterClone, pub(crate) filename: String, } impl Default for FlameConfig { fn default() -> Self { Self { enable: false, filter: default_tracing_filter(), filename: "./tracing.folded".to_owned(), } } } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct LogConfig { pub(crate) filter: EnvFilterClone, pub(crate) colors: bool, pub(crate) format: LogFormat, pub(crate) timestamp: bool, } impl Default for LogConfig { fn default() -> Self { Self { filter: default_tracing_filter(), colors: true, format: LogFormat::default(), timestamp: true, } } } #[derive(Debug, Default, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct ObservabilityConfig { /// Prometheus metrics pub(crate) metrics: MetricsConfig, /// OpenTelemetry traces pub(crate) traces: OtelTraceConfig, /// Folded inferno stack traces pub(crate) flame: FlameConfig, /// Logging to stdout pub(crate) logs: LogConfig, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] #[serde(default)] pub(crate) struct FederationConfig { pub(crate) enable: bool, pub(crate) self_test: bool, pub(crate) trusted_servers: Vec, pub(crate) max_fetch_prev_events: u16, pub(crate) max_concurrent_requests: u16, pub(crate) old_verify_keys: BTreeMap, } impl Default for FederationConfig { fn default() -> Self { Self { enable: true, self_test: true, trusted_servers: vec![ OwnedServerName::try_from("matrix.org").unwrap() ], max_fetch_prev_events: 100, max_concurrent_requests: 100, old_verify_keys: BTreeMap::new(), } } } fn false_fn() -> bool { false } fn true_fn() -> bool { true } fn default_listen() -> Vec { vec![ListenConfig { components: ListenComponent::all_components(), transport: ListenTransport::Tcp { address: default_address(), port: default_port(), tls: false, proxy_protocol: false, }, }] } fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } fn default_port() -> u16 { 6167 } fn default_db_cache_capacity_mb() -> f64 { 300.0 } #[cfg(feature = "rocksdb")] fn default_rocksdb_max_open_files() -> i32 { 1000 } fn default_cleanup_second_interval() -> u32 { // every minute 60 } fn default_max_request_size() -> UInt { // Default to 20 MB (20_u32 * 1024 * 1024).into() } pub(crate) fn default_tracing_filter() -> EnvFilterClone { "info,ruma_state_res=warn" .parse() .expect("hardcoded env filter should be valid") } // I know, it's a great name pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } /// Search default locations for a configuration file /// /// If one isn't found, the list of tried paths is returned. fn search() -> Result { use error::ConfigSearch as Error; xdg::BaseDirectories::new()? .find_config_file(&*DEFAULT_PATH) .ok_or(Error::NotFound) } /// Load the configuration from the given path or XDG Base Directories pub(crate) async fn load

(path: Option

) -> Result where P: AsRef, { use error::Config as Error; let path = match path.as_ref().map(AsRef::as_ref) { Some(x) => Cow::Borrowed(x), None => Cow::Owned(search()?), }; let path = path.as_ref(); let config: Config = toml::from_str( &tokio::fs::read_to_string(path) .await .map_err(|e| Error::Read(e, path.to_owned()))?, ) .map_err(|e| Error::Parse(e, path.to_owned()))?; if config.registration_token.as_deref() == Some("") { return Err(Error::RegistrationTokenEmpty); } Ok(config) }