diff --git a/Cargo.lock b/Cargo.lock index d257d117..a1f17fd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,6 +56,12 @@ dependencies = [ "password-hash", ] +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "as_variant" version = "1.3.0" @@ -408,6 +414,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" +dependencies = [ + "num-traits", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -624,6 +639,20 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "duration-str" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9add086174f60bcbcfde7175e71dcfd99da24dfd12f611d0faf74f4f26e15a06" +dependencies = [ + "chrono", + "rust_decimal", + "serde", + "thiserror 2.0.12", + "time", + "winnow", +] + [[package]] name = "ed25519" version = "2.2.3" @@ -923,6 +952,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "duration-str", "futures-util", "hmac", "html-escape", @@ -965,6 +995,7 @@ dependencies = [ "thiserror 2.0.12", "thread_local", "tikv-jemallocator", + "time", "tokio", "toml", "tower 0.5.2", @@ -2734,6 +2765,16 @@ dependencies = [ "rust-librocksdb-sys", ] +[[package]] +name = "rust_decimal" +version = "1.37.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50" +dependencies = [ + "arrayvec", + "num-traits", +] + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/Cargo.toml b/Cargo.toml index 0cf41846..750021d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,6 +96,7 @@ axum-server = { version = "0.7.2", features = ["tls-rustls-no-provider"] } base64 = "0.22.1" bytes = "1.10.1" clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } +duration-str = "0.17.0" futures-util = { version = "0.3.31", default-features = false } hmac = "0.12.1" html-escape = "0.2.13" @@ -131,6 +132,7 @@ serde_json = { version = "1.0.140", features = ["raw_value"] } serde_yaml = "0.9.34" sha-1 = "0.10.1" strum = { version = "0.27.1", features = ["derive"] } +time = "0.3.41" thiserror = "2.0.12" thread_local = "1.1.8" tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index d660c6ba..2e45029a 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -158,7 +158,7 @@ pub(crate) async fn register_route( // UIAA let mut uiaainfo; - let skip_auth = if services().globals.config.registration_token.is_some() { + let skip_auth = if services().globals.config.require_registration_token { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { diff --git a/src/config.rs b/src/config.rs index 9ad6e1ec..3e3efd96 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,6 +14,7 @@ use ruma::{ }; use serde::Deserialize; use strum::{Display, EnumIter, IntoEnumIterator}; +use tracing::warn; use crate::{error, utils::partial_canonicalize}; @@ -57,6 +58,8 @@ pub(crate) struct Config { pub(crate) allow_registration: bool, pub(crate) registration_token: Option, #[serde(default = "true_fn")] + pub(crate) require_registration_token: bool, + #[serde(default = "true_fn")] pub(crate) allow_encryption: bool, #[serde(default = "true_fn")] pub(crate) allow_room_creation: bool, @@ -537,8 +540,11 @@ where ) .map_err(|e| Error::Parse(e, path.to_owned()))?; - if config.registration_token.as_deref() == Some("") { - return Err(Error::RegistrationTokenEmpty); + if config.registration_token.is_some() { + warn!( + "configuration registration token is no longer supported, use the \ + admin room to generate one" + ); } match &config.media.backend { diff --git a/src/database.rs b/src/database.rs index f7cb1352..f862ca31 100644 --- a/src/database.rs +++ b/src/database.rs @@ -36,6 +36,7 @@ pub(crate) struct KeyValueDatabase { // Trees "owned" by `self::key_value::globals` pub(super) global: Arc, pub(super) server_signingkeys: Arc, + pub(super) registration_tokens: Arc, // Trees "owned" by `self::key_value::users` pub(super) userid_password: Arc, @@ -446,6 +447,7 @@ impl KeyValueDatabase { senderkey_pusher: builder.open_tree("senderkey_pusher")?, global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, + registration_tokens: builder.open_tree("registration_tokens")?, }; Ok(db) diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index e3293571..88cfed3c 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -3,7 +3,7 @@ use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::discovery::{OldVerifyKey, ServerSigningKeys}, signatures::Ed25519KeyPair, - DeviceId, ServerName, UserId, + DeviceId, MilliSecondsSinceUnixEpoch, ServerName, UserId, }; use crate::{ @@ -279,6 +279,102 @@ impl service::globals::Data for KeyValueDatabase { Ok(signingkeys) } + fn create_registration_token( + &self, + registration_token: String, + uses: Option, + expiry_ts: Option, + ) -> Result { + let mut value = Vec::::with_capacity(16); + + if let Some(expiry_ts) = expiry_ts { + value.extend_from_slice(&u64::from(expiry_ts.get()).to_be_bytes()); + } + + if let Some(uses) = uses { + value.extend_from_slice(&uses.get().to_be_bytes()); + } + + self.registration_tokens + .insert(registration_token.as_bytes(), &value)?; + + Ok(registration_token) + } + + #[allow(clippy::similar_names)] + fn take_registration_token( + &self, + registration_token: &str, + ) -> Result< + Option<( + Option, + Option, + )>, + > { + let Some(value) = + self.registration_tokens.get(registration_token.as_bytes())? + else { + return Ok(None); + }; + + let (expiry_ts, uses) = match value.split_at_checked(8) { + Some((expiry_ts, uses)) => { + let expiry_ts = MilliSecondsSinceUnixEpoch( + ruma::UInt::try_from(u64::from_be_bytes( + expiry_ts.try_into().unwrap(), + )) + .unwrap(), + ); + + if uses.is_empty() { + // unlimited uses, limited validity + (Some(expiry_ts), None) + } else { + let uses = std::num::NonZeroU64::new(u64::from_be_bytes( + uses.try_into().unwrap(), + )) + .unwrap(); + + // limited uses, limited validity + (Some(expiry_ts), Some(uses)) + } + } + // unlimited uses, unlimited validity + None => (None, None), + }; + + self.registration_tokens.remove(registration_token.as_bytes())?; + + // TODO: consider tokens expiring in less than X minutes invalid? + let expired = expiry_ts.is_some_and(|expiry_ts| { + MilliSecondsSinceUnixEpoch::now() < expiry_ts + }); + + let used = uses.is_some_and(|uses| uses.get() < 2); + + match (used, expired) { + (false, false) => { + self.create_registration_token( + registration_token.to_owned(), + uses.map(|uses| { + std::num::NonZeroU64::new(uses.get() - 1).unwrap() + }), + expiry_ts, + )?; + + Ok(Some((uses, expiry_ts))) + } + _ => Ok(None), + } + } + + fn revoke_registration_token( + &self, + registration_token: &str, + ) -> Result<()> { + self.registration_tokens.remove(registration_token.as_bytes()) + } + fn database_version(&self) -> Result { self.global.get(b"version")?.map_or(Ok(0), |version| { utils::u64_from_bytes(&version).map_err(|_| { diff --git a/src/error.rs b/src/error.rs index 49cf7307..ff43cdfd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -137,9 +137,6 @@ pub(crate) enum Config { #[error("failed to canonicalize path {}", .1.display())] Canonicalize(#[source] std::io::Error, PathBuf), - #[error("registration token must not be empty")] - RegistrationTokenEmpty, - #[error("database and media paths overlap")] DatabaseMediaOverlap, } diff --git a/src/service/admin.rs b/src/service/admin.rs index 5e3ee770..50f42dc0 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -1,4 +1,10 @@ -use std::{collections::BTreeMap, fmt::Write, sync::Arc, time::Instant}; +use std::{ + collections::BTreeMap, + fmt::Write, + num::NonZeroU64, + sync::Arc, + time::{Instant, SystemTime}, +}; use clap::{Parser, Subcommand, ValueEnum}; use regex::Regex; @@ -27,6 +33,7 @@ use ruma::{ OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, }; use serde_json::value::to_raw_value; +use time::{macros::format_description, OffsetDateTime}; use tokio::sync::{mpsc, Mutex, RwLock}; use tracing::warn; @@ -209,6 +216,19 @@ enum AdminCommand { #[command(subcommand)] cmd: TracingFilterCommand, }, + + /// Generate n-use registration token + GenRegistrationToken { + #[arg(value_parser = parse_non_zero_u64)] + uses: Option, + #[arg(value_parser = parse_expiry)] + expiry_ts: Option, + }, + + /// Revoke registration token + RevokeRegistrationToken { + registration_token: String, + }, } #[derive(Debug, Subcommand)] @@ -260,6 +280,47 @@ enum TracingBackend { Traces, } +fn parse_expiry(input: &str) -> Result { + if let Ok(duration) = duration_str::parse(input) { + let time = SystemTime::now() + .checked_add(duration) + .ok_or_else(|| "Duration is too big".to_owned())?; + + return MilliSecondsSinceUnixEpoch::from_system_time(time) + .ok_or_else(|| "Expiry too large to be represented".to_owned()); + } + + if let Ok(time) = OffsetDateTime::parse( + input, + &format_description!("[year]-[month]-[day]"), + ) + .or_else(|_| { + OffsetDateTime::parse( + input, + &format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second]" + ), + ) + }) { + return MilliSecondsSinceUnixEpoch::from_system_time(time.into()) + .ok_or_else(|| "Expiry too large to be represented".to_owned()); + } + + Err("Could not parse expiry".to_owned()) +} + +fn parse_non_zero_u64( + input: &str, +) -> Result { + let n = + NonZeroU64::new(input.parse().map_err(|_| "Invalid amount of uses")?) + .ok_or_else(|| "Amount of uses cannot be zero".to_owned())?; + + ruma::UInt::new(n.get()) + .map(MilliSecondsSinceUnixEpoch) + .ok_or_else(|| "UInt overflow".to_owned()) +} + impl Service { pub(crate) fn new() -> Arc { let (sender, receiver) = mpsc::unbounded_channel(); @@ -1247,6 +1308,44 @@ impl Service { "Filter reloaded", )); } + AdminCommand::GenRegistrationToken { + uses, + expiry_ts, + } => { + let uses_fmt = uses.as_ref().map_or_else( + || "unlimited".to_owned(), + NonZeroU64::to_string, + ); + + let expiry_ts_fmt = expiry_ts + .as_ref() + .map_or_else(|| "never".to_owned(), |ts| format!("{ts:?}")); + + // TODO: hash registration tokens? + let registration_token = + services().globals.create_registration_token( + utils::random_string(32), + uses, + expiry_ts, + )?; + + // TODO: use matrix spoiler here? + RoomMessageEventContent::text_plain(format!( + "token: {registration_token} | uses: {uses_fmt} | expiry \ + timestamp: {expiry_ts_fmt}" + )) + } + AdminCommand::RevokeRegistrationToken { + registration_token, + } => { + services() + .globals + .revoke_registration_token(®istration_token)?; + + RoomMessageEventContent::text_plain( + "Successfully revoked registration token.", + ) + } }; Ok(reply_message_content) diff --git a/src/service/globals.rs b/src/service/globals.rs index ede9f5c6..2e5db687 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -570,6 +570,34 @@ impl Service { } } + pub(crate) fn create_registration_token( + &self, + registration_token: String, + uses: Option, + expiry_ts: Option, + ) -> Result { + self.db.create_registration_token(registration_token, uses, expiry_ts) + } + + pub(crate) fn take_registration_token( + &self, + registration_token: &str, + ) -> Result< + Option<( + Option, + Option, + )>, + > { + self.db.take_registration_token(registration_token) + } + + pub(crate) fn revoke_registration_token( + &self, + registration_token: &str, + ) -> Result<()> { + self.db.revoke_registration_token(registration_token) + } + /// Filters the key map of multiple servers down to keys that should be /// accepted given the expiry time, room version, and timestamp of the /// paramters diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 124e81b3..27bdc138 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -115,6 +115,23 @@ pub(crate) trait Data: Send + Sync { &self, origin: &ServerName, ) -> Result>; + fn create_registration_token( + &self, + registration_token: String, + uses: Option, + expiry_ts: Option, + ) -> Result; + fn take_registration_token( + &self, + registration_token: &str, + ) -> Result< + Option<( + Option, + Option, + )>, + >; + fn revoke_registration_token(&self, registration_token: &str) + -> Result<()>; fn database_version(&self) -> Result; fn bump_database_version(&self, new_version: u64) -> Result<()>; } diff --git a/src/service/uiaa.rs b/src/service/uiaa.rs index 566db934..9e46d992 100644 --- a/src/service/uiaa.rs +++ b/src/service/uiaa.rs @@ -120,19 +120,21 @@ impl Service { // Password was correct! Let's add it to `completed` uiaainfo.completed.push(AuthType::Password); } - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) - == services().globals.config.registration_token.as_deref() - { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { - uiaainfo.auth_error = - Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::forbidden(), - message: "Invalid registration token.".to_owned(), - }); - return Ok((false, uiaainfo)); - } + AuthData::RegistrationToken(t) + if services() + .globals + .take_registration_token(&t.token)? + .is_some() => + { + uiaainfo.completed.push(AuthType::RegistrationToken); + } + AuthData::RegistrationToken(_) => { + uiaainfo.auth_error = + Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); } AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); diff --git a/tests/integrations/snapshots/integrations__check_config__invalid_keys@status_code.snap b/tests/integrations/snapshots/integrations__check_config__invalid_keys@status_code.snap index 2278288a..568f467c 100644 --- a/tests/integrations/snapshots/integrations__check_config__invalid_keys@status_code.snap +++ b/tests/integrations/snapshots/integrations__check_config__invalid_keys@status_code.snap @@ -1,7 +1,6 @@ --- source: tests/integrations/check_config.rs description: A config with invalid keys fails -snapshot_kind: text --- Some( 1, diff --git a/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap b/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap index 83cd650b..1e828065 100644 --- a/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap +++ b/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap @@ -8,4 +8,4 @@ Error: failed to validate configuration | 1 | some_name = "example.com" | ^^^^^^^^^ -unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password` +unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `require_registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password`