use std::{ collections::{HashMap, HashSet}, iter, sync::Arc, }; use ruma::{ api::client::error::ErrorKind, events::{ room::{create::PreviousRoom, member::MembershipState}, AnyStrippedStateEvent, StateEventType, TimelineEventType, }, room::RoomType, serde::Raw, state_res::{self, StateMap}, EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; use serde::{de::DeserializeOwned, Deserialize}; use tracing::warn; use super::{short::ShortStateHash, state_compressor::CompressedStateEvent}; use crate::{ service::globals::marker, services, utils::{ calculate_hash, debug_slice_truncated, on_demand_hashmap::KeyToken, }, Error, PduEvent, Result, }; mod data; pub(crate) use data::Data; pub(crate) trait ExtractCreateContent: DeserializeOwned { type Extract; fn extract(self) -> Self::Extract; } /// Extract the `room_version` from an `m.room.create` event #[derive(Deserialize)] pub(crate) struct ExtractVersion { room_version: RoomVersionId, } impl ExtractCreateContent for ExtractVersion { type Extract = RoomVersionId; fn extract(self) -> Self::Extract { self.room_version } } /// Extract the `type` from an `m.room.create` event #[derive(Deserialize)] pub(crate) struct ExtractType { #[serde(rename = "type")] kind: Option, } impl ExtractCreateContent for ExtractType { type Extract = Option; fn extract(self) -> Self::Extract { self.kind } } #[derive(Deserialize)] pub(crate) struct ExtractPredecessor { predecessor: Option, } impl ExtractCreateContent for ExtractPredecessor { type Extract = Option; fn extract(self) -> Self::Extract { self.predecessor } } pub(crate) struct Service { pub(crate) db: &'static dyn Data, } impl Service { /// Set the room to the given statehash and update caches. #[tracing::instrument(skip(self, statediffnew, _statediffremoved))] pub(crate) async fn force_state( &self, room_id: &KeyToken, shortstatehash: ShortStateHash, statediffnew: Arc>, _statediffremoved: Arc>, ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { services() .rooms .state_compressor .parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else { continue; }; let pdu: PduEvent = match serde_json::from_str( &serde_json::to_string(&pdu) .expect("CanonicalJsonObj can be serialized to JSON"), ) { Ok(pdu) => pdu, Err(_) => continue, }; match pdu.kind { TimelineEventType::RoomMember => { #[derive(Deserialize)] struct ExtractMembership { membership: MembershipState, } let membership = match serde_json::from_str::( pdu.content.get(), ) { Ok(e) => e.membership, Err(_) => continue, }; let Some(state_key) = pdu.state_key else { continue; }; let Ok(user_id) = UserId::parse(state_key) else { continue; }; services().rooms.state_cache.update_membership( room_id, &user_id, membership, &pdu.sender, None, false, )?; } TimelineEventType::SpaceChild => { services() .rooms .spaces .invalidate_cache(&pdu.room_id) .await; } _ => continue, } } services().rooms.state_cache.update_joined_count(room_id)?; self.db.set_room_state(room_id, shortstatehash)?; Ok(()) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed))] pub(crate) fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let state_hash = calculate_hash( state_ids_compressed.iter().map(CompressedStateEvent::as_bytes), ); let (shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), |p| { services() .rooms .state_compressor .load_shortstatehash_info(p) }, )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo .full_state .difference(&state_ids_compressed) .copied() .collect(); (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { (state_ids_compressed, Arc::new(HashSet::new())) }; services().rooms.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, // high number because no state will be based on this one 1_000_000, states_parents, )?; } self.db.set_event_state(shorteventid, shortstatehash)?; Ok(shortstatehash) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] pub(crate) fn append_to_state( &self, new_pdu: &PduEvent, ) -> Result { let shorteventid = services() .rooms .short .get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.db.set_event_state(shorteventid, p)?; } if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), |p| { services() .rooms .state_compressor .load_shortstatehash_info(p) }, )?; let shortstatekey = services().rooms.short.get_or_create_shortstatekey( &new_pdu.kind.to_string().into(), state_key, )?; let new = services() .rooms .state_compressor .compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() .map(|info| { info.full_state .iter() .find(|compressed| compressed.state == shortstatekey) }) .unwrap_or_default(); if Some(&new) == replaces { return Ok(previous_shortstatehash.expect("must exist")); } // TODO: statehash with deterministic inputs let shortstatehash = ShortStateHash::new(services().globals.next_count()?); let mut statediffnew = HashSet::new(); statediffnew.insert(new); let mut statediffremoved = HashSet::new(); if let Some(replaces) = replaces { statediffremoved.insert(*replaces); } services().rooms.state_compressor.save_state_from_diff( shortstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), 2, states_parents, )?; Ok(shortstatehash) } else { Ok(previous_shortstatehash .expect("first event in room must be a state event")) } } /// Gather events to help the invited user identify the room /// /// Also includes the invite event itself. #[tracing::instrument(skip(self, invite_event))] pub(crate) fn get_helpful_invite_events( &self, invite_event: &PduEvent, ) -> Result>> { let helpful_events = [ (StateEventType::RoomCreate, ""), (StateEventType::RoomJoinRules, ""), (StateEventType::RoomCanonicalAlias, ""), (StateEventType::RoomAvatar, ""), (StateEventType::RoomName, ""), (StateEventType::RoomMember, invite_event.sender.as_str()), ]; let helpful_events = helpful_events.into_iter().filter_map(|(state_type, state_key)| { let state = services().rooms.state_accessor.room_state_get( &invite_event.room_id, &state_type, state_key, ); match state { Ok(Some(x)) => Some(Ok(x.to_stripped_state_event())), Err(x) => Some(Err(x)), Ok(None) => None, } }); let actual_event = iter::once(Ok(invite_event.to_stripped_state_event())); helpful_events.chain(actual_event).collect() } /// Set the state hash to a new version, but does not update state_cache. #[tracing::instrument(skip(self))] pub(crate) fn set_room_state( &self, room_id: &KeyToken, shortstatehash: ShortStateHash, ) -> Result<()> { self.db.set_room_state(room_id, shortstatehash) } /// Returns the value of a field of an `m.room.create` event's `content`. #[tracing::instrument(skip(self))] pub(crate) fn get_create_content( &self, room_id: &RoomId, ) -> Result { let create_event = services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomCreate, "", )?; let content_field = create_event .as_ref() .map(|create_event| { serde_json::from_str::(create_event.content.get()).map_err( |error| { warn!(%error, "Invalid create event"); Error::BadDatabase("Invalid create event in db.") }, ) }) .transpose()? .ok_or_else(|| { Error::BadRequest( ErrorKind::InvalidParam, "No create event found", ) })?; Ok(content_field.extract()) } #[tracing::instrument(skip(self))] pub(crate) fn get_room_shortstatehash( &self, room_id: &RoomId, ) -> Result> { self.db.get_room_shortstatehash(room_id) } #[tracing::instrument(skip(self))] pub(crate) fn get_forward_extremities( &self, room_id: &RoomId, ) -> Result>> { self.db.get_forward_extremities(room_id) } #[tracing::instrument( skip(self, event_ids), fields(event_ids = debug_slice_truncated(&event_ids, 5)), )] pub(crate) fn set_forward_extremities( &self, room_id: &KeyToken, event_ids: Vec, ) -> Result<()> { self.db.set_forward_extremities(room_id, event_ids) } /// This fetches auth events from the current state. #[tracing::instrument(skip(self))] pub(crate) fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { return Ok(HashMap::new()); }; let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) .expect("content is a valid JSON object"); let mut sauthevents = auth_events .into_iter() .filter_map(|(event_type, state_key)| { services() .rooms .short .get_shortstatekey( &event_type.to_string().into(), &state_key, ) .ok() .flatten() .map(|s| (s, (event_type, state_key))) }) .collect::>(); let full_state = services() .rooms .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") .full_state; Ok(full_state .iter() .filter_map(|compressed| { services() .rooms .state_compressor .parse_compressed_state_event(compressed) .ok() }) .filter_map(|(shortstatekey, event_id)| { sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) }) .filter_map(|(k, event_id)| { services() .rooms .timeline .get_pdu(&event_id) .ok() .flatten() .map(|pdu| (k, pdu)) }) .collect()) } }