factor remote device key request logic into helper functions

This is pure code-motion, with no behavior changes. The new structure
will make it easier to fix the backoff behavior, and makes the code
somewhat less of a nightmare to follow.
This commit is contained in:
Olivia Lee 2024-12-01 15:57:06 -08:00
parent 9e3738d330
commit 79cedccdb6
No known key found for this signature in database
GPG key ID: 54D568A15B9CD1F9

View file

@ -17,7 +17,8 @@ use ruma::{
federation,
},
serde::Raw,
OneTimeKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
OneTimeKeyAlgorithm, OwnedDeviceId, OwnedServerName, OwnedUserId,
ServerName, UserId,
};
use serde_json::json;
use tracing::debug;
@ -385,80 +386,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new();
let back_off = |id| async {
match services().globals.bad_query_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1 + 1);
}
}
};
let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter()
.map(|(server, vec)| async move {
if let Some((time, tries)) = services()
.globals
.bad_query_ratelimiter
.read()
.await
.get(server)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if let Some(remaining) =
min_elapsed_duration.checked_sub(time.elapsed())
{
debug!(%server, %tries, ?remaining, "Backing off from server");
return (
server,
Err(Error::BadServerResponse(
"bad query, still backing off",
)),
);
}
}
let mut device_keys_input_fed = BTreeMap::new();
for (user_id, keys) in vec {
device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
}
// TODO: switch .and_then(|result| result) to .flatten() when stable
// <https://github.com/rust-lang/rust/issues/70142>
(
server,
tokio::time::timeout(
Duration::from_secs(25),
services().sending.send_federation_request(
server,
federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed,
},
),
)
.await
.map_err(|_e| Error::BadServerResponse("Query took too long"))
.and_then(|result| result),
)
.map(|(server, keys)| async move {
(server, request_keys_from(server, keys).await)
})
.collect();
while let Some((server, response)) = futures.next().await {
let response = match response {
Ok(response) => response,
Err(error) => {
back_off(server.to_owned()).await;
debug!(%server, %error, "remote device key query failed");
failures.insert(server.to_string(), json!({}));
continue;
}
let Ok(response) = response else {
failures.insert(server.to_string(), json!({}));
continue;
};
for (user, masterkey) in response.master_keys {
@ -511,6 +449,84 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
})
}
/// Returns `Err` if key requests to the server are being backed off due to
/// previous errors.
async fn check_key_requests_back_off(server: &ServerName) -> Result<()> {
if let Some((time, tries)) =
services().globals.bad_query_ratelimiter.read().await.get(server)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if let Some(remaining) =
min_elapsed_duration.checked_sub(time.elapsed())
{
debug!(%server, %tries, ?remaining, "Backing off from server");
return Err(Error::BadServerResponse(
"bad query, still backing off",
));
}
}
Ok(())
}
/// Backs off future remote device key requests to a server after a failure.
async fn back_off_key_requests(server: OwnedServerName) {
match services().globals.bad_query_ratelimiter.write().await.entry(server) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1 + 1);
}
}
}
/// Requests device keys from a remote server, unless the server is in backoff.
///
/// Updates backoff state depending on the result of the request.
async fn request_keys_from(
server: &ServerName,
keys: Vec<(&UserId, &Vec<OwnedDeviceId>)>,
) -> Result<federation::keys::get_keys::v1::Response> {
let result = request_keys_from_inner(server, keys).await;
if let Err(error) = &result {
debug!(%server, %error, "remote device key query failed");
back_off_key_requests(server.to_owned()).await;
}
result
}
async fn request_keys_from_inner(
server: &ServerName,
keys: Vec<(&UserId, &Vec<OwnedDeviceId>)>,
) -> Result<federation::keys::get_keys::v1::Response> {
check_key_requests_back_off(server).await?;
let mut device_keys_input_fed = BTreeMap::new();
for (user_id, keys) in keys {
device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
}
// TODO: switch .and_then(|result| result) to .flatten() when stable
// <https://github.com/rust-lang/rust/issues/70142>
tokio::time::timeout(
Duration::from_secs(25),
services().sending.send_federation_request(
server,
federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed,
},
),
)
.await
.map_err(|_e| Error::BadServerResponse("Query took too long"))
.and_then(|result| result)
}
fn add_unsigned_device_display_name(
keys: &mut Raw<ruma::encryption::DeviceKeys>,
metadata: ruma::api::client::device::Device,