From 07b52339809aae4616c1db8e10a6a5962b160cee Mon Sep 17 00:00:00 2001 From: Lambda Date: Mon, 27 May 2024 19:27:08 +0000 Subject: [PATCH] Use OnDemandHashMap for servername_ratelimiter This way, semaphores are actually cleaned up eventually. --- src/service/globals.rs | 8 +++++--- src/service/rooms/event_handler.rs | 22 +++------------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/src/service/globals.rs b/src/service/globals.rs index 02ae4e93..bf82d4e1 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -34,7 +34,7 @@ use trust_dns_resolver::TokioAsyncResolver; use crate::{ api::server_server::FedDest, observability::FilterReloadHandles, services, - Config, Error, Result, + utils::on_demand_hashmap::OnDemandHashMap, Config, Error, Result, }; type WellKnownMap = HashMap; @@ -66,7 +66,7 @@ pub(crate) struct Service { pub(crate) bad_query_ratelimiter: Arc>>, pub(crate) servername_ratelimiter: - Arc>>>, + OnDemandHashMap, pub(crate) roomid_mutex_insert: RwLock>>>, pub(crate) roomid_mutex_state: RwLock>>>, @@ -263,7 +263,9 @@ impl Service { bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + servername_ratelimiter: OnDemandHashMap::new( + "servername_ratelimiter".to_owned(), + ), roomid_mutex_state: RwLock::new(HashMap::new()), roomid_mutex_insert: RwLock::new(HashMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()), diff --git a/src/service/rooms/event_handler.rs b/src/service/rooms/event_handler.rs index 427e65e2..be1c3276 100644 --- a/src/service/rooms/event_handler.rs +++ b/src/service/rooms/event_handler.rs @@ -1935,25 +1935,9 @@ impl Service { let permit = services() .globals .servername_ratelimiter - .read() - .await - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); - - let permit = if let Some(p) = permit { - p - } else { - let mut write = - services().globals.servername_ratelimiter.write().await; - let s = Arc::clone( - write - .entry(origin.to_owned()) - .or_insert_with(|| Arc::new(Semaphore::new(1))), - ); - - s.acquire_owned() - } - .await; + .get_or_insert_with(origin.to_owned(), || Semaphore::new(1)) + .await; + let permit = permit.acquire().await; let back_off = |id| async { match services()