diff --git a/src/service.rs b/src/service.rs index 85e7a302..3bca91fc 100644 --- a/src/service.rs +++ b/src/service.rs @@ -12,6 +12,7 @@ pub(crate) mod pdu; pub(crate) mod pusher; pub(crate) mod rooms; pub(crate) mod sending; +pub(crate) mod server_backoff; pub(crate) mod transaction_ids; pub(crate) mod uiaa; pub(crate) mod users; @@ -35,6 +36,7 @@ pub(crate) struct Services { pub(crate) globals: globals::Service, pub(crate) key_backups: key_backups::Service, pub(crate) media: media::Service, + pub(crate) server_backoff: Arc, pub(crate) sending: Arc, } @@ -120,6 +122,7 @@ impl Services { media: media::Service { db, }, + server_backoff: server_backoff::Service::build(), sending: sending::Service::new(db, &config), globals: globals::Service::new(db, config, reload_handles)?, diff --git a/src/service/server_backoff.rs b/src/service/server_backoff.rs new file mode 100644 index 00000000..e276a16a --- /dev/null +++ b/src/service/server_backoff.rs @@ -0,0 +1,243 @@ +use std::{ + collections::HashMap, + ops::Range, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + +use rand::{thread_rng, Rng}; +use ruma::{OwnedServerName, ServerName}; +use tracing::{debug, instrument}; + +use crate::{Error, Result}; + +/// Service to handle backing off requests to offline servers. +/// +/// Matrix is full of servers that are either temporarily or permanently +/// offline. It's important not to flood offline servers with federation +/// traffic, since this can consume resources on both ends. +/// +/// To limit traffic to offline servers, we track a global exponential backoff +/// state for federation requests to each server name. This mechanism is *only* +/// intended to handle offline servers. Rate limiting and backoff retries for +/// specific requests have different considerations and need to be handled +/// elsewhere. +/// +/// Exponential backoff is typically used in a retry loop for a single request. +/// Because the state of this backoff is global, and requests may be issued +/// concurrently, we do a couple of unusual things: +/// +/// First, we wait for a certain number of consecutive failed requests before we +/// start delaying further requests. This is to avoid delaying requests to a +/// server that is not offline but fails on a small fraction of requests. +/// +/// Second, we only increment the failure counter once for every batch of +/// concurrent requests, instead of on every failed request. This avoids rapidly +/// increasing the counter, proportional to the rate of outgoing requests, when +/// the server is only briefly offline. +pub(crate) struct Service { + servers: RwLock>>>, +} + +// After the first 5 consecutive failed requests, increase delay +// exponentially from 5s to 24h over the next 24 failures. It takes an +// average of 4.3 days of failures to reach the maximum delay of 24h. + +// TODO: consider making these configurable + +/// Minimum number of consecutive failures for a server before starting to delay +/// requests. +const FAILURE_THRESHOLD: u8 = 5; +/// Initial delay between requests after the number of consecutive failures +/// to a server first exceeds [`FAILURE_THRESHOLD`]. +const BASE_DELAY: Duration = Duration::from_secs(5); +/// Factor to increase delay by after each additional consecutive failure. +const MULTIPLIER: f64 = 1.5; +/// Maximum delay between requests to a server. +const MAX_DELAY: Duration = Duration::from_secs(60 * 60 * 24); +/// Range of random multipliers to request delay. +const JITTER_RANGE: Range = 0.5..1.5; + +/// Guard to record the result of an attempted request to a server. +/// +/// If the request succeeds, call [`BackoffGuard::success`]. If the request +/// fails in a way that indicates the server is unavailble, call +/// [`BackoffGuard::hard_failure`]. If the request fails in a way that doesn't +/// necessarily indicate that the server is unavailable, call +/// [`BackoffGuard::soft_failure`]. Note that this choice is security-sensitive. +/// If an attacker is able to trigger hard failures for an online server, they +/// can cause us to incorrectly mark it as offline and block outgoing requests +/// to it. +#[must_use] +pub(crate) struct BackoffGuard { + backoff: Arc>, + /// Store the last failure timestamp observed when this request started. If + /// there was another failure recorded since the request started, do not + /// increment the failure count. This ensures that only one failure will + /// be recorded for every batch of concurrent requests, as discussed in + /// the documentation of [`Service`]. + last_failure: Option, +} + +/// State of exponential backoff for a specific server. +#[derive(Clone, Debug)] +struct BackoffState { + server_name: OwnedServerName, + + /// Count of consecutive failed requests to this server. + failure_count: u8, + /// Timestamp of the last failed request to this server. + last_failure: Option, + /// Random multiplier to request delay. + /// + /// This is updated to a new random value after each batch of concurrent + /// requests containing a failure. + jitter_coeff: f64, +} + +impl Service { + pub(crate) fn build() -> Arc { + Arc::new(Service { + servers: RwLock::default(), + }) + } + + /// If ready to attempt another request to a server, returns a guard to + /// record the result. + /// + /// If still in the backoff period for this server, returns `Err`. + #[instrument(skip(self))] + pub(crate) fn server_ready( + &self, + server_name: &ServerName, + ) -> Result { + let state = self.server_state(server_name); + + let last_failure = { + let state_lock = state.read().unwrap(); + + if let Some(remaining_delay) = state_lock.remaining_delay() { + debug!(failures = %state_lock.failure_count, ?remaining_delay, "backing off from server"); + return Err(Error::ServerBackoff { + server: server_name.to_owned(), + remaining_delay, + }); + } + + state_lock.last_failure + }; + + Ok(BackoffGuard { + backoff: state, + last_failure, + }) + } + + fn server_state( + &self, + server_name: &ServerName, + ) -> Arc> { + let servers = self.servers.read().unwrap(); + if let Some(state) = servers.get(server_name) { + Arc::clone(state) + } else { + drop(servers); + let mut servers = self.servers.write().unwrap(); + + // We have to check again because it's possible for another thread + // to write in between us dropping the read lock and taking the + // write lock. + if let Some(state) = servers.get(server_name) { + Arc::clone(state) + } else { + let state = Arc::new(RwLock::new(BackoffState::new( + server_name.to_owned(), + ))); + servers.insert(server_name.to_owned(), Arc::clone(&state)); + state + } + } + } +} + +impl BackoffState { + fn new(server_name: OwnedServerName) -> BackoffState { + BackoffState { + server_name, + failure_count: 0, + last_failure: None, + jitter_coeff: 0.0, + } + } + + /// Returns the remaining time before ready to attempt another request to + /// this server. + fn remaining_delay(&self) -> Option { + let last_failure = self.last_failure?; + if self.failure_count <= FAILURE_THRESHOLD { + return None; + } + + let excess_failure_count = self.failure_count - FAILURE_THRESHOLD; + // Converting to float is fine because we don't expect max_delay + // to be large enough that the loss of precision matters. The + // largest typical value is 24h, with a precision of 0.01ns. + let base_delay_secs = BASE_DELAY.as_secs_f64(); + let max_delay_secs = MAX_DELAY.as_secs_f64(); + let delay_secs = max_delay_secs.min( + base_delay_secs * MULTIPLIER.powi(i32::from(excess_failure_count)), + ) * self.jitter_coeff; + let delay = Duration::from_secs_f64(delay_secs); + delay.checked_sub(last_failure.elapsed()) + } +} + +impl BackoffGuard { + /// Record a successful request. + #[instrument(skip(self))] + pub(crate) fn success(self) { + let mut state = self.backoff.write().unwrap(); + + if state.failure_count != 0 { + debug!( + server_name = %&state.server_name, + "successful request to server, resetting failure count" + ); + } + + state.failure_count = 0; + } + + /// Record a failed request indicating that the server may be unavailable. + /// + /// Examples of failures in this category are a timeout, a 500 status, or + /// a 404 from an endpoint that is not specced to return 404. + #[instrument(skip(self))] + pub(crate) fn hard_failure(self) { + let mut state = self.backoff.write().unwrap(); + + if state.last_failure == self.last_failure { + state.failure_count = state.failure_count.saturating_add(1); + state.jitter_coeff = thread_rng().gen_range(JITTER_RANGE); + state.last_failure = Some(Instant::now()); + + debug!( + server_name = %state.server_name, + failure_count = state.failure_count, + "hard failure sending request to server, incrementing failure count" + ); + } + } + + /// Record a request that failed, but where the failure is likely to occur + /// in normal operation even if the server is not unavailable. + /// + /// An example of a failure in this category is 404 from querying a user + /// profile. This might occur if the server no longer exists, but will also + /// occur if the userid doesn't exist. + // Taking `self` here is intentional, to allow callers to destroy the guard + // without triggering the `must_use` warning. + #[allow(clippy::unused_self)] + #[instrument(skip(self))] + pub(crate) fn soft_failure(self) {} +} diff --git a/src/utils/error.rs b/src/utils/error.rs index 49b5ab15..d4e1e417 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -1,4 +1,4 @@ -use std::convert::Infallible; +use std::{convert::Infallible, time::Duration}; use http::StatusCode; use ruma::{ @@ -84,6 +84,13 @@ pub(crate) enum Error { UnsupportedRoomVersion(ruma::RoomVersionId), #[error("{0} in {1}")] InconsistentRoomState(&'static str, ruma::OwnedRoomId), + #[error( + "backing off requests to {server} for the next {remaining_delay:?}" + )] + ServerBackoff { + server: OwnedServerName, + remaining_delay: Duration, + }, } impl Error {