Merge branch 'registration-tokens' into 'main'

Draft: introduce temporary registration tokens

Closes #108

See merge request matrix/grapevine!181
This commit is contained in:
mikoto 2025-06-03 04:50:45 +00:00
commit a7a9b244f3
13 changed files with 312 additions and 23 deletions

41
Cargo.lock generated
View file

@ -56,6 +56,12 @@ dependencies = [
"password-hash", "password-hash",
] ]
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]] [[package]]
name = "as_variant" name = "as_variant"
version = "1.3.0" version = "1.3.0"
@ -407,6 +413,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "clang-sys" name = "clang-sys"
version = "1.8.1" version = "1.8.1"
@ -623,6 +638,20 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]]
name = "duration-str"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9add086174f60bcbcfde7175e71dcfd99da24dfd12f611d0faf74f4f26e15a06"
dependencies = [
"chrono",
"rust_decimal",
"serde",
"thiserror 2.0.12",
"time",
"winnow",
]
[[package]] [[package]]
name = "ed25519" name = "ed25519"
version = "2.2.3" version = "2.2.3"
@ -922,6 +951,7 @@ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"clap", "clap",
"duration-str",
"futures-util", "futures-util",
"hmac", "hmac",
"html-escape", "html-escape",
@ -964,6 +994,7 @@ dependencies = [
"thiserror 2.0.12", "thiserror 2.0.12",
"thread_local", "thread_local",
"tikv-jemallocator", "tikv-jemallocator",
"time",
"tokio", "tokio",
"toml", "toml",
"tower 0.5.2", "tower 0.5.2",
@ -2723,6 +2754,16 @@ dependencies = [
"rust-librocksdb-sys", "rust-librocksdb-sys",
] ]
[[package]]
name = "rust_decimal"
version = "1.37.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.24" version = "0.1.24"

View file

@ -96,6 +96,7 @@ axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-se
base64 = "0.22.1" base64 = "0.22.1"
bytes = "1.10.1" bytes = "1.10.1"
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
duration-str = "0.17.0"
futures-util = { version = "0.3.31", default-features = false } futures-util = { version = "0.3.31", default-features = false }
hmac = "0.12.1" hmac = "0.12.1"
html-escape = "0.2.13" html-escape = "0.2.13"
@ -131,6 +132,7 @@ serde_json = { version = "1.0.140", features = ["raw_value"] }
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
sha-1 = "0.10.1" sha-1 = "0.10.1"
strum = { version = "0.27.1", features = ["derive"] } strum = { version = "0.27.1", features = ["derive"] }
time = "0.3.41"
thiserror = "2.0.12" thiserror = "2.0.12"
thread_local = "1.1.8" thread_local = "1.1.8"
tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }

View file

@ -158,7 +158,7 @@ pub(crate) async fn register_route(
// UIAA // UIAA
let mut uiaainfo; let mut uiaainfo;
let skip_auth = if services().globals.config.registration_token.is_some() { let skip_auth = if services().globals.config.require_registration_token {
// Registration token required // Registration token required
uiaainfo = UiaaInfo { uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {

View file

@ -14,6 +14,7 @@ use ruma::{
}; };
use serde::Deserialize; use serde::Deserialize;
use strum::{Display, EnumIter, IntoEnumIterator}; use strum::{Display, EnumIter, IntoEnumIterator};
use tracing::warn;
use crate::{error, utils::partial_canonicalize}; use crate::{error, utils::partial_canonicalize};
@ -57,6 +58,8 @@ pub(crate) struct Config {
pub(crate) allow_registration: bool, pub(crate) allow_registration: bool,
pub(crate) registration_token: Option<String>, pub(crate) registration_token: Option<String>,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub(crate) require_registration_token: bool,
#[serde(default = "true_fn")]
pub(crate) allow_encryption: bool, pub(crate) allow_encryption: bool,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub(crate) allow_room_creation: bool, pub(crate) allow_room_creation: bool,
@ -557,8 +560,11 @@ where
) )
.map_err(|e| Error::Parse(e, path.to_owned()))?; .map_err(|e| Error::Parse(e, path.to_owned()))?;
if config.registration_token.as_deref() == Some("") { if config.registration_token.is_some() {
return Err(Error::RegistrationTokenEmpty); warn!(
"configuration registration token is no longer supported, use the \
admin room to generate one"
);
} }
match &config.media.backend { match &config.media.backend {

View file

@ -35,6 +35,7 @@ pub(crate) struct KeyValueDatabase {
// Trees "owned" by `self::key_value::globals` // Trees "owned" by `self::key_value::globals`
pub(super) global: Arc<dyn KvTree>, pub(super) global: Arc<dyn KvTree>,
pub(super) server_signingkeys: Arc<dyn KvTree>, pub(super) server_signingkeys: Arc<dyn KvTree>,
pub(super) registration_tokens: Arc<dyn KvTree>,
// Trees "owned" by `self::key_value::users` // Trees "owned" by `self::key_value::users`
pub(super) userid_password: Arc<dyn KvTree>, pub(super) userid_password: Arc<dyn KvTree>,
@ -445,6 +446,7 @@ impl KeyValueDatabase {
senderkey_pusher: builder.open_tree("senderkey_pusher")?, senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?, global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
registration_tokens: builder.open_tree("registration_tokens")?,
}; };
Ok(db) Ok(db)

View file

@ -3,7 +3,7 @@ use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::federation::discovery::{OldVerifyKey, ServerSigningKeys}, api::federation::discovery::{OldVerifyKey, ServerSigningKeys},
signatures::Ed25519KeyPair, signatures::Ed25519KeyPair,
DeviceId, ServerName, UserId, DeviceId, MilliSecondsSinceUnixEpoch, ServerName, UserId,
}; };
use crate::{ use crate::{
@ -279,6 +279,102 @@ impl service::globals::Data for KeyValueDatabase {
Ok(signingkeys) Ok(signingkeys)
} }
fn create_registration_token(
&self,
registration_token: String,
uses: Option<std::num::NonZeroU64>,
expiry_ts: Option<MilliSecondsSinceUnixEpoch>,
) -> Result<String> {
let mut value = Vec::<u8>::with_capacity(16);
if let Some(expiry_ts) = expiry_ts {
value.extend_from_slice(&u64::from(expiry_ts.get()).to_be_bytes());
}
if let Some(uses) = uses {
value.extend_from_slice(&uses.get().to_be_bytes());
}
self.registration_tokens
.insert(registration_token.as_bytes(), &value)?;
Ok(registration_token)
}
#[allow(clippy::similar_names)]
fn take_registration_token(
&self,
registration_token: &str,
) -> Result<
Option<(
Option<std::num::NonZeroU64>,
Option<MilliSecondsSinceUnixEpoch>,
)>,
> {
let Some(value) =
self.registration_tokens.get(registration_token.as_bytes())?
else {
return Ok(None);
};
let (expiry_ts, uses) = match value.split_at_checked(8) {
Some((expiry_ts, uses)) => {
let expiry_ts = MilliSecondsSinceUnixEpoch(
ruma::UInt::try_from(u64::from_be_bytes(
expiry_ts.try_into().unwrap(),
))
.unwrap(),
);
if uses.is_empty() {
// unlimited uses, limited validity
(Some(expiry_ts), None)
} else {
let uses = std::num::NonZeroU64::new(u64::from_be_bytes(
uses.try_into().unwrap(),
))
.unwrap();
// limited uses, limited validity
(Some(expiry_ts), Some(uses))
}
}
// unlimited uses, unlimited validity
None => (None, None),
};
self.registration_tokens.remove(registration_token.as_bytes())?;
// TODO: consider tokens expiring in less than X minutes invalid?
let expired = expiry_ts.is_some_and(|expiry_ts| {
MilliSecondsSinceUnixEpoch::now() < expiry_ts
});
let used = uses.is_some_and(|uses| uses.get() < 2);
match (used, expired) {
(false, false) => {
self.create_registration_token(
registration_token.to_owned(),
uses.map(|uses| {
std::num::NonZeroU64::new(uses.get() - 1).unwrap()
}),
expiry_ts,
)?;
Ok(Some((uses, expiry_ts)))
}
_ => Ok(None),
}
}
fn revoke_registration_token(
&self,
registration_token: &str,
) -> Result<()> {
self.registration_tokens.remove(registration_token.as_bytes())
}
fn database_version(&self) -> Result<u64> { fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| { self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| { utils::u64_from_bytes(&version).map_err(|_| {

View file

@ -137,9 +137,6 @@ pub(crate) enum Config {
#[error("failed to canonicalize path {}", .1.display())] #[error("failed to canonicalize path {}", .1.display())]
Canonicalize(#[source] std::io::Error, PathBuf), Canonicalize(#[source] std::io::Error, PathBuf),
#[error("registration token must not be empty")]
RegistrationTokenEmpty,
#[error("database and media paths overlap")] #[error("database and media paths overlap")]
DatabaseMediaOverlap, DatabaseMediaOverlap,
} }

View file

@ -1,4 +1,10 @@
use std::{collections::BTreeMap, fmt::Write, sync::Arc, time::Instant}; use std::{
collections::BTreeMap,
fmt::Write,
num::NonZeroU64,
sync::Arc,
time::{Instant, SystemTime},
};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use regex::Regex; use regex::Regex;
@ -28,6 +34,7 @@ use ruma::{
OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use time::{macros::format_description, OffsetDateTime};
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::warn; use tracing::warn;
@ -210,6 +217,19 @@ enum AdminCommand {
#[command(subcommand)] #[command(subcommand)]
cmd: TracingFilterCommand, cmd: TracingFilterCommand,
}, },
/// Generate n-use registration token
GenRegistrationToken {
#[arg(value_parser = parse_non_zero_u64)]
uses: Option<NonZeroU64>,
#[arg(value_parser = parse_expiry)]
expiry_ts: Option<MilliSecondsSinceUnixEpoch>,
},
/// Revoke registration token
RevokeRegistrationToken {
registration_token: String,
},
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -261,6 +281,47 @@ enum TracingBackend {
Traces, Traces,
} }
fn parse_expiry(input: &str) -> Result<MilliSecondsSinceUnixEpoch, String> {
if let Ok(duration) = duration_str::parse(input) {
let time = SystemTime::now()
.checked_add(duration)
.ok_or_else(|| "Duration is too big".to_owned())?;
return MilliSecondsSinceUnixEpoch::from_system_time(time)
.ok_or_else(|| "Expiry too large to be represented".to_owned());
}
if let Ok(time) = OffsetDateTime::parse(
input,
&format_description!("[year]-[month]-[day]"),
)
.or_else(|_| {
OffsetDateTime::parse(
input,
&format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second]"
),
)
}) {
return MilliSecondsSinceUnixEpoch::from_system_time(time.into())
.ok_or_else(|| "Expiry too large to be represented".to_owned());
}
Err("Could not parse expiry".to_owned())
}
fn parse_non_zero_u64(
input: &str,
) -> Result<MilliSecondsSinceUnixEpoch, String> {
let n =
NonZeroU64::new(input.parse().map_err(|_| "Invalid amount of uses")?)
.ok_or_else(|| "Amount of uses cannot be zero".to_owned())?;
ruma::UInt::new(n.get())
.map(MilliSecondsSinceUnixEpoch)
.ok_or_else(|| "UInt overflow".to_owned())
}
impl Service { impl Service {
pub(crate) fn new() -> Arc<Self> { pub(crate) fn new() -> Arc<Self> {
let (sender, receiver) = mpsc::unbounded_channel(); let (sender, receiver) = mpsc::unbounded_channel();
@ -1240,6 +1301,44 @@ impl Service {
"Filter reloaded", "Filter reloaded",
)); ));
} }
AdminCommand::GenRegistrationToken {
uses,
expiry_ts,
} => {
let uses_fmt = uses.as_ref().map_or_else(
|| "unlimited".to_owned(),
NonZeroU64::to_string,
);
let expiry_ts_fmt = expiry_ts
.as_ref()
.map_or_else(|| "never".to_owned(), |ts| format!("{ts:?}"));
// TODO: hash registration tokens?
let registration_token =
services().globals.create_registration_token(
utils::random_string(32),
uses,
expiry_ts,
)?;
// TODO: use matrix spoiler here?
RoomMessageEventContent::text_plain(format!(
"token: {registration_token} | uses: {uses_fmt} | expiry \
timestamp: {expiry_ts_fmt}"
))
}
AdminCommand::RevokeRegistrationToken {
registration_token,
} => {
services()
.globals
.revoke_registration_token(&registration_token)?;
RoomMessageEventContent::text_plain(
"Successfully revoked registration token.",
)
}
}; };
Ok(reply_message_content) Ok(reply_message_content)

View file

@ -565,6 +565,34 @@ impl Service {
} }
} }
pub(crate) fn create_registration_token(
&self,
registration_token: String,
uses: Option<std::num::NonZeroU64>,
expiry_ts: Option<MilliSecondsSinceUnixEpoch>,
) -> Result<String> {
self.db.create_registration_token(registration_token, uses, expiry_ts)
}
pub(crate) fn take_registration_token(
&self,
registration_token: &str,
) -> Result<
Option<(
Option<std::num::NonZeroU64>,
Option<MilliSecondsSinceUnixEpoch>,
)>,
> {
self.db.take_registration_token(registration_token)
}
pub(crate) fn revoke_registration_token(
&self,
registration_token: &str,
) -> Result<()> {
self.db.revoke_registration_token(registration_token)
}
/// Filters the key map of multiple servers down to keys that should be /// Filters the key map of multiple servers down to keys that should be
/// accepted given the expiry time, room version, and timestamp of the /// accepted given the expiry time, room version, and timestamp of the
/// paramters /// paramters

