//! Get room states by recomputing them use std::{ collections::{HashMap, HashSet}, error::Error, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, }; use petgraph::{ graph::NodeIndex, visit::{depth_first_search, DfsEvent, Reversed}, Direction, }; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use ruma::{ events::{StateEventType, TimelineEventType}, state_res::StateMap, EventId, RoomId, RoomVersionId, UserId, }; use serde::Deserialize; use tracing as t; use super::StateEvent; use crate::{database::KeyValueDatabase, utils, PduEvent}; /// A weightless unit type to use for graphs with unweighted edges struct WeightlessEdge; /// A directed graph with unweighted edges type DiGraph = petgraph::graph::DiGraph; /// Extracts the `.content.room_version` PDU field #[derive(Deserialize)] struct ExtractRoomVersion { /// Room version room_version: RoomVersionId, } /// Information about a state event node for state resolution #[derive(Clone)] struct StateResolutionNode { /// This event's ID event_id: Arc, /// This event's type kind: StateEventType, /// This event's `state_key` state_key: String, } /// Information about a state event's edges for state resolution struct StateResolutionEdges { /// This event's ID event_id: Arc, /// The ID of the room this event belongs to room_id: Arc, /// This event's `auth_events` auth_events: Vec>, } /// Get the state of all rooms pub(crate) fn get_room_states(db: &KeyValueDatabase) -> Vec { let state_resolution_edges = get_state_resolution_edges(db); let graphs = get_state_event_graphs(db, state_resolution_edges); let states = resolve_room_states(db, graphs); let mut serializable_state = Vec::new(); for (_, state) in states { serializable_state.extend(state.into_iter().filter_map( |((kind, key), event)| { let event = match get_either_pdu_by_event_id(db, &event) { Ok(Some(x)) => Arc::new(x), Ok(None) => { t::warn!(event_id = %event, "Unknown event, omitting"); return None; } Err(error) => { t::warn!( %error, event_id = %event, "Failed to get event, omitting", ); return None; } }; let x = StateEvent { kind, key, event, }; Some(x) }, )); } serializable_state.sort_unstable(); serializable_state } /// Resolve the current state of all rooms /// /// # Arguments /// /// * `db`: The database. /// * `graphs`: A map of room IDs to a directed graph of their state events. The /// edges in the graph must go from an event's `auth_events` to that event. #[t::instrument(skip(db, graphs))] fn resolve_room_states( db: &KeyValueDatabase, graphs: HashMap, DiGraph>, ) -> HashMap, StateMap>> { let span = t::Span::current(); let todo = AtomicUsize::new(graphs.len()); // State resolution is slow so parallelize it by room graphs .into_par_iter() .filter_map(|(room_id, graph)| { let _enter = span.enter(); let state_map = resolve_room_state(db, graph, &room_id); todo.fetch_sub(1, Ordering::SeqCst); t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining"); state_map.map(|x| (room_id, x)) }) .fold(HashMap::new, |mut acc, (room_id, state_map)| { acc.insert(room_id, state_map); acc }) .reduce(HashMap::new, |mut x, y| { // Room IDs should be unique per chunk so there should be no // collisions when combining this way x.extend(y); x }) } /// Find a single `m.room.create` event and return its room version fn get_room_version( db: &KeyValueDatabase, graph: &DiGraph, ) -> Option { let mut create_events = HashSet::new(); for i in graph.node_indices() { let n = &graph[i]; if n.kind != StateEventType::RoomCreate { continue; } create_events.insert(n.event_id.clone()); } if create_events.len() > 1 { t::warn!( "Multiple `m.room.create` events in graph, rejecting all of them", ); return None; } let Some(event_id) = create_events.into_iter().next() else { t::warn!("No `m.room.create` event in graph"); return None; }; let pdu = match get_either_pdu_by_event_id(db, &event_id) { Ok(Some(x)) => x, Ok(None) => { t::error!("No `m.room.create` PDU in database"); return None; } Err(error) => { t::error!( %error, "Failed to get `m.room.create` PDU from database", ); return None; } }; let content = match serde_json::from_str::(pdu.content.get()) { Ok(x) => x, Err(error) => { t::error!( %error, %event_id, "Failed to extract `.content.room_version` from PDU", ); return None; } }; Some(content.room_version) } /// Resolve the state of a single room /// /// # Arguments /// /// * `db`: The database. /// * `graphs`: A map of room IDs to a directed graph of their state events. The /// edges in the graph must go from an event's `auth_events` to that event. /// * `room_id`: The ID of the room being resolved. Only used for tracing. #[t::instrument(skip(db, graph))] fn resolve_room_state( db: &KeyValueDatabase, graph: DiGraph, room_id: &RoomId, ) -> Option>> { let Some(room_version) = get_room_version(db, &graph) else { t::error!("Couldn't get room version, skipping this room"); return None; }; let (state_sets, auth_chain_sets) = graph .node_indices() .map(|x| { let StateResolutionNode { event_id, kind, state_key, } = graph[x].clone(); // At this point, we don't know any valid groupings of state events, // so handle each state event by itself by making a separate map for // each one. let mut state_set = StateMap::new(); state_set.insert((kind, state_key), event_id); let mut auth_chain_set = HashSet::new(); depth_first_search(Reversed(&graph), [x], |event| { if let DfsEvent::Discover(node_index, _) = event { auth_chain_set.insert(graph[node_index].event_id.clone()); } }); (state_set, auth_chain_set) }) .collect::<(Vec<_>, Vec<_>)>(); let res = ruma::state_res::resolve( &room_version, &state_sets, auth_chain_sets, |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), ); match res { Ok(x) => { t::info!("Done"); Some(x) } Err(error) => { t::error!(%error, "Failed"); None } } } /// Look up an accepted [`PduEvent`] by its [`EventId`] in the database #[t::instrument(skip(db))] fn get_accepted_pdu_by_event_id( db: &KeyValueDatabase, event_id: &EventId, ) -> Result, Box> { let event_id_key = event_id.as_bytes(); let Some(pdu_key) = db.eventid_pduid.get(event_id_key).inspect_err(|error| { t::error!( %error, key = utils::u8_slice_to_hex(event_id_key), "Failed to get PDU ID for event ID", ); })? else { // This is normal, perhaps the requested event is an outlier t::debug!( key = utils::u8_slice_to_hex(event_id_key), "No PDU ID for event ID", ); return Ok(None); }; let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| { t::error!( %error, key = utils::u8_slice_to_hex(&pdu_key), "Failed to get PDU for PDU ID", ); })? else { t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID"); return Err("No PDU for PDU ID".into()); }; let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( |error| { t::error!( %error, key = utils::u8_slice_to_hex(&pdu_key), "Failed to deserialize PDU", ); }, )?; Ok(Some(pdu)) } /// Look up an outlier [`PduEvent`] by its [`EventId`] in the database #[t::instrument(skip(db))] fn get_outlier_pdu_by_event_id( db: &KeyValueDatabase, event_id: &EventId, ) -> Result, Box> { let event_id_key = event_id.as_bytes(); let Some(pdu_bytes) = db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| { t::error!( %error, key = utils::u8_slice_to_hex(event_id_key), "Failed to PDU for event ID", ); })? else { // This is normal, perhaps we just don't have this event t::debug!( key = utils::u8_slice_to_hex(event_id_key), "No PDU for event ID" ); return Ok(None); }; let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( |error| { t::error!( %error, key = utils::u8_slice_to_hex(event_id_key), "Failed to deserialize PDU", ); }, )?; Ok(Some(pdu)) } /// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database /// /// Tries searching accepted PDUs first, then outlier PDUs. #[t::instrument(skip(db))] fn get_either_pdu_by_event_id( db: &KeyValueDatabase, event_id: &EventId, ) -> Result, Box> { let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else { return get_outlier_pdu_by_event_id(db, event_id); }; Ok(Some(pdu)) } /// Get the edges in the graph of state events for state resolution #[t::instrument(skip(db))] fn get_state_resolution_edges( db: &KeyValueDatabase, ) -> impl Iterator + '_ { let filter_map = |key: Vec, value: Vec, map: &str| { let Ok(pdu) = serde_json::from_slice::(&value) else { t::error!( %map, key = utils::u8_slice_to_hex(&key), "Failed to deserialize PDU, skipping it", ); return None; }; let Some(state_key) = pdu.state_key else { // Filter out non-state events return None; }; // Ruma fails the entire state resolution if it sees an event whose // `state_key` it needs to parse as an ID and the parsing fails, so // we pre-emptively drop these events. Real-world cases where this // has happened are detailed on the issue linked below. // // if pdu.kind == TimelineEventType::RoomMember && <&UserId>::try_from(&*state_key).is_err() { t::warn!( %map, key = utils::u8_slice_to_hex(&key), event_id = %pdu.event_id, "Dropping event that could cause state resolution to fail", ); return None; } Some(StateResolutionEdges { event_id: pdu.event_id, room_id: >::from(pdu.room_id), auth_events: pdu.auth_events, }) }; std::iter::empty() .chain(db.pduid_pdu.iter().filter_map(move |(key, value)| { filter_map(key, value, "pduid_pdu") })) .chain(db.eventid_outlierpdu.iter().filter_map(move |(key, value)| { filter_map(key, value, "eventid_outlierpdu") })) } /// Get a directed graph of state events for all rooms /// /// The edges in the directed graph go from an event's `auth_events` to that /// event. #[t::instrument(skip(db, state_resolution_edges))] fn get_state_event_graphs( db: &KeyValueDatabase, state_resolution_edges: I, ) -> HashMap, DiGraph> where I: Iterator, { // Avoid inserting the same event ID twice; instead, mutate its edges let mut visited = HashMap::<(Arc, Arc), NodeIndex>::new(); // Graph of event IDs let mut graphs = HashMap::, DiGraph>>::new(); for StateResolutionEdges { event_id, room_id, auth_events, } in state_resolution_edges { let graph = graphs.entry(room_id.clone()).or_default(); let this = *visited .entry((room_id.clone(), event_id.clone())) .or_insert_with(|| graph.add_node(event_id)); for event_id in auth_events { let prev = *visited .entry((room_id.clone(), event_id.clone())) .or_insert_with(|| graph.add_node(event_id)); graph.add_edge(prev, this, WeightlessEdge); } } // Convert the graph of event IDs into one of `StateNode`s let graphs = graphs .into_iter() .map(|(room_id, graph)| { let graph = graph.filter_map( |index, event_id| { to_state_resolution_node(db, index, event_id, &graph) }, |_, &WeightlessEdge| Some(WeightlessEdge), ); (room_id, graph) }) .collect::>(); // Record some stats for fun t::info!( rooms = graphs.len(), total_nodes = graphs.iter().fold(0, |mut acc, x| { acc += x.1.node_indices().count(); acc }), total_edges = graphs.iter().fold(0, |mut acc, x| { acc += x.1.edge_indices().count(); acc }), "Done", ); graphs } /// `filter_map` a graph of [`EventId`] nodes to one of [`StateResolutionNode`]s fn to_state_resolution_node( db: &KeyValueDatabase, index: NodeIndex, event_id: &Arc, graph: &DiGraph>, ) -> Option { let pdu = match get_either_pdu_by_event_id(db, event_id) { Ok(Some(x)) => x, Ok(None) => { let mut dependents = Vec::new(); for i in graph.neighbors_directed(index, Direction::Outgoing) { dependents.push(&graph[i]); } t::warn!( ?dependents, missing = %event_id, "Missing `auth_event` depended on by one or more state events", ); return None; } Err(error) => { t::warn!(%error, %event_id, "Failed to get PDU by event ID"); return None; } }; let Some(state_key) = pdu.state_key else { let mut dependents = Vec::new(); for i in graph.neighbors_directed(index, Direction::Outgoing) { dependents.push(&graph[i]); } t::warn!( ?dependents, not_state = %event_id, "Event depended on by one or more state events isn't a state event", ); return None; }; Some(StateResolutionNode { event_id: event_id.clone(), kind: pdu.kind.to_string().into(), state_key, }) }