diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index c8523117..2a17ab3b 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,5 +1,6 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, + future::IntoFuture, time::Duration, }; @@ -398,12 +399,15 @@ pub(crate) async fn get_keys_helper bool>( server, tokio::time::timeout( Duration::from_secs(25), - services().sending.send_federation_request( - server, - federation::keys::get_keys::v1::Request { - device_keys: device_keys_input_fed, - }, - ), + services() + .sending + .send_federation_request( + server, + federation::keys::get_keys::v1::Request { + device_keys: device_keys_input_fed, + }, + ) + .into_future(), ) .await .map_err(|_e| Error::BadServerResponse("Query took too long")) diff --git a/src/service/sending.rs b/src/service/sending.rs index 07a72972..ac861af3 100644 --- a/src/service/sending.rs +++ b/src/service/sending.rs @@ -1,6 +1,8 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, + future::{Future, IntoFuture}, + pin::Pin, sync::Arc, time::{Duration, Instant}, }; @@ -116,6 +118,12 @@ pub(crate) struct RequestData { requester_span: Span, } +#[must_use = "The request builder must be awaited for the request to be sent"] +pub(crate) struct SendFederationRequestBuilder<'a, T> { + destination: &'a ServerName, + request: T, +} + pub(crate) struct Service { db: &'static dyn Data, @@ -667,57 +675,18 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, request))] - pub(crate) async fn send_federation_request( + // Allowed because `SendFederationRequestBuilder::into_future` uses + // `services()` + #[allow(clippy::unused_self)] + pub(crate) fn send_federation_request<'a, T>( &self, - destination: &ServerName, + destination: &'a ServerName, request: T, - ) -> Result - where - T: OutgoingRequest + Debug, - { - debug!("Waiting for permit"); - let permit = self.maximum_requests.acquire().await; - debug!("Got permit"); - - let backoff_guard = - services().server_backoff.server_ready(destination)?; - - let response = tokio::time::timeout( - Duration::from_secs(2 * 60), - server_server::send_request( - destination, - request, - LogRequestError::Yes, - AllowLoopbackRequests::No, - ), - ) - .await - .map_err(|_| { - warn!("Timeout waiting for server response"); - Error::BadServerResponse("Timeout waiting for server response") - }) - .and_then(|result| result); - drop(permit); - - match &response { - Err(Error::Federation(_, error)) => { - if error.error_kind().is_some() { - // Other errors may occur during normal operation with a - // healthy server, so don't increment the failure counter. - backoff_guard.soft_failure(); - } else { - // The error wasn't in the expected format for matrix API - // responses. This almost certainly indicates the server - // is unhealthy or offline. - backoff_guard.hard_failure(); - } - } - Err(_) => backoff_guard.hard_failure(), - Ok(_) => backoff_guard.success(), + ) -> SendFederationRequestBuilder<'a, T> { + SendFederationRequestBuilder { + destination, + request, } - - response } /// Sends a request to an appservice @@ -745,6 +714,70 @@ impl Service { } } +impl<'a, T> IntoFuture for SendFederationRequestBuilder<'a, T> +where + T: OutgoingRequest + Send + Debug + 'a, + T::IncomingResponse: Send, +{ + // TODO: get rid of the Box once impl_trait_in_assoc_type is stable + // + type IntoFuture = Pin + Send + 'a>>; + type Output = Result; + + #[tracing::instrument( + name = "send_federation_request", + skip(self), + fields(destination = %self.destination) + )] + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { + debug!("Waiting for permit"); + let permit = services().sending.maximum_requests.acquire().await; + debug!("Got permit"); + + let backoff_guard = + services().server_backoff.server_ready(self.destination)?; + + let response = tokio::time::timeout( + Duration::from_secs(2 * 60), + server_server::send_request( + self.destination, + self.request, + LogRequestError::Yes, + AllowLoopbackRequests::No, + ), + ) + .await + .map_err(|_| { + warn!("Timeout waiting for server response"); + Error::BadServerResponse("Timeout waiting for server response") + }) + .and_then(|result| result); + drop(permit); + + match &response { + Err(Error::Federation(_, error)) => { + if error.error_kind().is_some() { + // Other errors may occur during normal operation with a + // healthy server, so don't increment the failure + // counter. + backoff_guard.soft_failure(); + } else { + // The error wasn't in the expected format for matrix + // API responses. This almost certainly indicates the + // server is unhealthy or offline. + backoff_guard.hard_failure(); + } + } + Err(_) => backoff_guard.hard_failure(), + Ok(_) => backoff_guard.success(), + } + + response + }) + } +} + #[tracing::instrument(skip(events))] async fn handle_appservice_event( id: &str,