mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 15:21:24 +01:00
Merge branch 'rate-limiting' into 'main'
Draft: Rate limiting See merge request matrix/grapevine!183
This commit is contained in:
commit
612a5b51cc
8 changed files with 631 additions and 9 deletions
52
Cargo.lock
generated
52
Cargo.lock
generated
|
|
@ -521,6 +521,12 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
|
|
@ -558,6 +564,20 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "6.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.8.0"
|
||||
|
|
@ -922,6 +942,7 @@ dependencies = [
|
|||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"clap",
|
||||
"dashmap",
|
||||
"futures-util",
|
||||
"hmac",
|
||||
"html-escape",
|
||||
|
|
@ -945,6 +966,7 @@ dependencies = [
|
|||
"predicates",
|
||||
"prometheus",
|
||||
"proxy-header",
|
||||
"quanta",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
|
|
@ -1000,6 +1022,12 @@ version = "0.12.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.2"
|
||||
|
|
@ -2226,6 +2254,21 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.7"
|
||||
|
|
@ -2355,6 +2398,15 @@ dependencies = [
|
|||
"getrandom 0.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.10"
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ axum-server = { git = "https://gitlab.computer.surgery/matrix/thirdparty/axum-se
|
|||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
|
||||
dashmap = "6.1.0"
|
||||
futures-util = { version = "0.3.31", default-features = false }
|
||||
hmac = "0.12.1"
|
||||
html-escape = "0.2.13"
|
||||
|
|
@ -116,6 +117,7 @@ phf = { version = "0.11.3", features = ["macros"] }
|
|||
pin-project-lite = "0.2.16"
|
||||
prometheus = "0.13.4"
|
||||
proxy-header = { version = "0.1.2", features = ["tokio"] }
|
||||
quanta = "0.12.5"
|
||||
rand = "0.8.5"
|
||||
regex = "1.11.1"
|
||||
reqwest = { version = "0.12.15", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
use std::{collections::BTreeMap, iter::FromIterator, str};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
iter::FromIterator,
|
||||
str::{self, FromStr as _},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
|
|
@ -17,8 +21,8 @@ use http::{Request, StatusCode};
|
|||
use http_body_util::BodyExt;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata,
|
||||
OutgoingResponse,
|
||||
client::error::{ErrorKind, RetryAfter},
|
||||
AuthScheme, IncomingRequest, Metadata, OutgoingResponse,
|
||||
},
|
||||
server_util::authorization::XMatrix,
|
||||
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
||||
|
|
@ -28,7 +32,13 @@ use serde::Deserialize;
|
|||
use tracing::{error, warn};
|
||||
|
||||
use super::{Ar, Ra};
|
||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
||||
use crate::{
|
||||
service::{
|
||||
appservice::RegistrationInfo,
|
||||
rate_limiting::{Target, XForwardedFor},
|
||||
},
|
||||
services, Error, Result,
|
||||
};
|
||||
|
||||
enum Token {
|
||||
Appservice(Box<RegistrationInfo>),
|
||||
|
|
@ -51,10 +61,13 @@ struct ArPieces {
|
|||
/// Non-generic part of [`Ar::from_request()`]. Splitting this out reduces
|
||||
/// binary size by ~10%.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn ar_from_request_inner(
|
||||
async fn ar_from_request_inner<T>(
|
||||
req: axum::extract::Request,
|
||||
metadata: Metadata,
|
||||
) -> Result<ArPieces> {
|
||||
) -> Result<ArPieces>
|
||||
where
|
||||
T: IncomingRequest,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct QueryParams {
|
||||
access_token: Option<String>,
|
||||
|
|
@ -350,6 +363,45 @@ async fn ar_from_request_inner(
|
|||
}
|
||||
};
|
||||
|
||||
let mut target =
|
||||
Err(Error::BadRequest(ErrorKind::forbidden(), "Missing identifier."));
|
||||
|
||||
if let Some(servername) = sender_servername.as_ref() {
|
||||
target = Ok(Target::Server(servername.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(user) = sender_user.as_ref() {
|
||||
target = Ok(Target::User(user.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(info) = appservice_info.as_ref() {
|
||||
if info.registration.rate_limited.unwrap_or(true) {
|
||||
target = Ok(Target::Appservice(info.registration.id.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(header) = parts.headers.get("X-Forwarded-For") {
|
||||
let value = header.to_str().ok();
|
||||
|
||||
let x_forwarded_for =
|
||||
value.map(XForwardedFor::from_str).and_then(Result::ok);
|
||||
|
||||
if let Some(ip) = x_forwarded_for.and_then(|value| value.ip) {
|
||||
target = Ok(Target::Ip(ip));
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(retry_after) =
|
||||
{ services().rate_limiting.update_or_reject::<T>(target?) }
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after: Some(RetryAfter::Delay(retry_after)),
|
||||
},
|
||||
"Rate limit exceeded.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut http_request =
|
||||
Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
|
|
@ -411,7 +463,7 @@ where
|
|||
req: axum::extract::Request,
|
||||
_state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let pieces = ar_from_request_inner(req, T::METADATA).await?;
|
||||
let pieces = ar_from_request_inner::<T>(req, T::METADATA).await?;
|
||||
|
||||
let body =
|
||||
T::try_from_http_request(pieces.http_request, &pieces.path_params)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{BTreeMap, HashSet},
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fmt::{self, Display},
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
path::{Path, PathBuf},
|
||||
|
|
@ -19,6 +19,7 @@ use crate::{error, utils::partial_canonicalize};
|
|||
|
||||
mod env_filter_clone;
|
||||
mod proxy;
|
||||
pub(crate) mod rate_limiting;
|
||||
|
||||
pub(crate) use env_filter_clone::EnvFilterClone;
|
||||
use proxy::ProxyConfig;
|
||||
|
|
@ -69,6 +70,9 @@ pub(crate) struct Config {
|
|||
pub(crate) observability: ObservabilityConfig,
|
||||
#[serde(default)]
|
||||
pub(crate) turn: TurnConfig,
|
||||
#[serde(default = "rate_limiting::default_rate_limit")]
|
||||
pub(crate) rate_limiting:
|
||||
HashMap<rate_limiting::Endpoint, rate_limiting::Config>,
|
||||
|
||||
pub(crate) emergency_password: Option<String>,
|
||||
}
|
||||
|
|
|
|||
93
src/config/rate_limiting.rs
Normal file
93
src/config/rate_limiting.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
use std::{collections::HashMap, num::NonZeroU64};
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(
|
||||
Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) enum Endpoint {
|
||||
Registration,
|
||||
Login,
|
||||
RegistrationTokenValidity,
|
||||
Message,
|
||||
Join,
|
||||
Invite,
|
||||
Knock,
|
||||
CreateMedia,
|
||||
Transaction,
|
||||
FederatedJoin,
|
||||
FederatedInvite,
|
||||
FederatedKnock,
|
||||
|
||||
#[default]
|
||||
Global,
|
||||
|
||||
#[serde(untagged)]
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct Config {
|
||||
#[serde(default = "default_non_zero", flatten)]
|
||||
pub(crate) timeframe: Timeframe,
|
||||
#[serde(default = "default_non_zero")]
|
||||
pub(crate) burst_capacity: NonZeroU64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub(crate) fn delay_tolerance(&self) -> u64 {
|
||||
self.timeframe.emission_interval_ns() * self.burst_capacity.get()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
timeframe: Timeframe::PerSecond(NonZeroU64::MIN),
|
||||
burst_capacity: NonZeroU64::MIN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub(crate) enum Timeframe {
|
||||
PerSecond(NonZeroU64),
|
||||
PerMinute(NonZeroU64),
|
||||
PerHour(NonZeroU64),
|
||||
PerDay(NonZeroU64),
|
||||
}
|
||||
|
||||
impl Timeframe {
|
||||
pub(crate) fn time_window_ms(&self) -> u64 {
|
||||
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
|
||||
|
||||
let s = match self {
|
||||
PerSecond(_) => 1_u64,
|
||||
PerMinute(_) => 60,
|
||||
PerHour(_) => 60 * 60,
|
||||
PerDay(_) => 60 * 60 * 24,
|
||||
};
|
||||
|
||||
s.checked_mul(1_000_000_000).expect("time window overflow")
|
||||
}
|
||||
|
||||
pub(crate) fn emission_interval_ns(&self) -> u64 {
|
||||
use Timeframe::{PerDay, PerHour, PerMinute, PerSecond};
|
||||
|
||||
let tf = match self {
|
||||
PerSecond(tf) | PerMinute(tf) | PerHour(tf) | PerDay(tf) => tf,
|
||||
};
|
||||
|
||||
self.time_window_ms() / tf.get()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_non_zero() -> NonZeroU64 {
|
||||
NonZeroU64::MIN
|
||||
}
|
||||
pub(super) fn default_rate_limit() -> HashMap<Endpoint, Config> {
|
||||
HashMap::from_iter([(Endpoint::default(), Config::default())])
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ pub(crate) mod key_backups;
|
|||
pub(crate) mod media;
|
||||
pub(crate) mod pdu;
|
||||
pub(crate) mod pusher;
|
||||
pub(crate) mod rate_limiting;
|
||||
pub(crate) mod rooms;
|
||||
pub(crate) mod sending;
|
||||
pub(crate) mod transaction_ids;
|
||||
|
|
@ -36,6 +37,7 @@ pub(crate) struct Services {
|
|||
pub(crate) key_backups: key_backups::Service,
|
||||
pub(crate) media: media::Service,
|
||||
pub(crate) sending: Arc<sending::Service>,
|
||||
pub(crate) rate_limiting: rate_limiting::Service,
|
||||
}
|
||||
|
||||
impl Services {
|
||||
|
|
@ -121,6 +123,10 @@ impl Services {
|
|||
db,
|
||||
},
|
||||
sending: sending::Service::new(db, &config),
|
||||
rate_limiting: rate_limiting::Service::build(
|
||||
services().globals.config.rate_limiting.clone(),
|
||||
quanta::Clock::new(),
|
||||
),
|
||||
|
||||
globals: globals::Service::new(db, config, reload_handles)?,
|
||||
})
|
||||
|
|
|
|||
413
src/service/rate_limiting.rs
Normal file
413
src/service/rate_limiting.rs
Normal file
|
|
@ -0,0 +1,413 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
hash::Hash,
|
||||
net::{AddrParseError, IpAddr},
|
||||
str::FromStr,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use quanta::Clock;
|
||||
use ruma::{
|
||||
api::{IncomingRequest, Metadata},
|
||||
OwnedServerName, OwnedUserId,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::rate_limiting::{self, Endpoint},
|
||||
Result,
|
||||
};
|
||||
|
||||
pub(crate) struct Service {
|
||||
clock: Clock,
|
||||
store: DashMap<(Target, Endpoint), AtomicU64>,
|
||||
config: HashMap<Endpoint, rate_limiting::Config>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub(crate) fn build(
|
||||
config: HashMap<Endpoint, rate_limiting::Config>,
|
||||
clock: Clock,
|
||||
) -> Self {
|
||||
Self {
|
||||
clock,
|
||||
store: DashMap::new(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn update_or_reject<IR: IncomingRequest>(
|
||||
&self,
|
||||
target: Target,
|
||||
) -> Result<(), Duration> {
|
||||
let now_ns = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||
|
||||
let endpoint = Endpoint::from(IR::METADATA);
|
||||
|
||||
let config = self.config.get(&endpoint).cloned().unwrap_or_default();
|
||||
|
||||
let (emission_interval_ns, delay_tolerance) =
|
||||
(config.timeframe.emission_interval_ns(), config.delay_tolerance());
|
||||
|
||||
let entry = self.store.entry((target, endpoint)).or_insert_with(|| {
|
||||
AtomicU64::new(now_ns.saturating_sub(delay_tolerance))
|
||||
});
|
||||
|
||||
let theoretical_arrival_time = entry.load(Ordering::Acquire);
|
||||
|
||||
let new_theoretical_arrival_time =
|
||||
theoretical_arrival_time.max(now_ns) + emission_interval_ns;
|
||||
|
||||
if (new_theoretical_arrival_time - now_ns) <= delay_tolerance {
|
||||
entry.store(new_theoretical_arrival_time, Ordering::Release);
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Duration::from_nanos(
|
||||
theoretical_arrival_time.saturating_sub(now_ns),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) enum Target {
|
||||
User(OwnedUserId),
|
||||
// Server endpoints should be rate-limited on a server and room basis
|
||||
Server(OwnedServerName),
|
||||
Appservice(String),
|
||||
Ip(IpAddr),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) struct XForwardedFor {
|
||||
pub(crate) ip: Option<IpAddr>,
|
||||
pub(crate) proxies: Vec<IpAddr>,
|
||||
}
|
||||
|
||||
impl FromStr for XForwardedFor {
|
||||
type Err = AddrParseError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
let ips = s.split(',').try_fold(Vec::new(), |mut ips, part| {
|
||||
let trimmed = part.trim();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
Ok(ips)
|
||||
} else {
|
||||
match IpAddr::from_str(trimmed) {
|
||||
Ok(ip) => {
|
||||
ips.push(ip);
|
||||
|
||||
Ok(ips)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
let (ip, proxies) = if ips.is_empty() {
|
||||
(None, Vec::new())
|
||||
} else {
|
||||
(Some(ips[0]), ips[1..].to_vec())
|
||||
};
|
||||
|
||||
Ok(XForwardedFor {
|
||||
ip,
|
||||
proxies,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Metadata> for Endpoint {
|
||||
fn from(metadata: Metadata) -> Self {
|
||||
use ruma::api::{
|
||||
client::{
|
||||
account::{check_registration_token_validity, register},
|
||||
knock::knock_room,
|
||||
media::{create_content, create_content_async},
|
||||
membership::{
|
||||
invite_user, join_room_by_id, join_room_by_id_or_alias,
|
||||
},
|
||||
message::send_message_event,
|
||||
session::login,
|
||||
state::send_state_event,
|
||||
},
|
||||
federation::{
|
||||
knock::send_knock,
|
||||
membership::{create_invite, create_join_event},
|
||||
transactions::send_transaction_message,
|
||||
},
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
match metadata {
|
||||
register::v3::Request::METADATA => Endpoint::Registration,
|
||||
login::v3::Request::METADATA => Endpoint::Login,
|
||||
check_registration_token_validity::v1::Request::METADATA => {
|
||||
Endpoint::RegistrationTokenValidity
|
||||
}
|
||||
send_message_event::v3::Request::METADATA
|
||||
| send_state_event::v3::Request::METADATA => Endpoint::Message,
|
||||
join_room_by_id::v3::Request::METADATA
|
||||
| join_room_by_id_or_alias::v3::Request::METADATA => Endpoint::Join,
|
||||
invite_user::v3::Request::METADATA => Endpoint::Invite,
|
||||
create_content::v3::Request::METADATA
|
||||
| create_content_async::v3::Request::METADATA => {
|
||||
Endpoint::CreateMedia
|
||||
}
|
||||
send_transaction_message::v1::Request::METADATA => {
|
||||
Endpoint::Transaction
|
||||
}
|
||||
create_join_event::v1::Request::METADATA
|
||||
| create_join_event::v2::Request::METADATA => {
|
||||
Endpoint::FederatedJoin
|
||||
}
|
||||
create_invite::v1::Request::METADATA
|
||||
| create_invite::v2::Request::METADATA => Endpoint::FederatedInvite,
|
||||
send_knock::v1::Request::METADATA => Endpoint::FederatedKnock,
|
||||
knock_room::v3::Request::METADATA => Endpoint::Knock,
|
||||
_ => Self::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::Ipv4Addr, num::NonZeroU64};
|
||||
|
||||
use quanta::Clock;
|
||||
use ruma::api::client::{account::register, session::login};
|
||||
|
||||
use super::*;
|
||||
use crate::config::rate_limiting::{Config, Timeframe};
|
||||
|
||||
struct LoginRequest;
|
||||
|
||||
impl IncomingRequest for LoginRequest {
|
||||
type EndpointError =
|
||||
<login::v3::Request as IncomingRequest>::EndpointError;
|
||||
type OutgoingResponse =
|
||||
<login::v3::Request as IncomingRequest>::OutgoingResponse;
|
||||
|
||||
const METADATA: Metadata =
|
||||
<login::v3::Request as IncomingRequest>::METADATA;
|
||||
|
||||
fn try_from_http_request<B, S>(
|
||||
_: http::Request<B>,
|
||||
_: &[S],
|
||||
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
|
||||
where
|
||||
B: AsRef<[u8]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
struct RegisterRequest;
|
||||
|
||||
impl IncomingRequest for RegisterRequest {
|
||||
type EndpointError =
|
||||
<register::v3::Request as IncomingRequest>::EndpointError;
|
||||
type OutgoingResponse =
|
||||
<register::v3::Request as IncomingRequest>::OutgoingResponse;
|
||||
|
||||
const METADATA: Metadata =
|
||||
<register::v3::Request as IncomingRequest>::METADATA;
|
||||
|
||||
fn try_from_http_request<B, S>(
|
||||
_: http::Request<B>,
|
||||
_: &[S],
|
||||
) -> std::result::Result<Self, ruma::api::error::FromHttpRequestError>
|
||||
where
|
||||
B: AsRef<[u8]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn test_config(emission_interval_ns: u64, burst_capacity: u64) -> Config {
|
||||
Config {
|
||||
// Directly use nanoseconds without conversion
|
||||
timeframe: Timeframe::PerSecond(
|
||||
NonZeroU64::new(1_000_000_000 / emission_interval_ns).unwrap(),
|
||||
),
|
||||
burst_capacity: NonZeroU64::new(burst_capacity).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_rate_limiting() {
|
||||
let (clock, mock) = Clock::mock();
|
||||
|
||||
let config = test_config(1_000_000_000, 1);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||
clock,
|
||||
);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
assert!(service
|
||||
.update_or_reject::<LoginRequest>(target.clone())
|
||||
.is_err());
|
||||
|
||||
mock.increment(1_000_000_000);
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_burst_capacity() {
|
||||
let (clock, mock) = Clock::mock();
|
||||
|
||||
let config = test_config(1_000_000_000, 3);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||
clock,
|
||||
);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
assert!(service
|
||||
.update_or_reject::<LoginRequest>(target.clone())
|
||||
.is_err());
|
||||
|
||||
mock.increment(1_000_000_000);
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_target_types() {
|
||||
let config = test_config(1_000_000_000, 1);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||
Clock::new(),
|
||||
);
|
||||
|
||||
let targets = vec![
|
||||
Target::User(OwnedUserId::try_from("@user:example.com").unwrap()),
|
||||
Target::Server(OwnedServerName::try_from("example.com").unwrap()),
|
||||
Target::Appservice("my_appservice".to_owned()),
|
||||
Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
|
||||
];
|
||||
|
||||
for target in targets {
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_endpoints() {
|
||||
let config = test_config(1_000_000_000, 1);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([
|
||||
(Endpoint::Login, config.clone()),
|
||||
(Endpoint::Registration, config),
|
||||
]),
|
||||
Clock::new(),
|
||||
);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
service.update_or_reject::<RegisterRequest>(target.clone()).unwrap();
|
||||
|
||||
assert!(service
|
||||
.update_or_reject::<LoginRequest>(target.clone())
|
||||
.is_err());
|
||||
|
||||
assert!(service.update_or_reject::<RegisterRequest>(target).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let (clock, mock) = Clock::mock();
|
||||
|
||||
let service = Service::build(HashMap::new(), clock);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
assert!(service
|
||||
.update_or_reject::<LoginRequest>(target.clone())
|
||||
.is_err());
|
||||
|
||||
mock.increment(1_000_000_000);
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_high_rate() {
|
||||
let (clock, mock) = Clock::mock();
|
||||
|
||||
let config = test_config(1_000, 10);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||
clock,
|
||||
);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
for _ in 0..10 {
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
}
|
||||
|
||||
let value = mock.value();
|
||||
mock.decrement(value);
|
||||
|
||||
assert!(service.update_or_reject::<LoginRequest>(target).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_returned_delay() {
|
||||
let (clock, mock) = Clock::mock();
|
||||
|
||||
let config = test_config(1_000_000_000, 1);
|
||||
|
||||
let service = Service::build(
|
||||
HashMap::from_iter([(Endpoint::Login, config)]),
|
||||
clock,
|
||||
);
|
||||
|
||||
let target = Target::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
||||
|
||||
service.update_or_reject::<LoginRequest>(target.clone()).unwrap();
|
||||
|
||||
if let Err(delay) =
|
||||
service.update_or_reject::<LoginRequest>(target.clone())
|
||||
{
|
||||
assert_eq!(delay, Duration::from_secs(1));
|
||||
} else {
|
||||
panic!("Expected rate limit error");
|
||||
}
|
||||
|
||||
mock.increment(500_000_000);
|
||||
|
||||
if let Err(delay) = service.update_or_reject::<LoginRequest>(target) {
|
||||
assert_eq!(delay, Duration::from_millis(500));
|
||||
} else {
|
||||
panic!("Expected rate limit error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,4 +8,4 @@ Error: failed to validate configuration
|
|||
|
|
||||
1 | some_name = "example.com"
|
||||
| ^^^^^^^^^
|
||||
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `emergency_password`
|
||||
unknown field `some_name`, expected one of `conduit_compat`, `listen`, `tls`, `server_name`, `server_discovery`, `database`, `media`, `federation`, `cache`, `cleanup_second_interval`, `max_request_size`, `allow_registration`, `registration_token`, `allow_encryption`, `allow_room_creation`, `default_room_version`, `proxy`, `jwt_secret`, `observability`, `turn`, `rate_limiting`, `emergency_password`
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue