mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-18 00:01:24 +01:00
Merge branch 'rate-limiting' into 'main'
Draft: Rate limiting See merge request matrix/grapevine!183
This commit is contained in:
commit
612a5b51cc
8 changed files with 631 additions and 9 deletions
52
Cargo.lock
generated
52
Cargo.lock
generated
|
|
@ -521,6 +521,12 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crypto-common"
|
name = "crypto-common"
|
||||||
version = "0.1.6"
|
version = "0.1.6"
|
||||||
|
|
@ -558,6 +564,20 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dashmap"
|
||||||
|
version = "6.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"lock_api",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "data-encoding"
|
name = "data-encoding"
|
||||||
version = "2.8.0"
|
version = "2.8.0"
|
||||||
|
|
@ -922,6 +942,7 @@ dependencies = [
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"bytes",
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dashmap",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"hmac",
|
"hmac",
|
||||||
"html-escape",
|
"html-escape",
|
||||||
|
|
@ -945,6 +966,7 @@ dependencies = [
|
||||||
"predicates",
|
"predicates",
|
||||||
"prometheus",
|
"prometheus",
|
||||||
"proxy-header",
|
"proxy-header",
|
||||||
|
"quanta",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"regex",
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
|
@ -1000,6 +1022,12 @@ version = "0.12.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.14.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.15.2"
|
version = "0.15.2"
|
||||||
|
|
@ -2226,6 +2254,21 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quanta"
|
||||||
|
version = "0.12.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
"libc",
|
||||||
|
"once_cell",
|
||||||
|
"raw-cpuid",
|
||||||
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
|
"web-sys",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quinn"
|
name = "quinn"
|
||||||
version = "0.11.7"
|
version = "0.11.7"
|
||||||
|
|
@ -2355,6 +2398,15 @@ dependencies = [
|
||||||
"getrandom 0.3.2",
|
"getrandom 0.3.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "raw-cpuid"
|
||||||
|
version = "11.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.10"
|
version = "0.5.10"
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,7 @@ axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-se
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
bytes = "1.10.1"
|
bytes = "1.10.1"
|
||||||
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
|
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
|
||||||
|
dashmap = "6.1.0"
|
||||||
futures-util = { version = "0.3.31", default-features = false }
|
futures-util = { version = "0.3.31", default-features = false }
|
||||||
hmac = "0.12.1"
|
hmac = "0.12.1"
|
||||||
html-escape = "0.2.13"
|
html-escape = "0.2.13"
|
||||||
|
|
@ -116,6 +117,7 @@ phf = { version = "0.11.3", features = ["macros"] }
|
||||||
pin-project-lite = "0.2.16"
|
pin-project-lite = "0.2.16"
|
||||||
prometheus = "0.13.4"
|
prometheus = "0.13.4"
|
||||||
proxy-header = { version = "0.1.2", features = ["tokio"] }
|
proxy-header = { version = "0.1.2", features = ["tokio"] }
|
||||||
|
quanta = "0.12.5"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
regex = "1.11.1"
|
regex = "1.11.1"
|
||||||
reqwest = { version = "0.12.15", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }
|
reqwest = { version = "0.12.15", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
use std::{collections::BTreeMap, iter::FromIterator, str};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
iter::FromIterator,
|
||||||
|
str::{self, FromStr as _},
|
||||||
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
|
@ -17,8 +21,8 @@ use http::{Request, StatusCode};
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata,
|
client::error::{ErrorKind, RetryAfter},
|
||||||
OutgoingResponse,
|
AuthScheme, IncomingRequest, Metadata, OutgoingResponse,
|
||||||
},
|
},
|
||||||
server_util::authorization::XMatrix,
|
server_util::authorization::XMatrix,
|
||||||
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
||||||
|
|
@ -28,7 +32,13 @@ use serde::Deserialize;
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
|
|
||||||
use super::{Ar, Ra};
|
use super::{Ar, Ra};
|
||||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
use crate::{
|
||||||
|
service::{
|
||||||
|
appservice::RegistrationInfo,
|
||||||
|
rate_limiting::{Target, XForwardedFor},
|
||||||
|
},
|
||||||
|
services, Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
enum Token {
|
enum Token {
|
||||||
Appservice(Box<RegistrationInfo>),
|
Appservice(Box<RegistrationInfo>),
|
||||||
|
|
@ -51,10 +61,13 @@ struct ArPieces {
|
||||||
/// Non-generic part of [`Ar::from_request()`]. Splitting this out reduces
|
/// Non-generic part of [`Ar::from_request()`]. Splitting this out reduces
|
||||||
/// binary size by ~10%.
|
/// binary size by ~10%.
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
async fn ar_from_request_inner(
|
async fn ar_from_request_inner<T>(
|
||||||
req: axum::extract::Request,
|
req: axum::extract::Request,
|
||||||
metadata: Metadata,
|
metadata: Metadata,
|
||||||
) -> Result<ArPieces> {
|
) -> Result<ArPieces>
|
||||||
|
where
|
||||||
|
T: IncomingRequest,
|
||||||
|
{
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct QueryParams {
|
struct QueryParams {
|
||||||
access_token: Option<String>,
|
access_token: Option<String>,
|
||||||
|
|
@ -350,6 +363,45 @@ async fn ar_from_request_inner(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut target =
|
||||||
|
Err(Error::BadRequest(ErrorKind::forbidden(), "Missing identifier."));
|
||||||
|
|
||||||
|
if let Some(servername) = sender_servername.as_ref() {
|
||||||
|
target = Ok(Target::Server(servername.to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(user) = sender_user.as_ref() {
|
||||||
|
target = Ok(Target::User(user.to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(info) = appservice_info.as_ref() {
|
||||||
|
if info.registration.rate_limited.unwrap_or(true) {
|
||||||
|
target = Ok(Target::Appservice(info.registration.id.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(header) = parts.headers.get("X-Forwarded-For") {
|
||||||
|
let value = header.to_str().ok();
|
||||||
|
|
||||||
|
let x_forwarded_for =
|
||||||
|
value.map(XForwardedFor::from_str).and_then(Result::ok);
|
||||||
|
|
||||||
|
if let Some(ip) = x_forwarded_for.and_then(|value| value.ip) {
|
||||||
|
target = Ok(Target::Ip(ip));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(retry_after) =
|
||||||
|
{ services().rate_limiting.update_or_reject::<T>(target?) }
|
||||||
|
{
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::LimitExceeded {
|
||||||
|
retry_after: Some(RetryAfter::Delay(retry_after)),
|
||||||
|
},
|
||||||
|
"Rate limit exceeded.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let mut http_request =
|
let mut http_request =
|
||||||
Request::builder().uri(parts.uri).method(parts.method);
|
Request::builder().uri(parts.uri).method(parts.method);
|
||||||
*http_request.headers_mut().unwrap() = parts.headers;
|
*http_request.headers_mut().unwrap() = parts.headers;
|
||||||
|
|
@ -411,7 +463,7 @@ where
|
||||||
req: axum::extract::Request,
|
req: axum::extract::Request,
|
||||||
_state: &S,
|
_state: &S,
|
||||||
) -> Result<Self, Self::Rejection> {
|
) -> Result<Self, Self::Rejection> {
|
||||||
let pieces = ar_from_request_inner(req, T::METADATA).await?;
|
let pieces = ar_from_request_inner::<T>(req, T::METADATA).await?;
|
||||||
|
|
||||||
let body =
|
let body =
|
||||||
T::try_from_http_request(pieces.http_request, &pieces.path_params)
|
T::try_from_http_request(pieces.http_request, &pieces.path_params)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
borrow::Cow,
|
||||||
collections::{BTreeMap, HashSet},
|
collections::{BTreeMap, HashMap, HashSet},
|
||||||
fmt::{self, Display},
|
fmt::{self, Display},
|
||||||
net::{IpAddr, Ipv4Addr},
|
net::{IpAddr, Ipv4Addr},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
|
@ -19,6 +19,7 @@ use crate::{error, utils::partial_canonicalize};
|
||||||
|
|
||||||
mod env_filter_clone;
|
mod env_filter_clone;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
|
pub(crate) mod rate_limiting;
|
||||||
|
|
||||||
pub(crate) use env_filter_clone::EnvFilterClone;
|
pub(crate) use env_filter_clone::EnvFilterClone;
|
||||||
use proxy::ProxyConfig;
|
use proxy::ProxyConfig;
|
||||||
|
|
@ -69,6 +70,9 @@ pub(crate) struct Config {
|
||||||
pub(crate) observability: ObservabilityConfig,
|
pub(crate) observability: ObservabilityConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub(crate) turn: TurnConfig,
|
pub(crate) turn: TurnConfig,
|
||||||
|
#[serde(default = "rate_limiting::default_rate_limit")]
|
||||||
|
pub(crate) rate_limiting:
|
||||||
|
HashMap<rate_limiting::Endpoint, rate_limiting::Config>,
|
||||||
|
|
||||||
pub(crate) emergency_password: Option<String>,
|
pub(crate) emergency_password: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
93
src/config/rate_limiting.rs
Normal file
93
src/config/rate_limiting.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
use std::{collections::HashMap, num::NonZeroU64};
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(
|
||||||
|
Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd,
|
||||||
|
)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub(crate) enum Endpoint {
|
||||||
|
Registration,
|
||||||
|
Login,
|
||||||
|
RegistrationTokenValidity,
|
||||||
|
Message,
|
||||||
|
Join,
|
||||||
|
Invite,
|
||||||
|
Knock,
|
||||||
|
CreateMedia,
|
||||||
|
Transaction,
|
||||||
|
FederatedJoin,
|
||||||
|
FederatedInvite,
|
||||||
|
FederatedKnock,
|
||||||
|
|
||||||
|
#[default]
|
||||||
|
Global,
|
||||||
|
|
||||||
|
#[serde(untagged)]
|
||||||
|
Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub(crate) struct Config {
|
||||||
|
#[serde(default = "default_non_zero", flatten)]
|
||||||
|
pub(crate) timeframe: Timeframe,
|
||||||
|
#[serde(default = "default_non_zero")]
|
||||||
|
pub(crate) burst_capacity: NonZeroU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub(crate) fn delay_tolerance(&self) -> u64 {
|
||||||
|
self.timeframe.emission_interval_ns() * self.burst_capacity.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Config {
|
||||||
|
timeframe: Timeframe::PerSecond(NonZeroU64::MIN),
|
||||||
|
burst_capacity: NonZeroU64::MIN,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
#[allow(clippy::enum_variant_names)]
|
||||||
|
pub(crate) enum Timeframe {
|
||||||
|
PerSecond(NonZeroU64),
|
||||||
|
PerMinute(NonZeroU64),
|
||||||
|
PerHour(NonZeroU64),
|
||||||
|
PerDay(NonZeroU64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Timeframe {
|
||||||
|
pub(crate) fn time_window_ms(&self) -> u64 {
|
||||||
|
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
|
||||||
|
|
||||||
|
let s = match self {
|
||||||
|
PerSecond(_) => 1_u64,
|
||||||
|
PerMinute(_) => 60,
|
||||||
|
PerHour(_) => 60 * 60,
|
||||||
|
PerDay(_) => 60 * 60 * 24,
|
||||||
|
};
|
||||||
|
|
||||||
|
s.checked_mul(1_000_000_000).expect("time window overflow")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn emission_interval_ns(&self) -> u64 {
|
||||||
|
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
|
||||||
|
|
||||||
|
let tf = match self {
|
||||||
|
PerSecond(tf) | PerMinute(tf) | PerHour(tf) | PerDay(tf) => tf,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.time_window_ms() / tf.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_non_zero() -> NonZeroU64 {
|
||||||
|
NonZeroU64::MIN
|
||||||
|
}
|
||||||
|
pub(super) fn default_rate_limit() -> HashMap<Endpoint, Config> {
|
||||||
|
HashMap::from_iter([(Endpoint::default(), Config::default())])
|
||||||
|
}
|
||||||
|
|
@ -10,6 +10,7 @@ pub(crate) mod key_backups;
|
||||||
pub(crate) mod media;
|
pub(crate) mod media;
|
||||||
pub(crate) mod pdu;
|
pub(crate) mod pdu;
|
||||||
pub(crate) mod pusher;
|
pub(crate) mod pusher;
|
||||||
|
pub(crate) mod rate_limiting;
|
||||||
pub(crate) mod rooms;
|
pub(crate) mod rooms;
|
||||||
pub(crate) mod sending;
|
pub(crate) mod sending;
|
||||||
pub(crate) mod transaction_ids;
|
pub(crate) mod transaction_ids;
|
||||||
|
|
@ -36,6 +37,7 @@ pub(crate) struct Services {
|
||||||
pub(crate) key_backups: key_backups::Service,
|
pub(crate) key_backups: key_backups::Service,
|
||||||
pub(crate) media: media::Service,
|
pub(crate) media: media::Service,
|
||||||
pub(crate) sending: Arc<sending::Service>,
|
pub(crate) sending: Arc<sending::Service>,
|
||||||
|
pub(crate) rate_limiting: rate_limiting::Service,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Services {
|
impl Services {
|
||||||
|
|
@ -121,6 +123,10 @@ impl Services {
|
||||||
db,
|
db,
|
||||||
},
|
},
|
||||||
sending: sending::Service::new(db, &config),
|
sending: sending::Service::new(db, &config),
|
||||||
|
rate_limiting: rate_limiting::Service::build(
|
||||||
|
services().globals.config.rate_limiting.clone(),
|
||||||
|
quanta::Clock::new(),
|
||||||
|
),
|
||||||
|
|
||||||
globals: globals::Service::new(db, config, reload_handles)?,
|
globals: globals::Service::new(db, config, reload_handles)?,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
413
src/service/rate_limiting.rs
Normal file
413
src/service/rate_limiting.rs
Normal file
|
|
@ -0,0 +1,413 @@
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
hash::Hash,
|
||||||
|
net::{AddrParseError, IpAddr},
|
||||||
|
str::FromStr,
|
||||||
|
sync::atomic::{AtomicU64, Ordering},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use quanta::Clock;
|
||||||
|
use ruma::{
|
||||||
|
api::{IncomingRequest, Metadata},
|
||||||
|
OwnedServerName, OwnedUserId,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::rate_limiting::{self, Endpoint},
|
||||||
|
Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub(crate) struct Service {
|
||||||
|
clock: Clock,
|
||||||
|
store: DashMap<(Target, Endpoint), AtomicU64>,
|
||||||
|
config: HashMap<Endpoint, rate_limiting::Config>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub(crate) fn build(
|
||||||
|
config: HashMap<Endpoint, rate_limiting::Config>,
|
||||||
|
clock: Clock,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
clock,
|
||||||
|
store: DashMap::new(),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn update_or_reject<IR: IncomingRequest>(
|
||||||
|
&self,
|
||||||
|
target: Target,
|
||||||
|
) -> Result<(), Duration> {
|
||||||
|
let now_ns = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||||
|
|
||||||
|
let endpoint = Endpoint::from(IR::METADATA);
|
||||||
|
|
||||||
|
let config = self.config.get(&endpoint).cloned().unwrap_or_default();
|
||||||
|
|
||||||
|
let (emission_interval_ns, delay_tolerance) =
|
||||||
|
(config.timeframe.emission_interval_ns(), config.delay_tolerance());
|
||||||
|
|
||||||
|
let entry = self.store.entry((target, endpoint)).or_insert_with(|| {
|
||||||
|
AtomicU64::new(now_ns.saturating_sub(delay_tolerance))
|
||||||
|
});
|
||||||
|
|
||||||
|
let theoretical_arrival_time = entry.load(Ordering::Acquire);
|
||||||
|
|
||||||
|
let new_theoretical_arrival_time =
|
||||||
|
theoretical_arrival_time.max(now_ns) + emission_interval_ns;
|
||||||
|
|
||||||
|
if (new_theoretical_arrival_time - now_ns) <= delay_tolerance {
|
||||||
|
entry.store(new_theoretical_arrival_time, Ordering::Release);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(Duration::from_nanos(
|
||||||
|
theoretical_arrival_time.saturating_sub(now_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub(crate) enum Target {
|
||||||
|
User(OwnedUserId),
|
||||||
|
// Server endpoints should be rate-limited on a server and room basis
|
||||||
|
Server(OwnedServerName),
|
||||||
|
Appservice(String),
|
||||||
|
Ip(IpAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub(crate) struct XForwardedFor {
|
||||||
|
pub(crate) ip: Option<IpAddr>,
|
||||||
|
pub(crate) proxies: Vec<IpAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for XForwardedFor {
|
||||||
|
type Err = AddrParseError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
|
let ips = s.split(',').try_fold(Vec::new(), |mut ips, part| {
|
||||||
|
let trimmed = part.trim();
|
||||||
|
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
Ok(ips)
|
||||||
|
} else {
|
||||||
|
match IpAddr::from_str(trimmed) {
|
||||||
|
Ok(ip) => {
|
||||||
|
ips.push(ip);
|
||||||
|
|
||||||
|
Ok(ips)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (ip, proxies) = if ips.is_empty() {
|
||||||
|
(None, Vec::new())
|
||||||
|
} else {
|
||||||
|
(Some(ips[0]), ips[1..].to_vec())
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(XForwardedFor {
|
||||||
|
ip,
|
||||||
|
proxies,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Metadata> for Endpoint {
|
||||||
|
fn from(metadata: Metadata) -> Self {
|
||||||
|
use ruma::api::{
|
||||||
|
client::{
|
||||||
|
account::{check_registration_token_validity, register},
|
||||||
|
knock::knock_room,
|
||||||
|
media::{create_content, create_content_async},
|
||||||
|
membership::{
|
||||||
|
invite_user, join_room_by_id, join_room_by_id_or_alias,
|
||||||
|
},
|
||||||
|
message::send_message_event,
|
||||||
|
session::login,
|
||||||
|
state::send_state_event,
|
||||||
|
},
|
||||||
|
federation::{
|
||||||
|
knock::send_knock,
|
||||||
|
membership::{create_invite, create_join_event},
|
||||||
|
transactions::send_transaction_message,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[allow(deprecated)]
|
||||||
|
match metadata {
|
||||||
|
register::v3::Request::METADATA => Endpoint::Registration,
|
||||||
|
login::v3::Request::METADATA => Endpoint::Login,
|
||||||
|
check_registration_token_validity::v1::Request::METADATA => {
|
||||||
|
Endpoint::RegistrationTokenValidity
|
||||||
|
}
|
||||||
|
send_message_event::v3::Request::METADATA
|
||||||
|
| send_state_event::v3::Request::METADATA => Endpoint::Message,
|
||||||
|
join_room_by_id::v3::Request::METADATA
|
||||||
|
| join_room_by_id_or_alias::v3::Request::METADATA => Endpoint::Join,
|
||||||
|
invite_user::v3::Request::METADATA => Endpoint::Invite,
|
||||||
|
create_content::v3::Request::METADATA
|
||||||
|
| create_content_async::v3::Request::METADATA => {
|
||||||
|
Endpoint::CreateMedia
|
||||||
|
}
|
||||||
|
send_transaction_message::v1::Request::METADATA => {
|
||||||
|
Endpoint::Transaction
|
||||||
|
}
|
||||||
|
create_join_event::v1::Request::METADATA
|
||||||
|
| create_join_event::v2::Request::METADATA => {
|
||||||
|
Endpoint::FederatedJoin
|
||||||
|
}
|
||||||
|
create_invite::v1::Request::METADATA
|
||||||
|
| create_invite::v2::Request::METADATA => Endpoint::FederatedInvite,
|
||||||
|
send_knock::v1::Request::METADATA => Endpoint::FederatedKnock,
|
||||||
|
knock_room::v3::Request::METADATA => Endpoint::Knock,
|
||||||
|
_ => Self::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::{net::Ipv4Addr, num::NonZeroU64};
|
||||||
|
|
||||||
|
use quanta::Clock;
|
||||||
|
use ruma::api::client::{account::register, session::login};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::config::rate_limiting::{Config, Timeframe};
|
||||||
|
|
||||||
|
struct LoginRequest;
|
||||||
|
|
||||||
|
impl IncomingRequest for LoginRequest {
|
||||||
|
type EndpointError =
|
||||||
|
<login::v3::Request as IncomingRequest>::EndpointError;
|
||||||
|
type OutgoingResponse =
|
||||||
|
<login::v3::Request as IncomingRequest>::OutgoingResponse;
|
||||||
|
|
||||||
|
const METADATA: Metadata =
|
||||||
|
<login::v3::Request as IncomingRequest>::METADATA;
|
||||||
|
|
||||||
|
fn try_from_http_request<B, S>(
|
||||||
|
_: http::Request<B>,
|
||||||
|
_: &[S],
|
||||||
|
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
|
||||||
|
where
|
||||||
|
B: AsRef<[u8]>,
|
||||||
|
S: AsRef<str>,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RegisterRequest;
|
||||||
|
|
||||||
|
impl IncomingRequest for RegisterRequest {
|
||||||
|
type EndpointError =
|
||||||
|
<register::v3::Request as IncomingRequest>::EndpointError;
|
||||||
|
type OutgoingResponse =
|
||||||
|
<register::v3::Request as IncomingRequest>::OutgoingResponse;
|
||||||
|
|
||||||
|
const METADATA: Metadata =
|
||||||
|
<register::v3::Request as IncomingRequest>::METADATA;
|
||||||
|
|
||||||
|
fn try_from_http_request<B, S>(
|
||||||
|
_: http::Request<B>,
|
||||||
|
_: &[S],
|
||||||
|
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
|
||||||
|
where
|
||||||
|
B: AsRef<[u8]>,
|
||||||
|
S: AsRef<str>,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_config(emission_interval_ns: u64, burst_capacity: u64) -> Config {
|
||||||
|
Config {
|
||||||
|
// Directly use nanoseconds without conversion
|
||||||
|
timeframe: Timeframe::PerSecond(
|
||||||
|
NonZeroU64::new(1_000_000_000 / emission_interval_ns).unwrap(),
|
||||||
|
),
|
||||||
|
burst_capacity: NonZeroU64::new(burst_capacity).unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_basic_rate_limiting() {
|
||||||
|
let (clock, mock) = Clock::mock();
|
||||||
|
|
||||||
|
let config = test_config(1_000_000_000, 1);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||||
|
clock,
|
||||||
|
);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
assert!(service
|
||||||
|
.update_or_reject::<LoginRequest>(target.clone())
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
mock.increment(1_000_000_000);
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_burst_capacity() {
|
||||||
|
let (clock, mock) = Clock::mock();
|
||||||
|
|
||||||
|
let config = test_config(1_000_000_000, 3);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||||
|
clock,
|
||||||
|
);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
assert!(service
|
||||||
|
.update_or_reject::<LoginRequest>(target.clone())
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
mock.increment(1_000_000_000);
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_different_target_types() {
|
||||||
|
let config = test_config(1_000_000_000, 1);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||||
|
Clock::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let targets = vec![
|
||||||
|
Target::User(OwnedUserId::try_from("@user:example.com").unwrap()),
|
||||||
|
Target::Server(OwnedServerName::try_from("example.com").unwrap()),
|
||||||
|
Target::Appservice("my_appservice".to_owned()),
|
||||||
|
Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
|
||||||
|
];
|
||||||
|
|
||||||
|
for target in targets {
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_different_endpoints() {
|
||||||
|
let config = test_config(1_000_000_000, 1);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([
|
||||||
|
(Endpoint::Login, config.clone()),
|
||||||
|
(Endpoint::Registration, config),
|
||||||
|
]),
|
||||||
|
Clock::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
service.update_or_reject::<RegisterRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
assert!(service
|
||||||
|
.update_or_reject::<LoginRequest>(target.clone())
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
assert!(service.update_or_reject::<RegisterRequest>(target).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config() {
|
||||||
|
let (clock, mock) = Clock::mock();
|
||||||
|
|
||||||
|
let service = Service::build(HashMap::new(), clock);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
assert!(service
|
||||||
|
.update_or_reject::<LoginRequest>(target.clone())
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
mock.increment(1_000_000_000);
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_very_high_rate() {
|
||||||
|
let (clock, mock) = Clock::mock();
|
||||||
|
|
||||||
|
let config = test_config(1_000, 10);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||||
|
clock,
|
||||||
|
);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
for _ in 0..10 {
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = mock.value();
|
||||||
|
mock.decrement(value);
|
||||||
|
|
||||||
|
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_returned_delay() {
|
||||||
|
let (clock, mock) = Clock::mock();
|
||||||
|
|
||||||
|
let config = test_config(1_000_000_000, 1);
|
||||||
|
|
||||||
|
let service = Service::build(
|
||||||
|
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||||
|
clock,
|
||||||
|
);
|
||||||
|
|
||||||
|
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||||
|
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||||
|
|
||||||
|
if let Err(delay) =
|
||||||
|
service.update_or_reject::<LoginRequest>(target.clone())
|
||||||
|
{
|
||||||
|
assert_eq!(delay, Duration::from_secs(1));
|
||||||
|
} else {
|
||||||
|
panic!("Expected rate limit error");
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.increment(500_000_000);
|
||||||
|
|
||||||
|
if let Err(delay) = service.update_or_reject::<LoginRequest>(target) {
|
||||||
|
assert_eq!(delay, Duration::from_millis(500));
|
||||||
|
} else {
|
||||||
|
panic!("Expected rate limit error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,4 +8,4 @@ Error: failed to validate configuration
|
||||||
|
|
|
|
||||||
1 | some_name = "example.com"
|
1 | some_name = "example.com"
|
||||||
| ^^^^^^^^^
|
| ^^^^^^^^^
|
||||||
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password`
|
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `rate_limiting`, `emergency_password`
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue