Merge branch 'rate-limiting' into 'main'

Draft: Rate limiting

See merge request matrix/grapevine!183
This commit is contained in:
mikoto 2025-08-06 14:33:06 +00:00
commit 612a5b51cc
8 changed files with 631 additions and 9 deletions

52
Cargo.lock generated
View file

@ -521,6 +521,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crypto-common"
version = "0.1.6"
@ -558,6 +564,20 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.8.0"
@ -922,6 +942,7 @@ dependencies = [
"base64 0.22.1",
"bytes",
"clap",
"dashmap",
"futures-util",
"hmac",
"html-escape",
@ -945,6 +966,7 @@ dependencies = [
"predicates",
"prometheus",
"proxy-header",
"quanta",
"rand 0.8.5",
"regex",
"reqwest",
@ -1000,6 +1022,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.2"
@ -2226,6 +2254,21 @@ dependencies = [
"tokio",
]
[[package]]
name = "quanta"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quinn"
version = "0.11.7"
@ -2355,6 +2398,15 @@ dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "raw-cpuid"
version = "11.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146"
dependencies = [
"bitflags 2.9.0",
]
[[package]]
name = "redox_syscall"
version = "0.5.10"

View file

@ -96,6 +96,7 @@ axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-se
base64 = "0.22.1"
bytes = "1.10.1"
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
dashmap = "6.1.0"
futures-util = { version = "0.3.31", default-features = false }
hmac = "0.12.1"
html-escape = "0.2.13"
@ -116,6 +117,7 @@ phf = { version = "0.11.3", features = ["macros"] }
pin-project-lite = "0.2.16"
prometheus = "0.13.4"
proxy-header = { version = "0.1.2", features = ["tokio"] }
quanta = "0.12.5"
rand = "0.8.5"
regex = "1.11.1"
reqwest = { version = "0.12.15", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }

View file

@ -1,4 +1,8 @@
use std::{collections::BTreeMap, iter::FromIterator, str};
use std::{
collections::BTreeMap,
iter::FromIterator,
str::{self, FromStr as _},
};
use axum::{
async_trait,
@ -17,8 +21,8 @@ use http::{Request, StatusCode};
use http_body_util::BodyExt;
use ruma::{
api::{
client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata,
OutgoingResponse,
client::error::{ErrorKind, RetryAfter},
AuthScheme, IncomingRequest, Metadata, OutgoingResponse,
},
server_util::authorization::XMatrix,
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
@ -28,7 +32,13 @@ use serde::Deserialize;
use tracing::{error, warn};
use super::{Ar, Ra};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
use crate::{
service::{
appservice::RegistrationInfo,
rate_limiting::{Target, XForwardedFor},
},
services, Error, Result,
};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -51,10 +61,13 @@ struct ArPieces {
/// Non-generic part of [`Ar::from_request()`]. Splitting this out reduces
/// binary size by ~10%.
#[allow(clippy::too_many_lines)]
async fn ar_from_request_inner(
async fn ar_from_request_inner<T>(
req: axum::extract::Request,
metadata: Metadata,
) -> Result<ArPieces> {
) -> Result<ArPieces>
where
T: IncomingRequest,
{
#[derive(Deserialize)]
struct QueryParams {
access_token: Option<String>,
@ -350,6 +363,45 @@ async fn ar_from_request_inner(
}
};
let mut target =
Err(Error::BadRequest(ErrorKind::forbidden(), "Missing identifier."));
if let Some(servername) = sender_servername.as_ref() {
target = Ok(Target::Server(servername.to_owned()));
}
if let Some(user) = sender_user.as_ref() {
target = Ok(Target::User(user.to_owned()));
}
if let Some(info) = appservice_info.as_ref() {
if info.registration.rate_limited.unwrap_or(true) {
target = Ok(Target::Appservice(info.registration.id.clone()));
}
}
if let Some(header) = parts.headers.get("X-Forwarded-For") {
let value = header.to_str().ok();
let x_forwarded_for =
value.map(XForwardedFor::from_str).and_then(Result::ok);
if let Some(ip) = x_forwarded_for.and_then(|value| value.ip) {
target = Ok(Target::Ip(ip));
}
}
if let Err(retry_after) =
{ services().rate_limiting.update_or_reject::<T>(target?) }
{
return Err(Error::BadRequest(
ErrorKind::LimitExceeded {
retry_after: Some(RetryAfter::Delay(retry_after)),
},
"Rate limit exceeded.",
));
}
let mut http_request =
Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers;
@ -411,7 +463,7 @@ where
req: axum::extract::Request,
_state: &S,
) -> Result<Self, Self::Rejection> {
let pieces = ar_from_request_inner(req, T::METADATA).await?;
let pieces = ar_from_request_inner::<T>(req, T::METADATA).await?;
let body =
T::try_from_http_request(pieces.http_request, &pieces.path_params)

View file

@ -1,6 +1,6 @@
use std::{
borrow::Cow,
collections::{BTreeMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
fmt::{self, Display},
net::{IpAddr, Ipv4Addr},
path::{Path, PathBuf},
@ -19,6 +19,7 @@ use crate::{error, utils::partial_canonicalize};
mod env_filter_clone;
mod proxy;
pub(crate) mod rate_limiting;
pub(crate) use env_filter_clone::EnvFilterClone;
use proxy::ProxyConfig;
@ -69,6 +70,9 @@ pub(crate) struct Config {
pub(crate) observability: ObservabilityConfig,
#[serde(default)]
pub(crate) turn: TurnConfig,
#[serde(default = "rate_limiting::default_rate_limit")]
pub(crate) rate_limiting:
HashMap<rate_limiting::Endpoint, rate_limiting::Config>,
pub(crate) emergency_password: Option<String>,
}

View file

@ -0,0 +1,93 @@
use std::{collections::HashMap, num::NonZeroU64};
use serde::Deserialize;
#[derive(
Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd,
)]
#[serde(rename_all = "snake_case")]
pub(crate) enum Endpoint {
Registration,
Login,
RegistrationTokenValidity,
Message,
Join,
Invite,
Knock,
CreateMedia,
Transaction,
FederatedJoin,
FederatedInvite,
FederatedKnock,
#[default]
Global,
#[serde(untagged)]
Custom(String),
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct Config {
#[serde(default = "default_non_zero", flatten)]
pub(crate) timeframe: Timeframe,
#[serde(default = "default_non_zero")]
pub(crate) burst_capacity: NonZeroU64,
}
impl Config {
pub(crate) fn delay_tolerance(&self) -> u64 {
self.timeframe.emission_interval_ns() * self.burst_capacity.get()
}
}
impl Default for Config {
fn default() -> Self {
Config {
timeframe: Timeframe::PerSecond(NonZeroU64::MIN),
burst_capacity: NonZeroU64::MIN,
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
#[allow(clippy::enum_variant_names)]
pub(crate) enum Timeframe {
PerSecond(NonZeroU64),
PerMinute(NonZeroU64),
PerHour(NonZeroU64),
PerDay(NonZeroU64),
}
impl Timeframe {
pub(crate) fn time_window_ms(&self) -> u64 {
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
let s = match self {
PerSecond(_) => 1_u64,
PerMinute(_) => 60,
PerHour(_) => 60 * 60,
PerDay(_) => 60 * 60 * 24,
};
s.checked_mul(1_000_000_000).expect("time window overflow")
}
pub(crate) fn emission_interval_ns(&self) -> u64 {
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
let tf = match self {
PerSecond(tf) | PerMinute(tf) | PerHour(tf) | PerDay(tf) => tf,
};
self.time_window_ms() / tf.get()
}
}
fn default_non_zero() -> NonZeroU64 {
NonZeroU64::MIN
}
pub(super) fn default_rate_limit() -> HashMap<Endpoint, Config> {
HashMap::from_iter([(Endpoint::default(), Config::default())])
}

View file

@ -10,6 +10,7 @@ pub(crate) mod key_backups;
pub(crate) mod media;
pub(crate) mod pdu;
pub(crate) mod pusher;
pub(crate) mod rate_limiting;
pub(crate) mod rooms;
pub(crate) mod sending;
pub(crate) mod transaction_ids;
@ -36,6 +37,7 @@ pub(crate) struct Services {
pub(crate) key_backups: key_backups::Service,
pub(crate) media: media::Service,
pub(crate) sending: Arc<sending::Service>,
pub(crate) rate_limiting: rate_limiting::Service,
}
impl Services {
@ -121,6 +123,10 @@ impl Services {
db,
},
sending: sending::Service::new(db, &config),
rate_limiting: rate_limiting::Service::build(
services().globals.config.rate_limiting.clone(),
quanta::Clock::new(),
),
globals: globals::Service::new(db, config, reload_handles)?,
})

View file

@ -0,0 +1,413 @@
use std::{
collections::HashMap,
hash::Hash,
net::{AddrParseError, IpAddr},
str::FromStr,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use dashmap::DashMap;
use quanta::Clock;
use ruma::{
api::{IncomingRequest, Metadata},
OwnedServerName, OwnedUserId,
};
use crate::{
config::rate_limiting::{self, Endpoint},
Result,
};
pub(crate) struct Service {
clock: Clock,
store: DashMap<(Target, Endpoint), AtomicU64>,
config: HashMap<Endpoint, rate_limiting::Config>,
}
impl Service {
pub(crate) fn build(
config: HashMap<Endpoint, rate_limiting::Config>,
clock: Clock,
) -> Self {
Self {
clock,
store: DashMap::new(),
config,
}
}
pub(crate) fn update_or_reject<IR: IncomingRequest>(
&self,
target: Target,
) -> Result<(), Duration> {
let now_ns = self.clock.delta_as_nanos(0, self.clock.raw());
let endpoint = Endpoint::from(IR::METADATA);
let config = self.config.get(&endpoint).cloned().unwrap_or_default();
let (emission_interval_ns, delay_tolerance) =
(config.timeframe.emission_interval_ns(), config.delay_tolerance());
let entry = self.store.entry((target, endpoint)).or_insert_with(|| {
AtomicU64::new(now_ns.saturating_sub(delay_tolerance))
});
let theoretical_arrival_time = entry.load(Ordering::Acquire);
let new_theoretical_arrival_time =
theoretical_arrival_time.max(now_ns) + emission_interval_ns;
if (new_theoretical_arrival_time - now_ns) <= delay_tolerance {
entry.store(new_theoretical_arrival_time, Ordering::Release);
Ok(())
} else {
Err(Duration::from_nanos(
theoretical_arrival_time.saturating_sub(now_ns),
))
}
}
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub(crate) enum Target {
User(OwnedUserId),
// Server endpoints should be rate-limited on a server and room basis
Server(OwnedServerName),
Appservice(String),
Ip(IpAddr),
}
#[derive(Debug, PartialEq)]
pub(crate) struct XForwardedFor {
pub(crate) ip: Option<IpAddr>,
pub(crate) proxies: Vec<IpAddr>,
}
impl FromStr for XForwardedFor {
type Err = AddrParseError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let ips = s.split(',').try_fold(Vec::new(), |mut ips, part| {
let trimmed = part.trim();
if trimmed.is_empty() {
Ok(ips)
} else {
match IpAddr::from_str(trimmed) {
Ok(ip) => {
ips.push(ip);
Ok(ips)
}
Err(e) => Err(e),
}
}
})?;
let (ip, proxies) = if ips.is_empty() {
(None, Vec::new())
} else {
(Some(ips[0]), ips[1..].to_vec())
};
Ok(XForwardedFor {
ip,
proxies,
})
}
}
impl From<Metadata> for Endpoint {
fn from(metadata: Metadata) -> Self {
use ruma::api::{
client::{
account::{check_registration_token_validity, register},
knock::knock_room,
media::{create_content, create_content_async},
membership::{
invite_user, join_room_by_id, join_room_by_id_or_alias,
},
message::send_message_event,
session::login,
state::send_state_event,
},
federation::{
knock::send_knock,
membership::{create_invite, create_join_event},
transactions::send_transaction_message,
},
};
#[allow(deprecated)]
match metadata {
register::v3::Request::METADATA => Endpoint::Registration,
login::v3::Request::METADATA => Endpoint::Login,
check_registration_token_validity::v1::Request::METADATA => {
Endpoint::RegistrationTokenValidity
}
send_message_event::v3::Request::METADATA
| send_state_event::v3::Request::METADATA => Endpoint::Message,
join_room_by_id::v3::Request::METADATA
| join_room_by_id_or_alias::v3::Request::METADATA => Endpoint::Join,
invite_user::v3::Request::METADATA => Endpoint::Invite,
create_content::v3::Request::METADATA
| create_content_async::v3::Request::METADATA => {
Endpoint::CreateMedia
}
send_transaction_message::v1::Request::METADATA => {
Endpoint::Transaction
}
create_join_event::v1::Request::METADATA
| create_join_event::v2::Request::METADATA => {
Endpoint::FederatedJoin
}
create_invite::v1::Request::METADATA
| create_invite::v2::Request::METADATA => Endpoint::FederatedInvite,
send_knock::v1::Request::METADATA => Endpoint::FederatedKnock,
knock_room::v3::Request::METADATA => Endpoint::Knock,
_ => Self::default(),
}
}
}
#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, num::NonZeroU64};
use quanta::Clock;
use ruma::api::client::{account::register, session::login};
use super::*;
use crate::config::rate_limiting::{Config, Timeframe};
struct LoginRequest;
impl IncomingRequest for LoginRequest {
type EndpointError =
<login::v3::Request as IncomingRequest>::EndpointError;
type OutgoingResponse =
<login::v3::Request as IncomingRequest>::OutgoingResponse;
const METADATA: Metadata =
<login::v3::Request as IncomingRequest>::METADATA;
fn try_from_http_request<B, S>(
_: http::Request<B>,
_: &[S],
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
unimplemented!()
}
}
struct RegisterRequest;
impl IncomingRequest for RegisterRequest {
type EndpointError =
<register::v3::Request as IncomingRequest>::EndpointError;
type OutgoingResponse =
<register::v3::Request as IncomingRequest>::OutgoingResponse;
const METADATA: Metadata =
<register::v3::Request as IncomingRequest>::METADATA;
fn try_from_http_request<B, S>(
_: http::Request<B>,
_: &[S],
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
unimplemented!()
}
}
fn test_config(emission_interval_ns: u64, burst_capacity: u64) -> Config {
Config {
// Directly use nanoseconds without conversion
timeframe: Timeframe::PerSecond(
NonZeroU64::new(1_000_000_000 / emission_interval_ns).unwrap(),
),
burst_capacity: NonZeroU64::new(burst_capacity).unwrap(),
}
}
#[test]
fn test_basic_rate_limiting() {
let (clock, mock) = Clock::mock();
let config = test_config(1_000_000_000, 1);
let service = Service::build(
HashMap::from_iter([(Endpoint::Login, config)]),
clock,
);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
assert!(service
.update_or_reject::<LoginRequest>(target.clone())
.is_err());
mock.increment(1_000_000_000);
service.update_or_reject::<LoginRequest>(target).unwrap();
}
#[test]
fn test_burst_capacity() {
let (clock, mock) = Clock::mock();
let config = test_config(1_000_000_000, 3);
let service = Service::build(
HashMap::from_iter([(Endpoint::Login, config)]),
clock,
);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
assert!(service
.update_or_reject::<LoginRequest>(target.clone())
.is_err());
mock.increment(1_000_000_000);
service.update_or_reject::<LoginRequest>(target).unwrap();
}
#[test]
fn test_different_target_types() {
let config = test_config(1_000_000_000, 1);
let service = Service::build(
HashMap::from_iter([(Endpoint::Login, config)]),
Clock::new(),
);
let targets = vec![
Target::User(OwnedUserId::try_from("@user:example.com").unwrap()),
Target::Server(OwnedServerName::try_from("example.com").unwrap()),
Target::Appservice("my_appservice".to_owned()),
Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
];
for target in targets {
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
}
}
#[test]
fn test_different_endpoints() {
let config = test_config(1_000_000_000, 1);
let service = Service::build(
HashMap::from_iter([
(Endpoint::Login, config.clone()),
(Endpoint::Registration, config),
]),
Clock::new(),
);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
service.update_or_reject::<RegisterRequest>(target.clone()).unwrap();
assert!(service
.update_or_reject::<LoginRequest>(target.clone())
.is_err());
assert!(service.update_or_reject::<RegisterRequest>(target).is_err());
}
#[test]
fn test_default_config() {
let (clock, mock) = Clock::mock();
let service = Service::build(HashMap::new(), clock);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
assert!(service
.update_or_reject::<LoginRequest>(target.clone())
.is_err());
mock.increment(1_000_000_000);
service.update_or_reject::<LoginRequest>(target).unwrap();
}
#[test]
fn test_very_high_rate() {
let (clock, mock) = Clock::mock();
let config = test_config(1_000, 10);
let service = Service::build(
HashMap::from_iter([(Endpoint::Login, config)]),
clock,
);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
for _ in 0..10 {
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
}
let value = mock.value();
mock.decrement(value);
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
}
#[test]
fn test_returned_delay() {
let (clock, mock) = Clock::mock();
let config = test_config(1_000_000_000, 1);
let service = Service::build(
HashMap::from_iter([(Endpoint::Login, config)]),
clock,
);
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
if let Err(delay) =
service.update_or_reject::<LoginRequest>(target.clone())
{
assert_eq!(delay, Duration::from_secs(1));
} else {
panic!("Expected rate limit error");
}
mock.increment(500_000_000);
if let Err(delay) = service.update_or_reject::<LoginRequest>(target) {
assert_eq!(delay, Duration::from_millis(500));
} else {
panic!("Expected rate limit error");
}
}
}

View file

@ -8,4 +8,4 @@ Error: failed to validate configuration
|
1 | some_name = "example.com"
| ^^^^^^^^^
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password`
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `rate_limiting`, `emergency_password`