From 080fe5af422b9b019fe9a2ac3d83ec51145e113f Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 19 Aug 2024 09:51:52 +0200 Subject: [PATCH] feat: configurable federation backoff --- src/config.rs | 59 ++++++++++++++++++++++++++++++++++- src/service/server_backoff.rs | 44 ++++++++------------------ 2 files changed, 71 insertions(+), 32 deletions(-) diff --git a/src/config.rs b/src/config.rs index a440a455..b9171a88 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,7 +12,7 @@ use ruma::{ api::federation::discovery::OldVerifyKey, OwnedServerName, OwnedServerSigningKeyId, RoomVersionId, }; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use strum::{Display, EnumIter, IntoEnumIterator}; use crate::error; @@ -407,6 +407,7 @@ pub(crate) struct FederationConfig { pub(crate) max_fetch_prev_events: u16, pub(crate) max_concurrent_requests: u16, pub(crate) old_verify_keys: BTreeMap, + pub(crate) backoff: BackoffConfig, } impl Default for FederationConfig { @@ -420,6 +421,44 @@ impl Default for FederationConfig { max_fetch_prev_events: 100, max_concurrent_requests: 100, old_verify_keys: BTreeMap::new(), + backoff: BackoffConfig::default(), + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(default)] +pub(crate) struct BackoffConfig { + /// Minimum number of consecutive failures for a server before starting to + /// delay requests. + pub(crate) failure_threshold: u8, + + /// Initial delay between requests in seconds, after the number of + /// consecutive failures to a server first exceeds the threshold. + pub(crate) base_delay: f64, + + /// Factor to increase delay by after each additional consecutive failure. + pub(crate) multiplier: f64, + + /// Maximum delay between requests to a server in seconds. + pub(crate) max_delay: f64, + + /// Range of random multipliers to request delay. + #[serde(deserialize_with = "deserialize_jitter_range")] + pub(crate) jitter_range: std::ops::Range, +} + +impl Default for BackoffConfig { + fn default() -> Self { + // After the first 3 consecutive failed requests, increase delay + // exponentially from 5s to 24h over the next 24 failures. It takes an + // average of 4.3 days of failures to reach the maximum delay of 24h. + Self { + failure_threshold: 3, + base_delay: 5.0, + multiplier: 1.5, + max_delay: 60.0 * 60.0 * 24.0, + jitter_range: 0.5..1.5, } } } @@ -482,6 +521,24 @@ pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } +fn deserialize_jitter_range<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + let Some((a, b)) = s.split_once("..") else { + return Err(serde::de::Error::custom(crate::Error::bad_config( + "invalid jitter range", + ))); + }; + + a.parse() + .and_then(|a| b.parse().map(|b| a..b)) + .map_err(serde::de::Error::custom) +} + /// Search default locations for a configuration file /// /// If one isn't found, the list of tried paths is returned. diff --git a/src/service/server_backoff.rs b/src/service/server_backoff.rs index e276a16a..aa1f3857 100644 --- a/src/service/server_backoff.rs +++ b/src/service/server_backoff.rs @@ -1,6 +1,5 @@ use std::{ collections::HashMap, - ops::Range, sync::{Arc, RwLock}, time::{Duration, Instant}, }; @@ -9,7 +8,7 @@ use rand::{thread_rng, Rng}; use ruma::{OwnedServerName, ServerName}; use tracing::{debug, instrument}; -use crate::{Error, Result}; +use crate::{services, Error, Result}; /// Service to handle backing off requests to offline servers. /// @@ -39,25 +38,6 @@ pub(crate) struct Service { servers: RwLock>>>, } -// After the first 5 consecutive failed requests, increase delay -// exponentially from 5s to 24h over the next 24 failures. It takes an -// average of 4.3 days of failures to reach the maximum delay of 24h. - -// TODO: consider making these configurable - -/// Minimum number of consecutive failures for a server before starting to delay -/// requests. -const FAILURE_THRESHOLD: u8 = 5; -/// Initial delay between requests after the number of consecutive failures -/// to a server first exceeds [`FAILURE_THRESHOLD`]. -const BASE_DELAY: Duration = Duration::from_secs(5); -/// Factor to increase delay by after each additional consecutive failure. -const MULTIPLIER: f64 = 1.5; -/// Maximum delay between requests to a server. -const MAX_DELAY: Duration = Duration::from_secs(60 * 60 * 24); -/// Range of random multipliers to request delay. -const JITTER_RANGE: Range = 0.5..1.5; - /// Guard to record the result of an attempted request to a server. /// /// If the request succeeds, call [`BackoffGuard::success`]. If the request @@ -173,19 +153,18 @@ impl BackoffState { /// Returns the remaining time before ready to attempt another request to /// this server. fn remaining_delay(&self) -> Option { + let config = &services().globals.config.federation.backoff; + let last_failure = self.last_failure?; - if self.failure_count <= FAILURE_THRESHOLD { + if self.failure_count <= config.failure_threshold { return None; } - let excess_failure_count = self.failure_count - FAILURE_THRESHOLD; - // Converting to float is fine because we don't expect max_delay - // to be large enough that the loss of precision matters. The - // largest typical value is 24h, with a precision of 0.01ns. - let base_delay_secs = BASE_DELAY.as_secs_f64(); - let max_delay_secs = MAX_DELAY.as_secs_f64(); - let delay_secs = max_delay_secs.min( - base_delay_secs * MULTIPLIER.powi(i32::from(excess_failure_count)), + let excess_failure_count = + self.failure_count - config.failure_threshold; + let delay_secs = config.max_delay.min( + config.base_delay + * config.multiplier.powi(i32::from(excess_failure_count)), ) * self.jitter_coeff; let delay = Duration::from_secs_f64(delay_secs); delay.checked_sub(last_failure.elapsed()) @@ -214,11 +193,14 @@ impl BackoffGuard { /// a 404 from an endpoint that is not specced to return 404. #[instrument(skip(self))] pub(crate) fn hard_failure(self) { + let config = &services().globals.config.federation.backoff; + let mut state = self.backoff.write().unwrap(); if state.last_failure == self.last_failure { state.failure_count = state.failure_count.saturating_add(1); - state.jitter_coeff = thread_rng().gen_range(JITTER_RANGE); + state.jitter_coeff = + thread_rng().gen_range(config.jitter_range.clone()); state.last_failure = Some(Instant::now()); debug!(