View file

@ -115,6 +115,23 @@ pub(crate) trait Data: Send + Sync {
&self, &self,
origin: &ServerName, origin: &ServerName,
) -> Result<Option<SigningKeys>>; ) -> Result<Option<SigningKeys>>;
fn create_registration_token(
&self,
registration_token: String,
uses: Option<std::num::NonZeroU64>,
expiry_ts: Option<MilliSecondsSinceUnixEpoch>,
) -> Result<String>;
fn take_registration_token(
&self,
registration_token: &str,
) -> Result<
Option<(
Option<std::num::NonZeroU64>,
Option<MilliSecondsSinceUnixEpoch>,
)>,
>;
fn revoke_registration_token(&self, registration_token: &str)
-> Result<()>;
fn database_version(&self) -> Result<u64>; fn database_version(&self) -> Result<u64>;
fn bump_database_version(&self, new_version: u64) -> Result<()>; fn bump_database_version(&self, new_version: u64) -> Result<()>;
} }

View file

@ -120,19 +120,21 @@ impl Service {
// Password was correct! Let's add it to `completed` // Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password); uiaainfo.completed.push(AuthType::Password);
} }
AuthData::RegistrationToken(t) => { AuthData::RegistrationToken(t)
if Some(t.token.trim()) if services()
== services().globals.config.registration_token.as_deref() .globals
{ .take_registration_token(&t.token)?
uiaainfo.completed.push(AuthType::RegistrationToken); .is_some() =>
} else { {
uiaainfo.auth_error = uiaainfo.completed.push(AuthType::RegistrationToken);
Some(ruma::api::client::error::StandardErrorBody { }
kind: ErrorKind::forbidden(), AuthData::RegistrationToken(_) => {
message: "Invalid registration token.".to_owned(), uiaainfo.auth_error =
}); Some(ruma::api::client::error::StandardErrorBody {
return Ok((false, uiaainfo)); kind: ErrorKind::forbidden(),
} message: "Invalid registration token.".to_owned(),
});
return Ok((false, uiaainfo));
} }
AuthData::Dummy(_) => { AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy); uiaainfo.completed.push(AuthType::Dummy);

View file

@ -1,7 +1,6 @@
--- ---
source: tests/integrations/check_config.rs source: tests/integrations/check_config.rs
description: A config with invalid keys fails description: A config with invalid keys fails
snapshot_kind: text
--- ---
Some( Some(
1, 1,

View file

@ -8,4 +8,4 @@ Error: failed to validate configuration
| |
1 | some_name = "example.com" 1 | some_name = "example.com"
| ^^^^^^^^^ | ^^^^^^^^^
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password` unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `require_registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password`