Use TokenSet for roomid_mutex_state

This commit is contained in:
Lambda 2024-06-23 18:39:56 +00:00
parent 07b5233980
commit 34ccb2cd06
15 changed files with 243 additions and 429 deletions

View file

@ -265,17 +265,11 @@ impl Service {
}
};
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(grapevine_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(grapevine_room.clone())
.await;
services()
.rooms
@ -290,8 +284,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
grapevine_room,
&state_lock,
&room_token,
)
.await
.unwrap();
@ -1220,16 +1213,11 @@ impl Service {
services().rooms.short.get_or_create_shortroomid(&room_id)?;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
services().users.create(&services().globals.admin_bot_user_id, None)?;
@ -1268,8 +1256,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1298,8 +1285,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1323,8 +1309,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1344,8 +1329,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1367,8 +1351,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1388,8 +1371,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1411,8 +1393,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1434,8 +1415,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1458,8 +1438,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1495,16 +1474,11 @@ impl Service {
displayname: String,
) -> Result<()> {
if let Some(room_id) = services().admin.get_admin_room()? {
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
// Use the server user to grant the new admin's power level
// Invite and join the real user
@ -1530,8 +1504,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
services()
@ -1556,8 +1529,7 @@ impl Service {
redacts: None,
},
user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
@ -1585,8 +1557,7 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_id,
&state_lock,
&room_token,
)
.await?;
}

View file

@ -33,8 +33,11 @@ use tracing::{error, Instrument};
use trust_dns_resolver::TokioAsyncResolver;
use crate::{
api::server_server::FedDest, observability::FilterReloadHandles, services,
utils::on_demand_hashmap::OnDemandHashMap, Config, Error, Result,
api::server_server::FedDest,
observability::FilterReloadHandles,
services,
utils::on_demand_hashmap::{OnDemandHashMap, TokenSet},
Config, Error, Result,
};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
@ -42,6 +45,15 @@ type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
// Time if last failed try, number of failed tries
type RateLimitState = (Instant, u32);
// Markers for
// [`Service::roomid_mutex_state`]/[`Service::roomid_mutex_insert`]/
// [`Service::roomid_mutex_federation`]
pub(crate) mod marker {
pub(crate) enum State {}
pub(crate) enum Insert {}
pub(crate) enum Federation {}
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub(crate) reload_handles: FilterReloadHandles,
@ -69,7 +81,7 @@ pub(crate) struct Service {
OnDemandHashMap<OwnedServerName, Semaphore>,
pub(crate) roomid_mutex_insert:
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: TokenSet<OwnedRoomId, marker::State>,
// this lock will be held longer
pub(crate) roomid_mutex_federation:
@ -266,7 +278,7 @@ impl Service {
servername_ratelimiter: OnDemandHashMap::new(
"servername_ratelimiter".to_owned(),
),
roomid_mutex_state: RwLock::new(HashMap::new()),
roomid_mutex_state: TokenSet::new("roomid_mutex_state".to_owned()),
roomid_mutex_insert: RwLock::new(HashMap::new()),
roomid_mutex_federation: RwLock::new(HashMap::new()),
roomid_federationhandletime: RwLock::new(HashMap::new()),

View file

@ -992,16 +992,11 @@ impl Service {
// 13. Use state resolution to find new room state
// We start looking at current room state now, so lets lock the room
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
// Now we calculate the set of extremities this room has after the
// incoming event has been applied. We start with the previous
@ -1070,7 +1065,7 @@ impl Service {
services()
.rooms
.state
.force_state(room_id, sstatehash, new, removed, &state_lock)
.force_state(&room_token, sstatehash, new, removed)
.await?;
}
@ -1088,7 +1083,7 @@ impl Service {
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
&room_token,
)
.await?;
@ -1121,14 +1116,14 @@ impl Service {
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
&room_token,
)
.await?;
debug!("Appended incoming pdu");
// Event has passed all auth/stateres checks
drop(state_lock);
drop(room_token);
Ok(pdu_id)
}

View file

@ -13,16 +13,18 @@ use ruma::{
},
serde::Raw,
state_res::{self, StateMap},
EventId, OwnedEventId, RoomId, RoomVersionId, UserId,
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
};
use serde::Deserialize;
use tokio::sync::MutexGuard;
use tracing::warn;
use super::state_compressor::CompressedStateEvent;
use crate::{
service::globals::marker,
services,
utils::{calculate_hash, debug_slice_truncated},
utils::{
calculate_hash, debug_slice_truncated, on_demand_hashmap::KeyToken,
},
Error, PduEvent, Result,
};
@ -32,20 +34,13 @@ pub(crate) struct Service {
impl Service {
/// Set the room to the given statehash and update caches.
#[tracing::instrument(skip(
self,
statediffnew,
_statediffremoved,
state_lock
))]
#[tracing::instrument(skip(self, statediffnew, _statediffremoved))]
pub(crate) async fn force_state(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| {
services()
@ -116,7 +111,7 @@ impl Service {
services().rooms.state_cache.update_joined_count(room_id)?;
self.db.set_room_state(room_id, shortstatehash, state_lock)?;
self.db.set_room_state(room_id, shortstatehash)?;
Ok(())
}
@ -325,12 +320,10 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn set_room_state(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
shortstatehash: u64,
// Take mutex guard to make sure users get the room state mutex
mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
self.db.set_room_state(room_id, shortstatehash, mutex_lock)
self.db.set_room_state(room_id, shortstatehash)
}
/// Returns the room's version.
@ -383,17 +376,15 @@ impl Service {
}
#[tracing::instrument(
skip(self, event_ids, state_lock),
skip(self, event_ids),
fields(event_ids = debug_slice_truncated(&event_ids, 5)),
)]
pub(crate) fn set_forward_extremities(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
event_ids: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
self.db.set_forward_extremities(room_id, event_ids, state_lock)
self.db.set_forward_extremities(room_id, event_ids)
}
/// This fetches auth events from the current state.

View file

@ -1,9 +1,10 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId};
use crate::Result;
use crate::{
service::globals::marker, utils::on_demand_hashmap::KeyToken, Result,
};
pub(crate) trait Data: Send + Sync {
/// Returns the last state hash key added to the db for the given room.
@ -12,10 +13,8 @@ pub(crate) trait Data: Send + Sync {
/// Set the state hash to a new version, but does not update `state_cache`.
fn set_room_state(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
new_shortstatehash: u64,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()>;
/// Associates a state with an event.
@ -34,9 +33,7 @@ pub(crate) trait Data: Send + Sync {
/// Replace the forward extremities of the room.
fn set_forward_extremities(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
event_ids: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()>;
}

View file

@ -20,17 +20,18 @@ use ruma::{
StateEventType,
},
state_res::Event,
EventId, JsOption, OwnedServerName, OwnedUserId, RoomId, ServerName,
UserId,
EventId, JsOption, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::MutexGuard;
use tracing::{error, warn};
use crate::{
observability::{FoundIn, Lookup, METRICS},
service::pdu::PduBuilder,
services, Error, PduEvent, Result,
service::{globals::marker, pdu::PduBuilder},
services,
utils::on_demand_hashmap::KeyToken,
Error, PduEvent, Result,
};
pub(crate) struct Service {
@ -390,10 +391,9 @@ impl Service {
#[tracing::instrument(skip(self), ret(level = "trace"))]
pub(crate) fn user_can_invite(
&self,
room_id: &RoomId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
sender: &UserId,
target_user: &UserId,
state_lock: &MutexGuard<'_, ()>,
) -> bool {
let content =
to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite))
@ -410,7 +410,7 @@ impl Service {
services()
.rooms
.timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.create_hash_and_sign_event(new_event, sender, room_id)
.is_ok()
}

View file

@ -22,11 +22,12 @@ use ruma::{
push::{Action, Ruleset, Tweak},
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId,
OwnedEventId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId,
ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::{MutexGuard, RwLock};
use tokio::sync::RwLock;
use tracing::{error, info, warn};
use super::state_compressor::CompressedStateEvent;
@ -34,10 +35,12 @@ use crate::{
api::server_server,
service::{
appservice::NamespaceRegex,
globals::SigningKeys,
globals::{marker, SigningKeys},
pdu::{EventHash, PduBuilder},
},
services, utils, Error, PduEvent, Result,
services,
utils::{self, on_demand_hashmap::KeyToken},
Error, PduEvent, Result,
};
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
@ -197,9 +200,10 @@ impl Service {
pdu: &PduEvent,
mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<Vec<u8>> {
assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed");
let shortroomid = services()
.rooms
.short
@ -253,11 +257,7 @@ impl Service {
.rooms
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services().rooms.state.set_forward_extremities(
&pdu.room_id,
leaves,
state_lock,
)?;
services().rooms.state.set_forward_extremities(room_id, leaves)?;
let mutex_insert = Arc::clone(
services()
@ -669,9 +669,7 @@ impl Service {
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<(PduEvent, CanonicalJsonObject)> {
let PduBuilder {
event_type,
@ -702,7 +700,7 @@ impl Service {
} else {
Err(Error::InconsistentRoomState(
"non-create event for room of unknown version",
room_id.to_owned(),
(**room_id).clone(),
))
}
})?;
@ -751,7 +749,7 @@ impl Service {
let mut pdu = PduEvent {
event_id: ruma::event_id!("$thiswillbefilledinlater").into(),
room_id: room_id.to_owned(),
room_id: (**room_id).clone(),
sender: sender.to_owned(),
origin_server_ts: utils::millis_since_unix_epoch()
.try_into()
@ -855,24 +853,18 @@ impl Service {
/// Creates a new persisted data unit and adds it to a room. This function
/// takes a roomid_mutex_state, meaning that only this function is able
/// to mutate the room state.
#[tracing::instrument(skip(self, state_lock))]
#[tracing::instrument(skip(self))]
pub(crate) async fn build_and_append_pdu(
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = self.create_hash_and_sign_event(
pdu_builder,
sender,
room_id,
state_lock,
)?;
let (pdu, pdu_json) =
self.create_hash_and_sign_event(pdu_builder, sender, room_id)?;
if let Some(admin_room) = services().admin.get_admin_room()? {
if admin_room == room_id {
if admin_room == **room_id {
match pdu.event_type() {
TimelineEventType::RoomEncryption => {
warn!("Encryption is not allowed in the admins room");
@ -1052,18 +1044,14 @@ impl Service {
// Since this PDU references all pdu_leaves we can update the
// leaves of the room
vec![(*pdu.event_id).to_owned()],
state_lock,
room_id,
)
.await?;
// We set the room state after inserting the pdu, so that we never have
// a moment in time where events in the current room state do
// not exist
services().rooms.state.set_room_state(
room_id,
statehashid,
state_lock,
)?;
services().rooms.state.set_room_state(room_id, statehashid)?;
let mut servers: HashSet<OwnedServerName> = services()
.rooms
@ -1103,9 +1091,10 @@ impl Service {
new_room_leaves: Vec<OwnedEventId>,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<Option<Vec<u8>>> {
assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed");
// We append to state before appending the pdu, so we don't have a
// moment in time with the pdu without it's state. This is okay
// because append_pdu can't fail.
@ -1119,19 +1108,18 @@ impl Service {
services()
.rooms
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services().rooms.state.set_forward_extremities(
&pdu.room_id,
new_room_leaves,
state_lock,
)?;
.mark_as_referenced(room_id, &pdu.prev_events)?;
services()
.rooms
.state
.set_forward_extremities(room_id, new_room_leaves)?;
return Ok(None);
}
let pdu_id = services()
.rooms
.timeline
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)
.append_pdu(pdu, pdu_json, new_room_leaves, room_id)
.await?;
Ok(Some(pdu_id))