use std::{ borrow::Cow, collections::{BTreeMap, HashSet}, fmt::{self, Display}, net::{IpAddr, Ipv4Addr}, path::{Path, PathBuf}, }; use once_cell::sync::Lazy; use reqwest::Url; use ruma::{ api::federation::discovery::OldVerifyKey, OwnedServerName, OwnedServerSigningKeyId, RoomVersionId, }; use serde::Deserialize; use strum::{Display, EnumIter, IntoEnumIterator}; use crate::error; mod env_filter_clone; mod proxy; pub(crate) use env_filter_clone::EnvFilterClone; use proxy::ProxyConfig; /// The default configuration file path pub(crate) static DEFAULT_PATH: Lazy = Lazy::new(|| [env!("CARGO_PKG_NAME"), "config.toml"].iter().collect()); #[allow(clippy::struct_excessive_bools)] #[derive(Debug, Deserialize)] pub(crate) struct Config { #[serde(default = "false_fn")] pub(crate) conduit_compat: bool, #[serde(default = "default_listen")] pub(crate) listen: Vec, pub(crate) tls: Option, /// The name of this homeserver /// /// This is the value that will appear e.g. in user IDs and room aliases. pub(crate) server_name: OwnedServerName, pub(crate) server_discovery: ServerDiscovery, pub(crate) database: DatabaseConfig, #[serde(default)] pub(crate) federation: FederationConfig, #[serde(default = "default_cache_capacity_modifier")] pub(crate) cache_capacity_modifier: f64, #[serde(default = "default_pdu_cache_capacity")] pub(crate) pdu_cache_capacity: usize, #[serde(default = "default_cleanup_second_interval")] pub(crate) cleanup_second_interval: u32, #[serde(default = "default_max_request_size")] pub(crate) max_request_size: u32, #[serde(default = "false_fn")] pub(crate) allow_registration: bool, pub(crate) registration_token: Option, #[serde(default = "true_fn")] pub(crate) allow_encryption: bool, #[serde(default = "true_fn")] pub(crate) allow_room_creation: bool, #[serde(default = "false_fn")] pub(crate) serve_media_unauthenticated: bool, #[serde(default = "default_default_room_version")] pub(crate) default_room_version: RoomVersionId, #[serde(default)] pub(crate) proxy: ProxyConfig, pub(crate) jwt_secret: Option, #[serde(default)] pub(crate) observability: ObservabilityConfig, #[serde(default)] pub(crate) turn: TurnConfig, pub(crate) emergency_password: Option, } #[derive(Debug, Deserialize)] pub(crate) struct ServerDiscovery { /// Server-server discovery configuration #[serde(default)] pub(crate) server: ServerServerDiscovery, /// Client-server discovery configuration pub(crate) client: ClientServerDiscovery, } /// Server-server discovery configuration #[derive(Debug, Default, Deserialize)] pub(crate) struct ServerServerDiscovery { /// The alternative authority to make server-server API requests to pub(crate) authority: Option, } /// Client-server discovery configuration #[derive(Debug, Deserialize)] pub(crate) struct ClientServerDiscovery { /// The base URL to make client-server API requests to pub(crate) base_url: Url, #[serde(default, rename = "advertise_buggy_sliding_sync")] pub(crate) advertise_sliding_sync: bool, } #[derive(Debug, Deserialize)] pub(crate) struct TlsConfig { pub(crate) certs: String, pub(crate) key: String, } #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, EnumIter, Display, )] #[serde(rename_all = "snake_case")] #[strum(serialize_all = "snake_case")] pub(crate) enum ListenComponent { Client, Federation, Metrics, WellKnown, } impl ListenComponent { fn all_components() -> HashSet { Self::iter().collect() } } #[derive(Clone, Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub(crate) enum ListenTransport { Tcp { #[serde(default = "default_address")] address: IpAddr, #[serde(default = "default_port")] port: u16, #[serde(default = "false_fn")] tls: bool, }, } impl Display for ListenTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ListenTransport::Tcp { address, port, tls: false, } => write!(f, "http://{address}:{port}"), ListenTransport::Tcp { address, port, tls: true, } => write!(f, "https://{address}:{port}"), } } } #[derive(Clone, Debug, Deserialize)] pub(crate) struct ListenConfig { #[serde(default = "ListenComponent::all_components")] pub(crate) components: HashSet, #[serde(flatten)] pub(crate) transport: ListenTransport, } impl Display for ListenConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{} ({})", self.transport, self.components .iter() .map(ListenComponent::to_string) .collect::>() .join(", ") ) } } #[derive(Copy, Clone, Default, Debug, Deserialize)] #[serde(rename_all = "snake_case")] pub(crate) enum LogFormat { /// Use the [`tracing_subscriber::fmt::format::Pretty`] formatter Pretty, /// Use the [`tracing_subscriber::fmt::format::Full`] formatter #[default] Full, /// Use the [`tracing_subscriber::fmt::format::Compact`] formatter Compact, /// Use the [`tracing_subscriber::fmt::format::Json`] formatter Json, } #[derive(Clone, Debug, Deserialize)] #[serde(default)] pub(crate) struct TurnConfig { pub(crate) username: String, pub(crate) password: String, pub(crate) uris: Vec, pub(crate) secret: String, pub(crate) ttl: u64, } impl Default for TurnConfig { fn default() -> Self { Self { username: String::new(), password: String::new(), uris: Vec::new(), secret: String::new(), ttl: 60 * 60 * 24, } } } #[cfg(feature = "rocksdb")] #[derive(Clone, Debug, Deserialize)] pub(crate) struct RocksdbConfig { pub(crate) path: PathBuf, #[serde(default = "default_rocksdb_max_open_files")] pub(crate) max_open_files: i32, #[serde(default = "default_rocksdb_cache_capacity_bytes")] pub(crate) cache_capacity_bytes: usize, } #[cfg(feature = "sqlite")] #[derive(Clone, Debug, Deserialize)] pub(crate) struct SqliteConfig { pub(crate) path: PathBuf, #[serde(default = "default_sqlite_cache_capacity_kilobytes")] pub(crate) cache_capacity_kilobytes: u32, } #[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "lowercase", tag = "backend")] pub(crate) enum DatabaseConfig { #[cfg(feature = "rocksdb")] Rocksdb(RocksdbConfig), #[cfg(feature = "sqlite")] Sqlite(SqliteConfig), } impl DatabaseConfig { pub(crate) fn path(&self) -> &Path { match self { #[cfg(feature = "rocksdb")] DatabaseConfig::Rocksdb(x) => &x.path, #[cfg(feature = "sqlite")] DatabaseConfig::Sqlite(x) => &x.path, } } } impl Display for DatabaseConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { #[cfg(feature = "rocksdb")] DatabaseConfig::Rocksdb(_) => write!(f, "RocksDB"), #[cfg(feature = "sqlite")] DatabaseConfig::Sqlite(_) => write!(f, "SQLite"), } } } #[derive(Clone, Debug, Default, Deserialize)] #[serde(default)] pub(crate) struct MetricsConfig { pub(crate) enable: bool, } #[derive(Debug, Deserialize)] #[serde(default)] pub(crate) struct OtelTraceConfig { pub(crate) enable: bool, pub(crate) filter: EnvFilterClone, pub(crate) endpoint: Option, pub(crate) service_name: String, } impl Default for OtelTraceConfig { fn default() -> Self { Self { enable: false, filter: default_tracing_filter(), endpoint: None, service_name: env!("CARGO_PKG_NAME").to_owned(), } } } #[derive(Debug, Deserialize)] #[serde(default)] pub(crate) struct FlameConfig { pub(crate) enable: bool, pub(crate) filter: EnvFilterClone, pub(crate) filename: String, } impl Default for FlameConfig { fn default() -> Self { Self { enable: false, filter: default_tracing_filter(), filename: "./tracing.folded".to_owned(), } } } #[derive(Debug, Deserialize)] #[serde(default)] pub(crate) struct LogConfig { pub(crate) filter: EnvFilterClone, pub(crate) colors: bool, pub(crate) format: LogFormat, pub(crate) timestamp: bool, } impl Default for LogConfig { fn default() -> Self { Self { filter: default_tracing_filter(), colors: true, format: LogFormat::default(), timestamp: true, } } } #[derive(Debug, Default, Deserialize)] #[serde(default)] pub(crate) struct ObservabilityConfig { /// Prometheus metrics pub(crate) metrics: MetricsConfig, /// OpenTelemetry traces pub(crate) traces: OtelTraceConfig, /// Folded inferno stack traces pub(crate) flame: FlameConfig, /// Logging to stdout pub(crate) logs: LogConfig, } #[derive(Debug, Deserialize)] #[serde(default)] pub(crate) struct FederationConfig { pub(crate) enable: bool, pub(crate) self_test: bool, pub(crate) trusted_servers: Vec, pub(crate) max_fetch_prev_events: u16, pub(crate) max_concurrent_requests: u16, pub(crate) old_verify_keys: BTreeMap, } impl Default for FederationConfig { fn default() -> Self { Self { enable: true, self_test: true, trusted_servers: vec![ OwnedServerName::try_from("matrix.org").unwrap() ], max_fetch_prev_events: 100, max_concurrent_requests: 100, old_verify_keys: BTreeMap::new(), } } } fn false_fn() -> bool { false } fn true_fn() -> bool { true } fn default_listen() -> Vec { vec![ListenConfig { components: ListenComponent::all_components(), transport: ListenTransport::Tcp { address: default_address(), port: default_port(), tls: false, }, }] } fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } fn default_port() -> u16 { 6167 } fn default_cache_capacity_modifier() -> f64 { 1.0 } #[cfg(feature = "rocksdb")] fn default_rocksdb_max_open_files() -> i32 { 1000 } #[cfg(feature = "rocksdb")] fn default_rocksdb_cache_capacity_bytes() -> usize { 300 * 1024 * 1024 } #[cfg(feature = "sqlite")] fn default_sqlite_cache_capacity_kilobytes() -> u32 { 300 * 1024 } fn default_pdu_cache_capacity() -> usize { 150_000 } fn default_cleanup_second_interval() -> u32 { // every minute 60 } fn default_max_request_size() -> u32 { // Default to 20 MB 20 * 1024 * 1024 } fn default_tracing_filter() -> EnvFilterClone { "info,ruma_state_res=warn" .parse() .expect("hardcoded env filter should be valid") } // I know, it's a great name pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } /// Search default locations for a configuration file /// /// If one isn't found, the list of tried paths is returned. fn search() -> Result { use error::ConfigSearch as Error; xdg::BaseDirectories::new()? .find_config_file(&*DEFAULT_PATH) .ok_or(Error::NotFound) } /// Load the configuration from the given path or XDG Base Directories pub(crate) async fn load

(path: Option

) -> Result where P: AsRef, { use error::Config as Error; let path = match path.as_ref().map(AsRef::as_ref) { Some(x) => Cow::Borrowed(x), None => Cow::Owned(search()?), }; let path = path.as_ref(); let config: Config = toml::from_str( &tokio::fs::read_to_string(path) .await .map_err(|e| Error::Read(e, path.to_owned()))?, ) .map_err(|e| Error::Parse(e, path.to_owned()))?; if config.registration_token.as_deref() == Some("") { return Err(Error::RegistrationTokenEmpty); } Ok(config) }