use std::{ array, collections::HashSet, mem::size_of, sync::{Arc, Mutex}, }; use lru_cache::LruCache; use ruma::{EventId, RoomId}; use crate::{ observability::{FoundIn, Lookup, METRICS}, services, utils, Result, }; pub(crate) mod data; pub(crate) use data::Data; use data::StateDiff; use super::short::{ShortEventId, ShortStateHash, ShortStateKey}; #[derive(Clone)] pub(crate) struct CompressedStateLayer { pub(crate) shortstatehash: ShortStateHash, pub(crate) full_state: Arc>, pub(crate) added: Arc>, pub(crate) removed: Arc>, } pub(crate) struct Service { pub(crate) db: &'static dyn Data, #[allow(clippy::type_complexity)] pub(crate) stateinfo_cache: Mutex>>, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct CompressedStateEvent { pub(crate) state: ShortStateKey, pub(crate) event: ShortEventId, } impl CompressedStateEvent { pub(crate) fn as_bytes( &self, ) -> [u8; size_of::() + size_of::()] { let mut bytes = self .state .get() .to_be_bytes() .into_iter() .chain(self.event.get().to_be_bytes()); array::from_fn(|_| bytes.next().unwrap()) } pub(crate) fn from_bytes( bytes: [u8; size_of::() + size_of::()], ) -> Self { let state = ShortStateKey::new(u64::from_be_bytes( bytes[0..8].try_into().unwrap(), )); let event = ShortEventId::new(u64::from_be_bytes( bytes[8..16].try_into().unwrap(), )); Self { state, event, } } } impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. #[allow(clippy::type_complexity)] #[tracing::instrument(skip(self))] pub(crate) fn load_shortstatehash_info( &self, shortstatehash: ShortStateHash, ) -> Result> { let lookup = Lookup::StateInfo; if let Some(r) = self.stateinfo_cache.lock().unwrap().get_mut(&shortstatehash) { METRICS.record_lookup(lookup, FoundIn::Cache); return Ok(r.clone()); } let StateDiff { parent, added, removed, } = self.db.get_statediff(shortstatehash)?; let response = if let Some(parent) = parent { let mut response = self.load_shortstatehash_info(parent)?; let mut state = (*response.last().unwrap().full_state).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { state.remove(r); } response.push(CompressedStateLayer { shortstatehash, full_state: Arc::new(state), added, removed: Arc::new(removed), }); response } else { vec![CompressedStateLayer { shortstatehash, full_state: added.clone(), added, removed, }] }; METRICS.record_lookup(lookup, FoundIn::Database); self.stateinfo_cache .lock() .unwrap() .insert(shortstatehash, response.clone()); Ok(response) } // Allowed because this function uses `services()` #[allow(clippy::unused_self)] pub(crate) fn compress_state_event( &self, shortstatekey: ShortStateKey, event_id: &EventId, ) -> Result { Ok(CompressedStateEvent { state: shortstatekey, event: services() .rooms .short .get_or_create_shorteventid(event_id)?, }) } /// Returns shortstatekey, event id // Allowed because this function uses `services()` #[allow(clippy::unused_self)] pub(crate) fn parse_compressed_state_event( &self, compressed_event: &CompressedStateEvent, ) -> Result<(ShortStateKey, Arc)> { Ok(( compressed_event.state, services() .rooms .short .get_eventid_from_short(compressed_event.event)?, )) } /// Creates a new shortstatehash that often is just a diff to an already /// existing shortstatehash and therefore very efficient. /// /// There are multiple layers of diffs. The bottom layer 0 always contains /// the full state. Layer 1 contains diffs to states of layer 0, layer 2 /// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be /// combined with layer n-1 to create a new diff on layer n-1 that's /// based on layer n-2. If that layer is also too big, it will recursively /// fix above layers too. /// /// * `shortstatehash` - Shortstatehash of this state /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid /// * `statediffremoved` - Removed from base. Each vec is /// shortstatekey+shorteventid /// * `diff_to_sibling` - Approximately how much the diff grows each time /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer #[allow(clippy::type_complexity)] #[tracing::instrument(skip( self, statediffnew, statediffremoved, diff_to_sibling, parent_states ))] pub(crate) fn save_state_from_diff( &self, shortstatehash: ShortStateHash, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: Vec, ) -> Result<()> { let diffsum = statediffnew.len() + statediffremoved.len(); if parent_states.len() > 3 { // Number of layers // To many layers, we have to go deeper let parent = parent_states.pop().unwrap(); let mut parent_new = (*parent.added).clone(); let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { // It was not added in the parent and we removed it parent_removed.insert(*removed); } // Else it was added in the parent and we removed it again. We // can forget this change } for new in statediffnew.iter() { if !parent_removed.remove(new) { // It was not touched in the parent and we added it parent_new.insert(*new); } // Else it was removed in the parent and we added it again. We // can forget this change } self.save_state_from_diff( shortstatehash, Arc::new(parent_new), Arc::new(parent_removed), diffsum, parent_states, )?; return Ok(()); } if parent_states.is_empty() { // There is no parent layer, create a new state self.db.save_statediff( shortstatehash, StateDiff { parent: None, added: statediffnew, removed: statediffremoved, }, )?; return Ok(()); }; // Else we have two options. // 1. We add the current diff on top of the parent layer. // 2. We replace a layer above let parent = parent_states.pop().unwrap(); let parent_diff = parent.added.len() + parent.removed.len(); if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { // Diff too big, we replace above layer(s) let mut parent_new = (*parent.added).clone(); let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { // It was not added in the parent and we removed it parent_removed.insert(*removed); } // Else it was added in the parent and we removed it again. We // can forget this change } for new in statediffnew.iter() { if !parent_removed.remove(new) { // It was not touched in the parent and we added it parent_new.insert(*new); } // Else it was removed in the parent and we added it again. We // can forget this change } self.save_state_from_diff( shortstatehash, Arc::new(parent_new), Arc::new(parent_removed), diffsum, parent_states, )?; } else { // Diff small enough, we add diff as layer on top of parent self.db.save_statediff( shortstatehash, StateDiff { parent: Some(parent.shortstatehash), added: statediffnew, removed: statediffremoved, }, )?; } Ok(()) } /// Returns the new shortstatehash, and the state diff from the previous /// room state #[allow(clippy::type_complexity)] #[tracing::instrument(skip(self, new_state_ids_compressed))] pub(crate) fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, ) -> Result<( ShortStateHash, Arc>, Arc>, )> { let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; let state_hash = utils::calculate_hash( new_state_ids_compressed.iter().map(CompressedStateEvent::as_bytes), ); let (new_shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok(( new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()), )); } let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), |p| self.load_shortstatehash_info(p), )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo .full_state .difference(&new_state_ids_compressed) .copied() .collect(); (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { (new_state_ids_compressed, Arc::new(HashSet::new())) }; if !already_existed { self.save_state_from_diff( new_shortstatehash, statediffnew.clone(), statediffremoved.clone(), // every state change is 2 event changes on average 2, states_parents, )?; }; Ok((new_shortstatehash, statediffnew, statediffremoved)) } }