diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index c4422079..40a2ccb7 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -17,7 +17,8 @@ use ruma::{ federation, }, serde::Raw, - OneTimeKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, + OneTimeKeyAlgorithm, OwnedDeviceId, OwnedServerName, OwnedUserId, + ServerName, UserId, }; use serde_json::json; use tracing::debug; @@ -385,80 +386,17 @@ pub(crate) async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); - let back_off = |id| async { - match services().globals.bad_query_ratelimiter.write().await.entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); - } - } - }; - let mut futures: FuturesUnordered<_> = get_over_federation .into_iter() - .map(|(server, vec)| async move { - if let Some((time, tries)) = services() - .globals - .bad_query_ratelimiter - .read() - .await - .get(server) - { - // Exponential backoff - let mut min_elapsed_duration = - Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if let Some(remaining) = - min_elapsed_duration.checked_sub(time.elapsed()) - { - debug!(%server, %tries, ?remaining, "Backing off from server"); - return ( - server, - Err(Error::BadServerResponse( - "bad query, still backing off", - )), - ); - } - } - - let mut device_keys_input_fed = BTreeMap::new(); - for (user_id, keys) in vec { - device_keys_input_fed.insert(user_id.to_owned(), keys.clone()); - } - // TODO: switch .and_then(|result| result) to .flatten() when stable - // - ( - server, - tokio::time::timeout( - Duration::from_secs(25), - services().sending.send_federation_request( - server, - federation::keys::get_keys::v1::Request { - device_keys: device_keys_input_fed, - }, - ), - ) - .await - .map_err(|_e| Error::BadServerResponse("Query took too long")) - .and_then(|result| result), - ) + .map(|(server, keys)| async move { + (server, request_keys_from(server, keys).await) }) .collect(); while let Some((server, response)) = futures.next().await { - let response = match response { - Ok(response) => response, - Err(error) => { - back_off(server.to_owned()).await; - debug!(%server, %error, "remote device key query failed"); - failures.insert(server.to_string(), json!({})); - continue; - } + let Ok(response) = response else { + failures.insert(server.to_string(), json!({})); + continue; }; for (user, masterkey) in response.master_keys { @@ -511,6 +449,84 @@ pub(crate) async fn get_keys_helper bool>( }) } +/// Returns `Err` if key requests to the server are being backed off due to +/// previous errors. +async fn check_key_requests_back_off(server: &ServerName) -> Result<()> { + if let Some((time, tries)) = + services().globals.bad_query_ratelimiter.read().await.get(server) + { + // Exponential backoff + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if let Some(remaining) = + min_elapsed_duration.checked_sub(time.elapsed()) + { + debug!(%server, %tries, ?remaining, "Backing off from server"); + return Err(Error::BadServerResponse( + "bad query, still backing off", + )); + } + } + Ok(()) +} + +/// Backs off future remote device key requests to a server after a failure. +async fn back_off_key_requests(server: OwnedServerName) { + match services().globals.bad_query_ratelimiter.write().await.entry(server) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + } + } +} + +/// Requests device keys from a remote server, unless the server is in backoff. +/// +/// Updates backoff state depending on the result of the request. +async fn request_keys_from( + server: &ServerName, + keys: Vec<(&UserId, &Vec)>, +) -> Result { + let result = request_keys_from_inner(server, keys).await; + if let Err(error) = &result { + debug!(%server, %error, "remote device key query failed"); + back_off_key_requests(server.to_owned()).await; + } + result +} + +async fn request_keys_from_inner( + server: &ServerName, + keys: Vec<(&UserId, &Vec)>, +) -> Result { + check_key_requests_back_off(server).await?; + + let mut device_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in keys { + device_keys_input_fed.insert(user_id.to_owned(), keys.clone()); + } + // TODO: switch .and_then(|result| result) to .flatten() when stable + // + tokio::time::timeout( + Duration::from_secs(25), + services().sending.send_federation_request( + server, + federation::keys::get_keys::v1::Request { + device_keys: device_keys_input_fed, + }, + ), + ) + .await + .map_err(|_e| Error::BadServerResponse("Query took too long")) + .and_then(|result| result) +} + fn add_unsigned_device_display_name( keys: &mut Raw, metadata: ruma::api::client::device::Device,