From 34ccb2cd06448cb398b7042908b02b55ebab29ed Mon Sep 17 00:00:00 2001 From: Lambda Date: Sun, 23 Jun 2024 18:39:56 +0000 Subject: [PATCH] Use TokenSet for roomid_mutex_state --- src/api/client_server/membership.rs | 160 +++++++++----------------- src/api/client_server/message.rs | 25 ++-- src/api/client_server/profile.rs | 46 ++------ src/api/client_server/redact.rs | 22 ++-- src/api/client_server/room.rs | 97 +++++----------- src/api/client_server/state.rs | 18 +-- src/api/server_server.rs | 20 ++-- src/database/key_value/rooms/state.rs | 18 +-- src/service/admin.rs | 85 +++++--------- src/service/globals.rs | 20 +++- src/service/rooms/event_handler.rs | 23 ++-- src/service/rooms/state.rs | 35 +++--- src/service/rooms/state/data.rs | 15 +-- src/service/rooms/state_accessor.rs | 16 +-- src/service/rooms/timeline.rs | 72 +++++------- 15 files changed, 243 insertions(+), 429 deletions(-) diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 0a2725f2..993aedd1 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -226,16 +226,11 @@ pub(crate) async fn kick_user_route( event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; services() .rooms @@ -250,12 +245,11 @@ pub(crate) async fn kick_user_route( redacts: None, }, sender_user, - &body.room_id, - &state_lock, + &room_token, ) .await?; - drop(state_lock); + drop(room_token); Ok(Ra(kick_user::v3::Response::new())) } @@ -302,16 +296,11 @@ pub(crate) async fn ban_user_route( }, )?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; services() .rooms @@ -326,12 +315,11 @@ pub(crate) async fn ban_user_route( redacts: None, }, sender_user, - &body.room_id, - &state_lock, + &room_token, ) .await?; - drop(state_lock); + drop(room_token); Ok(Ra(ban_user::v3::Response::new())) } @@ -365,16 +353,11 @@ pub(crate) async fn unban_user_route( event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; services() .rooms @@ -389,12 +372,11 @@ pub(crate) async fn unban_user_route( redacts: None, }, sender_user, - &body.room_id, - &state_lock, + &room_token, ) .await?; - drop(state_lock); + drop(room_token); Ok(Ra(unban_user::v3::Response::new())) } @@ -528,16 +510,11 @@ async fn join_room_by_id_helper( ) -> Result { let sender_user = sender_user.expect("user is authenticated"); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; // Ask a remote server if we are not participating in this room if services() @@ -601,10 +578,9 @@ async fn join_room_by_id_helper( { if user.server_name() == services().globals.server_name() && services().rooms.state_accessor.user_can_invite( - room_id, + &room_token, &user, sender_user, - &state_lock, ) { auth_user = Some(user); @@ -641,8 +617,7 @@ async fn join_room_by_id_helper( redacts: None, }, sender_user, - room_id, - &state_lock, + &room_token, ) .await { @@ -797,7 +772,7 @@ async fn join_room_by_id_helper( )); } - drop(state_lock); + drop(room_token); let pub_key_map = RwLock::new(BTreeMap::new()); services() .rooms @@ -1113,13 +1088,7 @@ async fn join_room_by_id_helper( services() .rooms .state - .force_state( - room_id, - statehash_before_join, - new, - removed, - &state_lock, - ) + .force_state(&room_token, statehash_before_join, new, removed) .await?; info!("Updating joined counts for new room"); @@ -1139,7 +1108,7 @@ async fn join_room_by_id_helper( &parsed_join_pdu, join_event, vec![(*parsed_join_pdu.event_id).to_owned()], - &state_lock, + &room_token, ) .await?; @@ -1147,11 +1116,10 @@ async fn join_room_by_id_helper( // We set the room state after inserting the pdu, so that we never have // a moment in time where events in the current room state do // not exist - services().rooms.state.set_room_state( - room_id, - statehash_after_join, - &state_lock, - )?; + services() + .rooms + .state + .set_room_state(&room_token, statehash_after_join)?; } Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) @@ -1306,16 +1274,11 @@ pub(crate) async fn invite_helper( ) -> Result<()> { if user_id.server_name() != services().globals.server_name() { let (pdu, pdu_json, invite_room_state) = { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, @@ -1339,14 +1302,13 @@ pub(crate) async fn invite_helper( redacts: None, }, sender_user, - room_id, - &state_lock, + &room_token, )?; let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; - drop(state_lock); + drop(room_token); (pdu, pdu_json, invite_room_state) }; @@ -1447,16 +1409,11 @@ pub(crate) async fn invite_helper( )); } - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; services() .rooms @@ -1480,12 +1437,11 @@ pub(crate) async fn invite_helper( redacts: None, }, sender_user, - room_id, - &state_lock, + &room_token, ) .await?; - drop(state_lock); + drop(room_token); Ok(()) } @@ -1530,16 +1486,11 @@ pub(crate) async fn leave_room( .state_cache .server_in_room(services().globals.server_name(), room_id)? { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; let member_event = services().rooms.state_accessor.room_state_get( room_id, @@ -1587,8 +1538,7 @@ pub(crate) async fn leave_room( redacts: None, }, user_id, - room_id, - &state_lock, + &room_token, ) .await?; } else { diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 162c2d63..6f29205f 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, -}; +use std::collections::{BTreeMap, HashSet}; use ruma::{ api::client::{ @@ -32,16 +29,11 @@ pub(crate) async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; // Forbid m.room.encrypted if encryption is disabled if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() @@ -104,8 +96,7 @@ pub(crate) async fn send_message_event_route( redacts: None, }, sender_user, - &body.room_id, - &state_lock, + &room_token, ) .await?; @@ -116,7 +107,7 @@ pub(crate) async fn send_message_event_route( event_id.as_bytes(), )?; - drop(state_lock); + drop(room_token); Ok(Ra(send_message_event::v3::Response::new((*event_id).to_owned()))) } diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index 8557baab..8a84f743 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ruma::{ api::{ client::{ @@ -81,26 +79,16 @@ pub(crate) async fn set_displayname_route( .collect(); for (pdu_builder, room_id) in all_rooms_joined { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.clone()) + .await; if let Err(error) = services() .rooms .timeline - .build_and_append_pdu( - pdu_builder, - sender_user, - &room_id, - &state_lock, - ) + .build_and_append_pdu(pdu_builder, sender_user, &room_token) .await { warn!(%error, "failed to add PDU"); @@ -203,26 +191,16 @@ pub(crate) async fn set_avatar_url_route( .collect(); for (pdu_builder, room_id) in all_joined_rooms { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.clone()) + .await; if let Err(error) = services() .rooms .timeline - .build_and_append_pdu( - pdu_builder, - sender_user, - &room_id, - &state_lock, - ) + .build_and_append_pdu(pdu_builder, sender_user, &room_token) .await { warn!(%error, "failed to add PDU"); diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 5604bfb0..46af6d8c 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, @@ -19,16 +17,11 @@ pub(crate) async fn redact_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; let event_id = services() .rooms @@ -46,12 +39,11 @@ pub(crate) async fn redact_event_route( redacts: Some(body.event_id.into()), }, sender_user, - &body.room_id, - &state_lock, + &room_token, ) .await?; - drop(state_lock); + drop(room_token); let event_id = (*event_id).to_owned(); Ok(Ra(redact_event::v3::Response { diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 76ed6b9d..c02526b3 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,4 +1,4 @@ -use std::{cmp::max, collections::BTreeMap, sync::Arc}; +use std::{cmp::max, collections::BTreeMap}; use ruma::{ api::client::{ @@ -63,16 +63,8 @@ pub(crate) async fn create_room_route( services().rooms.short.get_or_create_shortroomid(&room_id)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = + services().globals.roomid_mutex_state.lock_key(room_id.clone()).await; if !services().globals.allow_room_creation() && body.appservice_info.is_none() @@ -250,8 +242,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -278,8 +269,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -338,8 +328,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -361,8 +350,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; } @@ -389,8 +377,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -410,8 +397,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -434,8 +420,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -463,12 +448,7 @@ pub(crate) async fn create_room_route( services() .rooms .timeline - .build_and_append_pdu( - pdu_builder, - sender_user, - &room_id, - &state_lock, - ) + .build_and_append_pdu(pdu_builder, sender_user, &room_token) .await?; } @@ -489,8 +469,7 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; } @@ -511,14 +490,13 @@ pub(crate) async fn create_room_route( redacts: None, }, sender_user, - &room_id, - &state_lock, + &room_token, ) .await?; } // 8. Events implied by invite (and TODO: invite_3pid) - drop(state_lock); + drop(room_token); for user_id in &body.invite { if let Err(error) = invite_helper(sender_user, user_id, &room_id, None, body.is_direct) @@ -635,16 +613,11 @@ pub(crate) async fn upgrade_room_route( let replacement_room = RoomId::new(services().globals.server_name()); services().rooms.short.get_or_create_shortroomid(&replacement_room)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let original_state_lock = mutex_state.lock().await; + let original_room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; // Send a m.room.tombstone event to the old room to indicate that it is not // intended to be used any further Fail if the sender does not have the @@ -665,22 +638,16 @@ pub(crate) async fn upgrade_room_route( redacts: None, }, sender_user, - &body.room_id, - &original_state_lock, + &original_room_token, ) .await?; // Change lock to replacement room - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(replacement_room.clone()) - .or_default(), - ); - let replacement_state_lock = mutex_state.lock().await; + let replacement_room_token = services() + .globals + .roomid_mutex_state + .lock_key(replacement_room.clone()) + .await; // Get the old room creation event let mut create_event_content = serde_json::from_str::( @@ -777,8 +744,7 @@ pub(crate) async fn upgrade_room_route( redacts: None, }, sender_user, - &replacement_room, - &replacement_state_lock, + &replacement_room_token, ) .await?; @@ -805,8 +771,7 @@ pub(crate) async fn upgrade_room_route( redacts: None, }, sender_user, - &replacement_room, - &replacement_state_lock, + &replacement_room_token, ) .await?; @@ -847,8 +812,7 @@ pub(crate) async fn upgrade_room_route( redacts: None, }, sender_user, - &replacement_room, - &replacement_state_lock, + &replacement_room_token, ) .await?; } @@ -910,13 +874,10 @@ pub(crate) async fn upgrade_room_route( redacts: None, }, sender_user, - &body.room_id, - &original_state_lock, + &original_room_token, ) .await?; - drop(replacement_state_lock); - // Return the replacement room id Ok(Ra(upgrade_room::v3::Response { replacement_room, diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 9da197b9..b0e5809a 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -240,16 +240,11 @@ async fn send_state_event_for_key_helper( } } - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; let event_id = services() .rooms @@ -264,8 +259,7 @@ async fn send_state_event_for_key_helper( redacts: None, }, sender_user, - room_id, - &state_lock, + &room_token, ) .await?; diff --git a/src/api/server_server.rs b/src/api/server_server.rs index e2841acb..2208c693 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1452,16 +1452,11 @@ pub(crate) async fn create_join_event_template_route( .event_handler .acl_check(sender_servername, &body.room_id)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(body.room_id.clone()) + .await; // TODO: Grapevine does not implement restricted join rules yet, we always // reject @@ -1529,11 +1524,10 @@ pub(crate) async fn create_join_event_template_route( redacts: None, }, &body.user_id, - &body.room_id, - &state_lock, + &room_token, )?; - drop(state_lock); + drop(room_token); pdu_json.remove("event_id"); diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 78863a75..a3246878 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,9 +1,13 @@ use std::{collections::HashSet, sync::Arc}; -use ruma::{EventId, OwnedEventId, RoomId}; -use tokio::sync::MutexGuard; +use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId}; -use crate::{database::KeyValueDatabase, service, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, + service::{self, globals::marker}, + utils::{self, on_demand_hashmap::KeyToken}, + Error, Result, +}; impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { @@ -21,10 +25,8 @@ impl service::rooms::state::Data for KeyValueDatabase { fn set_room_state( &self, - room_id: &RoomId, + room_id: &KeyToken, new_shortstatehash: u64, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; @@ -71,10 +73,8 @@ impl service::rooms::state::Data for KeyValueDatabase { fn set_forward_extremities( &self, - room_id: &RoomId, + room_id: &KeyToken, event_ids: Vec, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/admin.rs b/src/service/admin.rs index b1edf1a9..e18a2039 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -265,17 +265,11 @@ impl Service { } }; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(grapevine_room.clone()) - .or_default(), - ); - - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(grapevine_room.clone()) + .await; services() .rooms @@ -290,8 +284,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - grapevine_room, - &state_lock, + &room_token, ) .await .unwrap(); @@ -1220,16 +1213,11 @@ impl Service { services().rooms.short.get_or_create_shortroomid(&room_id)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.clone()) + .await; services().users.create(&services().globals.admin_bot_user_id, None)?; @@ -1268,8 +1256,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1298,8 +1285,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1323,8 +1309,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1344,8 +1329,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1367,8 +1351,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1388,8 +1371,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1411,8 +1393,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1434,8 +1415,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1458,8 +1438,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1495,16 +1474,11 @@ impl Service { displayname: String, ) -> Result<()> { if let Some(room_id) = services().admin.get_admin_room()? { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.clone()) + .await; // Use the server user to grant the new admin's power level // Invite and join the real user @@ -1530,8 +1504,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; services() @@ -1556,8 +1529,7 @@ impl Service { redacts: None, }, user_id, - &room_id, - &state_lock, + &room_token, ) .await?; @@ -1585,8 +1557,7 @@ impl Service { redacts: None, }, &services().globals.admin_bot_user_id, - &room_id, - &state_lock, + &room_token, ) .await?; } diff --git a/src/service/globals.rs b/src/service/globals.rs index bf82d4e1..e8fe08d1 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -33,8 +33,11 @@ use tracing::{error, Instrument}; use trust_dns_resolver::TokioAsyncResolver; use crate::{ - api::server_server::FedDest, observability::FilterReloadHandles, services, - utils::on_demand_hashmap::OnDemandHashMap, Config, Error, Result, + api::server_server::FedDest, + observability::FilterReloadHandles, + services, + utils::on_demand_hashmap::{OnDemandHashMap, TokenSet}, + Config, Error, Result, }; type WellKnownMap = HashMap; @@ -42,6 +45,15 @@ type TlsNameMap = HashMap, u16)>; // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32); +// Markers for +// [`Service::roomid_mutex_state`]/[`Service::roomid_mutex_insert`]/ +// [`Service::roomid_mutex_federation`] +pub(crate) mod marker { + pub(crate) enum State {} + pub(crate) enum Insert {} + pub(crate) enum Federation {} +} + pub(crate) struct Service { pub(crate) db: &'static dyn Data, pub(crate) reload_handles: FilterReloadHandles, @@ -69,7 +81,7 @@ pub(crate) struct Service { OnDemandHashMap, pub(crate) roomid_mutex_insert: RwLock>>>, - pub(crate) roomid_mutex_state: RwLock>>>, + pub(crate) roomid_mutex_state: TokenSet, // this lock will be held longer pub(crate) roomid_mutex_federation: @@ -266,7 +278,7 @@ impl Service { servername_ratelimiter: OnDemandHashMap::new( "servername_ratelimiter".to_owned(), ), - roomid_mutex_state: RwLock::new(HashMap::new()), + roomid_mutex_state: TokenSet::new("roomid_mutex_state".to_owned()), roomid_mutex_insert: RwLock::new(HashMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()), roomid_federationhandletime: RwLock::new(HashMap::new()), diff --git a/src/service/rooms/event_handler.rs b/src/service/rooms/event_handler.rs index be1c3276..3817fb90 100644 --- a/src/service/rooms/event_handler.rs +++ b/src/service/rooms/event_handler.rs @@ -992,16 +992,11 @@ impl Service { // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let room_token = services() + .globals + .roomid_mutex_state + .lock_key(room_id.to_owned()) + .await; // Now we calculate the set of extremities this room has after the // incoming event has been applied. We start with the previous @@ -1070,7 +1065,7 @@ impl Service { services() .rooms .state - .force_state(room_id, sstatehash, new, removed, &state_lock) + .force_state(&room_token, sstatehash, new, removed) .await?; } @@ -1088,7 +1083,7 @@ impl Service { extremities.iter().map(|e| (**e).to_owned()).collect(), state_ids_compressed, soft_fail, - &state_lock, + &room_token, ) .await?; @@ -1121,14 +1116,14 @@ impl Service { extremities.iter().map(|e| (**e).to_owned()).collect(), state_ids_compressed, soft_fail, - &state_lock, + &room_token, ) .await?; debug!("Appended incoming pdu"); // Event has passed all auth/stateres checks - drop(state_lock); + drop(room_token); Ok(pdu_id) } diff --git a/src/service/rooms/state.rs b/src/service/rooms/state.rs index 709c70d2..f2ab7b36 100644 --- a/src/service/rooms/state.rs +++ b/src/service/rooms/state.rs @@ -13,16 +13,18 @@ use ruma::{ }, serde::Raw, state_res::{self, StateMap}, - EventId, OwnedEventId, RoomId, RoomVersionId, UserId, + EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; use serde::Deserialize; -use tokio::sync::MutexGuard; use tracing::warn; use super::state_compressor::CompressedStateEvent; use crate::{ + service::globals::marker, services, - utils::{calculate_hash, debug_slice_truncated}, + utils::{ + calculate_hash, debug_slice_truncated, on_demand_hashmap::KeyToken, + }, Error, PduEvent, Result, }; @@ -32,20 +34,13 @@ pub(crate) struct Service { impl Service { /// Set the room to the given statehash and update caches. - #[tracing::instrument(skip( - self, - statediffnew, - _statediffremoved, - state_lock - ))] + #[tracing::instrument(skip(self, statediffnew, _statediffremoved))] pub(crate) async fn force_state( &self, - room_id: &RoomId, + room_id: &KeyToken, shortstatehash: u64, statediffnew: Arc>, _statediffremoved: Arc>, - // Take mutex guard to make sure users get the room state mutex - state_lock: &MutexGuard<'_, ()>, ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { services() @@ -116,7 +111,7 @@ impl Service { services().rooms.state_cache.update_joined_count(room_id)?; - self.db.set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash)?; Ok(()) } @@ -325,12 +320,10 @@ impl Service { #[tracing::instrument(skip(self))] pub(crate) fn set_room_state( &self, - room_id: &RoomId, + room_id: &KeyToken, shortstatehash: u64, - // Take mutex guard to make sure users get the room state mutex - mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) + self.db.set_room_state(room_id, shortstatehash) } /// Returns the room's version. @@ -383,17 +376,15 @@ impl Service { } #[tracing::instrument( - skip(self, event_ids, state_lock), + skip(self, event_ids), fields(event_ids = debug_slice_truncated(&event_ids, 5)), )] pub(crate) fn set_forward_extremities( &self, - room_id: &RoomId, + room_id: &KeyToken, event_ids: Vec, - // Take mutex guard to make sure users get the room state mutex - state_lock: &MutexGuard<'_, ()>, ) -> Result<()> { - self.db.set_forward_extremities(room_id, event_ids, state_lock) + self.db.set_forward_extremities(room_id, event_ids) } /// This fetches auth events from the current state. diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 2ed50def..6f15b004 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,9 +1,10 @@ use std::{collections::HashSet, sync::Arc}; -use ruma::{EventId, OwnedEventId, RoomId}; -use tokio::sync::MutexGuard; +use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId}; -use crate::Result; +use crate::{ + service::globals::marker, utils::on_demand_hashmap::KeyToken, Result, +}; pub(crate) trait Data: Send + Sync { /// Returns the last state hash key added to the db for the given room. @@ -12,10 +13,8 @@ pub(crate) trait Data: Send + Sync { /// Set the state hash to a new version, but does not update `state_cache`. fn set_room_state( &self, - room_id: &RoomId, + room_id: &KeyToken, new_shortstatehash: u64, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()>; /// Associates a state with an event. @@ -34,9 +33,7 @@ pub(crate) trait Data: Send + Sync { /// Replace the forward extremities of the room. fn set_forward_extremities( &self, - room_id: &RoomId, + room_id: &KeyToken, event_ids: Vec, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()>; } diff --git a/src/service/rooms/state_accessor.rs b/src/service/rooms/state_accessor.rs index 99502700..b9e9f702 100644 --- a/src/service/rooms/state_accessor.rs +++ b/src/service/rooms/state_accessor.rs @@ -20,17 +20,18 @@ use ruma::{ StateEventType, }, state_res::Event, - EventId, JsOption, OwnedServerName, OwnedUserId, RoomId, ServerName, - UserId, + EventId, JsOption, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + ServerName, UserId, }; use serde_json::value::to_raw_value; -use tokio::sync::MutexGuard; use tracing::{error, warn}; use crate::{ observability::{FoundIn, Lookup, METRICS}, - service::pdu::PduBuilder, - services, Error, PduEvent, Result, + service::{globals::marker, pdu::PduBuilder}, + services, + utils::on_demand_hashmap::KeyToken, + Error, PduEvent, Result, }; pub(crate) struct Service { @@ -390,10 +391,9 @@ impl Service { #[tracing::instrument(skip(self), ret(level = "trace"))] pub(crate) fn user_can_invite( &self, - room_id: &RoomId, + room_id: &KeyToken, sender: &UserId, target_user: &UserId, - state_lock: &MutexGuard<'_, ()>, ) -> bool { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) @@ -410,7 +410,7 @@ impl Service { services() .rooms .timeline - .create_hash_and_sign_event(new_event, sender, room_id, state_lock) + .create_hash_and_sign_event(new_event, sender, room_id) .is_ok() } diff --git a/src/service/rooms/timeline.rs b/src/service/rooms/timeline.rs index f850057d..cc8e42d3 100644 --- a/src/service/rooms/timeline.rs +++ b/src/service/rooms/timeline.rs @@ -22,11 +22,12 @@ use ruma::{ push::{Action, Ruleset, Tweak}, state_res::{self, Event, RoomVersion}, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, - OwnedEventId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, + OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, + ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::{MutexGuard, RwLock}; +use tokio::sync::RwLock; use tracing::{error, info, warn}; use super::state_compressor::CompressedStateEvent; @@ -34,10 +35,12 @@ use crate::{ api::server_server, service::{ appservice::NamespaceRegex, - globals::SigningKeys, + globals::{marker, SigningKeys}, pdu::{EventHash, PduBuilder}, }, - services, utils, Error, PduEvent, Result, + services, + utils::{self, on_demand_hashmap::KeyToken}, + Error, PduEvent, Result, }; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] @@ -197,9 +200,10 @@ impl Service { pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, leaves: Vec, - // Take mutex guard to make sure users get the room state mutex - state_lock: &MutexGuard<'_, ()>, + room_id: &KeyToken, ) -> Result> { + assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed"); + let shortroomid = services() .rooms .short @@ -253,11 +257,7 @@ impl Service { .rooms .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services().rooms.state.set_forward_extremities( - &pdu.room_id, - leaves, - state_lock, - )?; + services().rooms.state.set_forward_extremities(room_id, leaves)?; let mutex_insert = Arc::clone( services() @@ -669,9 +669,7 @@ impl Service { &self, pdu_builder: PduBuilder, sender: &UserId, - room_id: &RoomId, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &MutexGuard<'_, ()>, + room_id: &KeyToken, ) -> Result<(PduEvent, CanonicalJsonObject)> { let PduBuilder { event_type, @@ -702,7 +700,7 @@ impl Service { } else { Err(Error::InconsistentRoomState( "non-create event for room of unknown version", - room_id.to_owned(), + (**room_id).clone(), )) } })?; @@ -751,7 +749,7 @@ impl Service { let mut pdu = PduEvent { event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), + room_id: (**room_id).clone(), sender: sender.to_owned(), origin_server_ts: utils::millis_since_unix_epoch() .try_into() @@ -855,24 +853,18 @@ impl Service { /// Creates a new persisted data unit and adds it to a room. This function /// takes a roomid_mutex_state, meaning that only this function is able /// to mutate the room state. - #[tracing::instrument(skip(self, state_lock))] + #[tracing::instrument(skip(self))] pub(crate) async fn build_and_append_pdu( &self, pdu_builder: PduBuilder, sender: &UserId, - room_id: &RoomId, - // Take mutex guard to make sure users get the room state mutex - state_lock: &MutexGuard<'_, ()>, + room_id: &KeyToken, ) -> Result> { - let (pdu, pdu_json) = self.create_hash_and_sign_event( - pdu_builder, - sender, - room_id, - state_lock, - )?; + let (pdu, pdu_json) = + self.create_hash_and_sign_event(pdu_builder, sender, room_id)?; if let Some(admin_room) = services().admin.get_admin_room()? { - if admin_room == room_id { + if admin_room == **room_id { match pdu.event_type() { TimelineEventType::RoomEncryption => { warn!("Encryption is not allowed in the admins room"); @@ -1052,18 +1044,14 @@ impl Service { // Since this PDU references all pdu_leaves we can update the // leaves of the room vec![(*pdu.event_id).to_owned()], - state_lock, + room_id, ) .await?; // We set the room state after inserting the pdu, so that we never have // a moment in time where events in the current room state do // not exist - services().rooms.state.set_room_state( - room_id, - statehashid, - state_lock, - )?; + services().rooms.state.set_room_state(room_id, statehashid)?; let mut servers: HashSet = services() .rooms @@ -1103,9 +1091,10 @@ impl Service { new_room_leaves: Vec, state_ids_compressed: Arc>, soft_fail: bool, - // Take mutex guard to make sure users get the room state mutex - state_lock: &MutexGuard<'_, ()>, + room_id: &KeyToken, ) -> Result>> { + assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed"); + // We append to state before appending the pdu, so we don't have a // moment in time with the pdu without it's state. This is okay // because append_pdu can't fail. @@ -1119,19 +1108,18 @@ impl Service { services() .rooms .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services().rooms.state.set_forward_extremities( - &pdu.room_id, - new_room_leaves, - state_lock, - )?; + .mark_as_referenced(room_id, &pdu.prev_events)?; + services() + .rooms + .state + .set_forward_extremities(room_id, new_room_leaves)?; return Ok(None); } let pdu_id = services() .rooms .timeline - .append_pdu(pdu, pdu_json, new_room_leaves, state_lock) + .append_pdu(pdu, pdu_json, new_room_leaves, room_id) .await?; Ok(Some(pdu_id))