diff --git a/Cargo.lock b/Cargo.lock index a3dd7de4..edd472e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -521,6 +521,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.6" @@ -558,6 +564,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.8.0" @@ -922,6 +942,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "dashmap", "futures-util", "hmac", "html-escape", @@ -945,6 +966,7 @@ dependencies = [ "predicates", "prometheus", "proxy-header", + "quanta", "rand 0.8.5", "regex", "reqwest", @@ -1000,6 +1022,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -2226,6 +2254,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "quanta" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.7" @@ -2355,6 +2398,15 @@ dependencies = [ "getrandom 0.3.2", ] +[[package]] +name = "raw-cpuid" +version = "11.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" +dependencies = [ + "bitflags 2.9.0", +] + [[package]] name = "redox_syscall" version = "0.5.10" diff --git a/Cargo.toml b/Cargo.toml index 6c61f19e..397976d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,6 +96,7 @@ axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-se base64 = "0.22.1" bytes = "1.10.1" clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } +dashmap = "6.1.0" futures-util = { version = "0.3.31", default-features = false } hmac = "0.12.1" html-escape = "0.2.13" @@ -116,6 +117,7 @@ phf = { version = "0.11.3", features = ["macros"] } pin-project-lite = "0.2.16" prometheus = "0.13.4" proxy-header = { version = "0.1.2", features = ["tokio"] } +quanta = "0.12.5" rand = "0.8.5" regex = "1.11.1" reqwest = { version = "0.12.15", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index a514c6b8..69bde2b2 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,8 @@ -use std::{collections::BTreeMap, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + iter::FromIterator, + str::{self, FromStr as _}, +}; use axum::{ async_trait, @@ -17,8 +21,8 @@ use http::{Request, StatusCode}; use http_body_util::BodyExt; use ruma::{ api::{ - client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata, - OutgoingResponse, + client::error::{ErrorKind, RetryAfter}, + AuthScheme, IncomingRequest, Metadata, OutgoingResponse, }, server_util::authorization::XMatrix, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, @@ -28,7 +32,13 @@ use serde::Deserialize; use tracing::{error, warn}; use super::{Ar, Ra}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{ + appservice::RegistrationInfo, + rate_limiting::{Target, XForwardedFor}, + }, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -51,10 +61,13 @@ struct ArPieces { /// Non-generic part of [`Ar::from_request()`]. Splitting this out reduces /// binary size by ~10%. #[allow(clippy::too_many_lines)] -async fn ar_from_request_inner( +async fn ar_from_request_inner( req: axum::extract::Request, metadata: Metadata, -) -> Result { +) -> Result +where + T: IncomingRequest, +{ #[derive(Deserialize)] struct QueryParams { access_token: Option, @@ -350,6 +363,45 @@ async fn ar_from_request_inner( } }; + let mut target = + Err(Error::BadRequest(ErrorKind::forbidden(), "Missing identifier.")); + + if let Some(servername) = sender_servername.as_ref() { + target = Ok(Target::Server(servername.to_owned())); + } + + if let Some(user) = sender_user.as_ref() { + target = Ok(Target::User(user.to_owned())); + } + + if let Some(info) = appservice_info.as_ref() { + if info.registration.rate_limited.unwrap_or(true) { + target = Ok(Target::Appservice(info.registration.id.clone())); + } + } + + if let Some(header) = parts.headers.get("X-Forwarded-For") { + let value = header.to_str().ok(); + + let x_forwarded_for = + value.map(XForwardedFor::from_str).and_then(Result::ok); + + if let Some(ip) = x_forwarded_for.and_then(|value| value.ip) { + target = Ok(Target::Ip(ip)); + } + } + + if let Err(retry_after) = + { services().rate_limiting.update_or_reject::(target?) } + { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { + retry_after: Some(RetryAfter::Delay(retry_after)), + }, + "Rate limit exceeded.", + )); + } + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; @@ -411,7 +463,7 @@ where req: axum::extract::Request, _state: &S, ) -> Result { - let pieces = ar_from_request_inner(req, T::METADATA).await?; + let pieces = ar_from_request_inner::(req, T::METADATA).await?; let body = T::try_from_http_request(pieces.http_request, &pieces.path_params) diff --git a/src/config.rs b/src/config.rs index b597e4a9..31654211 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use std::{ borrow::Cow, - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt::{self, Display}, net::{IpAddr, Ipv4Addr}, path::{Path, PathBuf}, @@ -19,6 +19,7 @@ use crate::{error, utils::partial_canonicalize}; mod env_filter_clone; mod proxy; +pub(crate) mod rate_limiting; pub(crate) use env_filter_clone::EnvFilterClone; use proxy::ProxyConfig; @@ -69,6 +70,9 @@ pub(crate) struct Config { pub(crate) observability: ObservabilityConfig, #[serde(default)] pub(crate) turn: TurnConfig, + #[serde(default = "rate_limiting::default_rate_limit")] + pub(crate) rate_limiting: + HashMap, pub(crate) emergency_password: Option, } diff --git a/src/config/rate_limiting.rs b/src/config/rate_limiting.rs new file mode 100644 index 00000000..4209a48c --- /dev/null +++ b/src/config/rate_limiting.rs @@ -0,0 +1,93 @@ +use std::{collections::HashMap, num::NonZeroU64}; + +use serde::Deserialize; + +#[derive( + Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, +)] +#[serde(rename_all = "snake_case")] +pub(crate) enum Endpoint { + Registration, + Login, + RegistrationTokenValidity, + Message, + Join, + Invite, + Knock, + CreateMedia, + Transaction, + FederatedJoin, + FederatedInvite, + FederatedKnock, + + #[default] + Global, + + #[serde(untagged)] + Custom(String), +} + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct Config { + #[serde(default = "default_non_zero", flatten)] + pub(crate) timeframe: Timeframe, + #[serde(default = "default_non_zero")] + pub(crate) burst_capacity: NonZeroU64, +} + +impl Config { + pub(crate) fn delay_tolerance(&self) -> u64 { + self.timeframe.emission_interval_ns() * self.burst_capacity.get() + } +} + +impl Default for Config { + fn default() -> Self { + Config { + timeframe: Timeframe::PerSecond(NonZeroU64::MIN), + burst_capacity: NonZeroU64::MIN, + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +#[allow(clippy::enum_variant_names)] +pub(crate) enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub(crate) fn time_window_ms(&self) -> u64 { + use Timeframe::{PerDay, PerHour, PerMinute, PerSecond}; + + let s = match self { + PerSecond(_) => 1_u64, + PerMinute(_) => 60, + PerHour(_) => 60 * 60, + PerDay(_) => 60 * 60 * 24, + }; + + s.checked_mul(1_000_000_000).expect("time window overflow") + } + + pub(crate) fn emission_interval_ns(&self) -> u64 { + use Timeframe::{PerDay, PerHour, PerMinute, PerSecond}; + + let tf = match self { + PerSecond(tf) | PerMinute(tf) | PerHour(tf) | PerDay(tf) => tf, + }; + + self.time_window_ms() / tf.get() + } +} + +fn default_non_zero() -> NonZeroU64 { + NonZeroU64::MIN +} +pub(super) fn default_rate_limit() -> HashMap { + HashMap::from_iter([(Endpoint::default(), Config::default())]) +} diff --git a/src/service.rs b/src/service.rs index 38d44941..12f6281b 100644 --- a/src/service.rs +++ b/src/service.rs @@ -10,6 +10,7 @@ pub(crate) mod key_backups; pub(crate) mod media; pub(crate) mod pdu; pub(crate) mod pusher; +pub(crate) mod rate_limiting; pub(crate) mod rooms; pub(crate) mod sending; pub(crate) mod transaction_ids; @@ -36,6 +37,7 @@ pub(crate) struct Services { pub(crate) key_backups: key_backups::Service, pub(crate) media: media::Service, pub(crate) sending: Arc, + pub(crate) rate_limiting: rate_limiting::Service, } impl Services { @@ -121,6 +123,10 @@ impl Services { db, }, sending: sending::Service::new(db, &config), + rate_limiting: rate_limiting::Service::build( + services().globals.config.rate_limiting.clone(), + quanta::Clock::new(), + ), globals: globals::Service::new(db, config, reload_handles)?, }) diff --git a/src/service/rate_limiting.rs b/src/service/rate_limiting.rs new file mode 100644 index 00000000..1b7cad2d --- /dev/null +++ b/src/service/rate_limiting.rs @@ -0,0 +1,413 @@ +use std::{ + collections::HashMap, + hash::Hash, + net::{AddrParseError, IpAddr}, + str::FromStr, + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; + +use dashmap::DashMap; +use quanta::Clock; +use ruma::{ + api::{IncomingRequest, Metadata}, + OwnedServerName, OwnedUserId, +}; + +use crate::{ + config::rate_limiting::{self, Endpoint}, + Result, +}; + +pub(crate) struct Service { + clock: Clock, + store: DashMap<(Target, Endpoint), AtomicU64>, + config: HashMap, +} + +impl Service { + pub(crate) fn build( + config: HashMap, + clock: Clock, + ) -> Self { + Self { + clock, + store: DashMap::new(), + config, + } + } + + pub(crate) fn update_or_reject( + &self, + target: Target, + ) -> Result<(), Duration> { + let now_ns = self.clock.delta_as_nanos(0, self.clock.raw()); + + let endpoint = Endpoint::from(IR::METADATA); + + let config = self.config.get(&endpoint).cloned().unwrap_or_default(); + + let (emission_interval_ns, delay_tolerance) = + (config.timeframe.emission_interval_ns(), config.delay_tolerance()); + + let entry = self.store.entry((target, endpoint)).or_insert_with(|| { + AtomicU64::new(now_ns.saturating_sub(delay_tolerance)) + }); + + let theoretical_arrival_time = entry.load(Ordering::Acquire); + + let new_theoretical_arrival_time = + theoretical_arrival_time.max(now_ns) + emission_interval_ns; + + if (new_theoretical_arrival_time - now_ns) <= delay_tolerance { + entry.store(new_theoretical_arrival_time, Ordering::Release); + + Ok(()) + } else { + Err(Duration::from_nanos( + theoretical_arrival_time.saturating_sub(now_ns), + )) + } + } +} + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub(crate) enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice(String), + Ip(IpAddr), +} + +#[derive(Debug, PartialEq)] +pub(crate) struct XForwardedFor { + pub(crate) ip: Option, + pub(crate) proxies: Vec, +} + +impl FromStr for XForwardedFor { + type Err = AddrParseError; + + fn from_str(s: &str) -> std::result::Result { + let ips = s.split(',').try_fold(Vec::new(), |mut ips, part| { + let trimmed = part.trim(); + + if trimmed.is_empty() { + Ok(ips) + } else { + match IpAddr::from_str(trimmed) { + Ok(ip) => { + ips.push(ip); + + Ok(ips) + } + Err(e) => Err(e), + } + } + })?; + + let (ip, proxies) = if ips.is_empty() { + (None, Vec::new()) + } else { + (Some(ips[0]), ips[1..].to_vec()) + }; + + Ok(XForwardedFor { + ip, + proxies, + }) + } +} + +impl From for Endpoint { + fn from(metadata: Metadata) -> Self { + use ruma::api::{ + client::{ + account::{check_registration_token_validity, register}, + knock::knock_room, + media::{create_content, create_content_async}, + membership::{ + invite_user, join_room_by_id, join_room_by_id_or_alias, + }, + message::send_message_event, + session::login, + state::send_state_event, + }, + federation::{ + knock::send_knock, + membership::{create_invite, create_join_event}, + transactions::send_transaction_message, + }, + }; + + #[allow(deprecated)] + match metadata { + register::v3::Request::METADATA => Endpoint::Registration, + login::v3::Request::METADATA => Endpoint::Login, + check_registration_token_validity::v1::Request::METADATA => { + Endpoint::RegistrationTokenValidity + } + send_message_event::v3::Request::METADATA + | send_state_event::v3::Request::METADATA => Endpoint::Message, + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Endpoint::Join, + invite_user::v3::Request::METADATA => Endpoint::Invite, + create_content::v3::Request::METADATA + | create_content_async::v3::Request::METADATA => { + Endpoint::CreateMedia + } + send_transaction_message::v1::Request::METADATA => { + Endpoint::Transaction + } + create_join_event::v1::Request::METADATA + | create_join_event::v2::Request::METADATA => { + Endpoint::FederatedJoin + } + create_invite::v1::Request::METADATA + | create_invite::v2::Request::METADATA => Endpoint::FederatedInvite, + send_knock::v1::Request::METADATA => Endpoint::FederatedKnock, + knock_room::v3::Request::METADATA => Endpoint::Knock, + _ => Self::default(), + } + } +} + +#[cfg(test)] +mod tests { + use std::{net::Ipv4Addr, num::NonZeroU64}; + + use quanta::Clock; + use ruma::api::client::{account::register, session::login}; + + use super::*; + use crate::config::rate_limiting::{Config, Timeframe}; + + struct LoginRequest; + + impl IncomingRequest for LoginRequest { + type EndpointError = + ::EndpointError; + type OutgoingResponse = + ::OutgoingResponse; + + const METADATA: Metadata = + ::METADATA; + + fn try_from_http_request( + _: http::Request, + _: &[S], + ) -> std::result::Result + where + B: AsRef<[u8]>, + S: AsRef, + { + unimplemented!() + } + } + + struct RegisterRequest; + + impl IncomingRequest for RegisterRequest { + type EndpointError = + ::EndpointError; + type OutgoingResponse = + ::OutgoingResponse; + + const METADATA: Metadata = + ::METADATA; + + fn try_from_http_request( + _: http::Request, + _: &[S], + ) -> std::result::Result + where + B: AsRef<[u8]>, + S: AsRef, + { + unimplemented!() + } + } + + fn test_config(emission_interval_ns: u64, burst_capacity: u64) -> Config { + Config { + // Directly use nanoseconds without conversion + timeframe: Timeframe::PerSecond( + NonZeroU64::new(1_000_000_000 / emission_interval_ns).unwrap(), + ), + burst_capacity: NonZeroU64::new(burst_capacity).unwrap(), + } + } + + #[test] + fn test_basic_rate_limiting() { + let (clock, mock) = Clock::mock(); + + let config = test_config(1_000_000_000, 1); + + let service = Service::build( + HashMap::from_iter([(Endpoint::Login, config)]), + clock, + ); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + service.update_or_reject::(target.clone()).unwrap(); + + assert!(service + .update_or_reject::(target.clone()) + .is_err()); + + mock.increment(1_000_000_000); + + service.update_or_reject::(target).unwrap(); + } + + #[test] + fn test_burst_capacity() { + let (clock, mock) = Clock::mock(); + + let config = test_config(1_000_000_000, 3); + + let service = Service::build( + HashMap::from_iter([(Endpoint::Login, config)]), + clock, + ); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + service.update_or_reject::(target.clone()).unwrap(); + service.update_or_reject::(target.clone()).unwrap(); + service.update_or_reject::(target.clone()).unwrap(); + + assert!(service + .update_or_reject::(target.clone()) + .is_err()); + + mock.increment(1_000_000_000); + + service.update_or_reject::(target).unwrap(); + } + + #[test] + fn test_different_target_types() { + let config = test_config(1_000_000_000, 1); + + let service = Service::build( + HashMap::from_iter([(Endpoint::Login, config)]), + Clock::new(), + ); + + let targets = vec![ + Target::User(OwnedUserId::try_from("@user:example.com").unwrap()), + Target::Server(OwnedServerName::try_from("example.com").unwrap()), + Target::Appservice("my_appservice".to_owned()), + Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), + ]; + + for target in targets { + service.update_or_reject::(target.clone()).unwrap(); + + assert!(service.update_or_reject::(target).is_err()); + } + } + + #[test] + fn test_different_endpoints() { + let config = test_config(1_000_000_000, 1); + + let service = Service::build( + HashMap::from_iter([ + (Endpoint::Login, config.clone()), + (Endpoint::Registration, config), + ]), + Clock::new(), + ); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + service.update_or_reject::(target.clone()).unwrap(); + + service.update_or_reject::(target.clone()).unwrap(); + + assert!(service + .update_or_reject::(target.clone()) + .is_err()); + + assert!(service.update_or_reject::(target).is_err()); + } + + #[test] + fn test_default_config() { + let (clock, mock) = Clock::mock(); + + let service = Service::build(HashMap::new(), clock); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + service.update_or_reject::(target.clone()).unwrap(); + + assert!(service + .update_or_reject::(target.clone()) + .is_err()); + + mock.increment(1_000_000_000); + + service.update_or_reject::(target).unwrap(); + } + + #[test] + fn test_very_high_rate() { + let (clock, mock) = Clock::mock(); + + let config = test_config(1_000, 10); + + let service = Service::build( + HashMap::from_iter([(Endpoint::Login, config)]), + clock, + ); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + for _ in 0..10 { + service.update_or_reject::(target.clone()).unwrap(); + } + + let value = mock.value(); + mock.decrement(value); + + assert!(service.update_or_reject::(target).is_err()); + } + + #[test] + fn test_returned_delay() { + let (clock, mock) = Clock::mock(); + + let config = test_config(1_000_000_000, 1); + + let service = Service::build( + HashMap::from_iter([(Endpoint::Login, config)]), + clock, + ); + + let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); + + service.update_or_reject::(target.clone()).unwrap(); + + if let Err(delay) = + service.update_or_reject::(target.clone()) + { + assert_eq!(delay, Duration::from_secs(1)); + } else { + panic!("Expected rate limit error"); + } + + mock.increment(500_000_000); + + if let Err(delay) = service.update_or_reject::(target) { + assert_eq!(delay, Duration::from_millis(500)); + } else { + panic!("Expected rate limit error"); + } + } +} diff --git a/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap b/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap index 83cd650b..0b645eae 100644 --- a/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap +++ b/tests/integrations/snapshots/integrations__check_config__invalid_keys@stderr.snap @@ -8,4 +8,4 @@ Error: failed to validate configuration | 1 | some_name = "example.com" | ^^^^^^^^^ -unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password` +unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `rate_limiting`, `emergency_password`