From b1e14fad5c4863b349b56fcb5e056a1a9adacc04 Mon Sep 17 00:00:00 2001 From: Charles Hall Date: Sat, 2 Nov 2024 18:26:00 -0700 Subject: [PATCH] add option to recompute room state Hopefully we can make use of this state resolution code for more stuff in the future. --- Cargo.lock | 63 +++ Cargo.toml | 2 + src/cli.rs | 8 + src/cli/get_room_states.rs | 7 +- src/cli/get_room_states/recompute.rs | 550 +++++++++++++++++++++++++++ 5 files changed, 629 insertions(+), 1 deletion(-) create mode 100644 src/cli/get_room_states/recompute.rs diff --git a/Cargo.lock b/Cargo.lock index d7ccc059..934c8c9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,6 +525,31 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crypto-common" version = "0.1.6" @@ -734,6 +759,12 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.35" @@ -915,12 +946,14 @@ dependencies = [ "opentelemetry-prometheus", "opentelemetry_sdk", "parking_lot", + "petgraph", "phf", "pin-project-lite", "predicates", "prometheus", "proxy-header", "rand", + "rayon", "regex", "reqwest", "ring", @@ -1982,6 +2015,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2299,6 +2342,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.7" diff --git a/Cargo.toml b/Cargo.toml index 053f605a..3b2dec61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,11 +111,13 @@ opentelemetry-otlp = "0.17.0" opentelemetry-prometheus = "0.17.0" opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio"] } parking_lot = { version = "0.12.3", optional = true } +petgraph = "0.6.5" phf = { version = "0.11.2", features = ["macros"] } pin-project-lite = "0.2.15" prometheus = "0.13.4" proxy-header = { version = "0.1.2", features = ["tokio"] } rand = "0.8.5" +rayon = "1.10.0" regex = "1.11.1" reqwest = { version = "0.12.9", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } ring = "0.17.8" diff --git a/src/cli.rs b/src/cli.rs index 5742986c..110359ae 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -99,6 +99,14 @@ pub(crate) struct GetRoomStatesArgs { #[clap(flatten)] pub(crate) config: ConfigArg, + /// Recompute the room state instead of getting it from the caches + /// + /// Note that this can be VERY SLOW for some rooms. `#matrix:matrix.org` + /// has been seen to take around 30 minutes to solve on a 7950X, for + /// example. + #[clap(long, short)] + pub(crate) recompute: bool, + #[clap(flatten)] observability: ObservabilityArgs, } diff --git a/src/cli/get_room_states.rs b/src/cli/get_room_states.rs index b65d5673..e82a2f34 100644 --- a/src/cli/get_room_states.rs +++ b/src/cli/get_room_states.rs @@ -13,6 +13,7 @@ use crate::{ }; mod cache; +mod recompute; /// Serializable information about a state event #[derive(Serialize, PartialEq, Eq)] @@ -63,7 +64,11 @@ pub(crate) async fn run( db.apply_migrations().await.map_err(Error::Database)?; - let room_states = cache::get_room_states().await; + let room_states = if args.recompute { + recompute::get_room_states(db) + } else { + cache::get_room_states().await + }; serde_json::to_writer(std::io::stdout(), &room_states)?; diff --git a/src/cli/get_room_states/recompute.rs b/src/cli/get_room_states/recompute.rs new file mode 100644 index 00000000..c5f19b54 --- /dev/null +++ b/src/cli/get_room_states/recompute.rs @@ -0,0 +1,550 @@ +//! Get room states by recomputing them + +use std::{ + collections::{HashMap, HashSet}, + error::Error, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use petgraph::{ + graph::NodeIndex, + visit::{depth_first_search, DfsEvent, Reversed}, + Direction, +}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use ruma::{ + events::{StateEventType, TimelineEventType}, + state_res::StateMap, + EventId, RoomId, RoomVersionId, UserId, +}; +use serde::Deserialize; +use tracing as t; + +use super::StateEvent; +use crate::{database::KeyValueDatabase, utils, PduEvent}; + +/// A weightless unit type to use for graphs with unweighted edges +struct WeightlessEdge; + +/// A directed graph with unweighted edges +type DiGraph = petgraph::graph::DiGraph; + +/// Extracts the `.content.room_version` PDU field +#[derive(Deserialize)] +struct ExtractRoomVersion { + /// Room version + room_version: RoomVersionId, +} + +/// Information about a state event node for state resolution +#[derive(Clone)] +struct StateResolutionNode { + /// This event's ID + event_id: Arc, + + /// This event's type + kind: StateEventType, + + /// This event's `state_key` + state_key: String, +} + +/// Information about a state event's edges for state resolution +struct StateResolutionEdges { + /// This event's ID + event_id: Arc, + + /// The ID of the room this event belongs to + room_id: Arc, + + /// This event's `auth_events` + auth_events: Vec>, +} + +/// Get the state of all rooms +pub(crate) fn get_room_states(db: &KeyValueDatabase) -> Vec { + let state_resolution_edges = get_state_resolution_edges(db); + let graphs = get_state_event_graphs(db, state_resolution_edges); + + let states = resolve_room_states(db, graphs); + + let mut serializable_state = Vec::new(); + + for (_, state) in states { + serializable_state.extend(state.into_iter().filter_map( + |((kind, key), event)| { + let event = match get_either_pdu_by_event_id(db, &event) { + Ok(Some(x)) => Arc::new(x), + Ok(None) => { + t::warn!(event_id = %event, "Unknown event, omitting"); + return None; + } + Err(error) => { + t::warn!( + %error, + event_id = %event, + "Failed to get event, omitting", + ); + return None; + } + }; + + let x = StateEvent { + kind, + key, + event, + }; + + Some(x) + }, + )); + } + + serializable_state.sort_unstable(); + + serializable_state +} + +/// Resolve the current state of all rooms +/// +/// # Arguments +/// +/// * `db`: The database. +/// * `graphs`: A map of room IDs to a directed graph of their state events. The +/// edges in the graph must go from an event's `auth_events` to that event. +#[t::instrument(skip(db, graphs))] +fn resolve_room_states( + db: &KeyValueDatabase, + graphs: HashMap, DiGraph>, +) -> HashMap, StateMap>> { + let span = t::Span::current(); + + let todo = AtomicUsize::new(graphs.len()); + + // State resolution is slow so parallelize it by room + graphs + .into_par_iter() + .filter_map(|(room_id, graph)| { + let _enter = span.enter(); + + let state_map = resolve_room_state(db, graph, &room_id); + + todo.fetch_sub(1, Ordering::SeqCst); + t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining"); + + state_map.map(|x| (room_id, x)) + }) + .fold(HashMap::new, |mut acc, (room_id, state_map)| { + acc.insert(room_id, state_map); + acc + }) + .reduce(HashMap::new, |mut x, y| { + // Room IDs should be unique per chunk so there should be no + // collisions when combining this way + x.extend(y); + x + }) +} + +/// Find a single `m.room.create` event and return its room version +fn get_room_version( + db: &KeyValueDatabase, + graph: &DiGraph, +) -> Option { + let mut create_events = HashSet::new(); + for i in graph.node_indices() { + let n = &graph[i]; + + if n.kind != StateEventType::RoomCreate { + continue; + } + + create_events.insert(n.event_id.clone()); + } + + if create_events.len() > 1 { + t::warn!( + "Multiple `m.room.create` events in graph, rejecting all of them", + ); + return None; + } + + let Some(event_id) = create_events.into_iter().next() else { + t::warn!("No `m.room.create` event in graph"); + return None; + }; + + let pdu = match get_either_pdu_by_event_id(db, &event_id) { + Ok(Some(x)) => x, + Ok(None) => { + t::error!("No `m.room.create` PDU in database"); + return None; + } + Err(error) => { + t::error!( + %error, + "Failed to get `m.room.create` PDU from database", + ); + return None; + } + }; + + let content = + match serde_json::from_str::(pdu.content.get()) { + Ok(x) => x, + Err(error) => { + t::error!( + %error, + %event_id, + "Failed to extract `.content.room_version` from PDU", + ); + return None; + } + }; + + Some(content.room_version) +} + +/// Resolve the state of a single room +/// +/// # Arguments +/// +/// * `db`: The database. +/// * `graphs`: A map of room IDs to a directed graph of their state events. The +/// edges in the graph must go from an event's `auth_events` to that event. +/// * `room_id`: The ID of the room being resolved. Only used for tracing. +#[t::instrument(skip(db, graph))] +fn resolve_room_state( + db: &KeyValueDatabase, + graph: DiGraph, + room_id: &RoomId, +) -> Option>> { + let Some(room_version) = get_room_version(db, &graph) else { + t::error!("Couldn't get room version, skipping this room"); + return None; + }; + + let (state_sets, auth_chain_sets) = graph + .node_indices() + .map(|x| { + let StateResolutionNode { + event_id, + kind, + state_key, + } = graph[x].clone(); + + // At this point, we don't know any valid groupings of state events, + // so handle each state event by itself by making a separate map for + // each one. + let mut state_set = StateMap::new(); + state_set.insert((kind, state_key), event_id); + + let mut auth_chain_set = HashSet::new(); + depth_first_search(Reversed(&graph), [x], |event| { + if let DfsEvent::Discover(node_index, _) = event { + auth_chain_set.insert(graph[node_index].event_id.clone()); + } + }); + + (state_set, auth_chain_set) + }) + .collect::<(Vec<_>, Vec<_>)>(); + + let res = ruma::state_res::resolve( + &room_version, + &state_sets, + auth_chain_sets, + |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), + ); + + match res { + Ok(x) => { + t::info!("Done"); + Some(x) + } + Err(error) => { + t::error!(%error, "Failed"); + None + } + } +} + +/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_accepted_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_key) = + db.eventid_pduid.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to get PDU ID for event ID", + ); + })? + else { + // This is normal, perhaps the requested event is an outlier + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU ID for event ID", + ); + return Ok(None); + }; + + let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to get PDU for PDU ID", + ); + })? + else { + t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID"); + return Err("No PDU for PDU ID".into()); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an outlier [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_outlier_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_bytes) = + db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to PDU for event ID", + ); + })? + else { + // This is normal, perhaps we just don't have this event + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU for event ID" + ); + return Ok(None); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database +/// +/// Tries searching accepted PDUs first, then outlier PDUs. +#[t::instrument(skip(db))] +fn get_either_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else { + return get_outlier_pdu_by_event_id(db, event_id); + }; + + Ok(Some(pdu)) +} + +/// Get the edges in the graph of state events for state resolution +#[t::instrument(skip(db))] +fn get_state_resolution_edges( + db: &KeyValueDatabase, +) -> impl Iterator + '_ { + let filter_map = |key: Vec, value: Vec, map: &str| { + let Ok(pdu) = serde_json::from_slice::(&value) else { + t::error!( + %map, + key = utils::u8_slice_to_hex(&key), + "Failed to deserialize PDU, skipping it", + ); + return None; + }; + + let Some(state_key) = pdu.state_key else { + // Filter out non-state events + return None; + }; + + // Ruma fails the entire state resolution if it sees an event whose + // `state_key` it needs to parse as an ID and the parsing fails, so + // we pre-emptively drop these events. Real-world cases where this + // has happened are detailed on the issue linked below. + // + // + if pdu.kind == TimelineEventType::RoomMember + && <&UserId>::try_from(&*state_key).is_err() + { + t::warn!( + %map, + key = utils::u8_slice_to_hex(&key), + event_id = %pdu.event_id, + "Dropping event that could cause state resolution to fail", + ); + return None; + } + + Some(StateResolutionEdges { + event_id: pdu.event_id, + room_id: >::from(pdu.room_id), + auth_events: pdu.auth_events, + }) + }; + + std::iter::empty() + .chain(db.pduid_pdu.iter().filter_map(move |(key, value)| { + filter_map(key, value, "pduid_pdu") + })) + .chain(db.eventid_outlierpdu.iter().filter_map(move |(key, value)| { + filter_map(key, value, "eventid_outlierpdu") + })) +} + +/// Get a directed graph of state events for all rooms +/// +/// The edges in the directed graph go from an event's `auth_events` to that +/// event. +#[t::instrument(skip(db, state_resolution_edges))] +fn get_state_event_graphs( + db: &KeyValueDatabase, + state_resolution_edges: I, +) -> HashMap, DiGraph> +where + I: Iterator, +{ + // Avoid inserting the same event ID twice; instead, mutate its edges + let mut visited = HashMap::<(Arc, Arc), NodeIndex>::new(); + + // Graph of event IDs + let mut graphs = HashMap::, DiGraph>>::new(); + + for StateResolutionEdges { + event_id, + room_id, + auth_events, + } in state_resolution_edges + { + let graph = graphs.entry(room_id.clone()).or_default(); + + let this = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + for event_id in auth_events { + let prev = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + graph.add_edge(prev, this, WeightlessEdge); + } + } + + // Convert the graph of event IDs into one of `StateNode`s + let graphs = graphs + .into_iter() + .map(|(room_id, graph)| { + let graph = graph.filter_map( + |index, event_id| { + to_state_resolution_node(db, index, event_id, &graph) + }, + |_, &WeightlessEdge| Some(WeightlessEdge), + ); + + (room_id, graph) + }) + .collect::>(); + + // Record some stats for fun + t::info!( + rooms = graphs.len(), + total_nodes = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.node_indices().count(); + acc + }), + total_edges = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.edge_indices().count(); + acc + }), + "Done", + ); + + graphs +} + +/// `filter_map` a graph of [`EventId`] nodes to one of [`StateResolutionNode`]s +fn to_state_resolution_node( + db: &KeyValueDatabase, + index: NodeIndex, + event_id: &Arc, + graph: &DiGraph>, +) -> Option { + let pdu = match get_either_pdu_by_event_id(db, event_id) { + Ok(Some(x)) => x, + Ok(None) => { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + + t::warn!( + ?dependents, + missing = %event_id, + "Missing `auth_event` depended on by one or more state events", + ); + + return None; + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU by event ID"); + return None; + } + }; + + let Some(state_key) = pdu.state_key else { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + t::warn!( + ?dependents, + not_state = %event_id, + "Event depended on by one or more state events isn't a state event", + ); + return None; + }; + + Some(StateResolutionNode { + event_id: event_id.clone(), + kind: pdu.kind.to_string().into(), + state_key, + }) +}