diff --git a/Cargo.lock b/Cargo.lock index d7ccc059..934c8c9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,6 +525,31 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crypto-common" version = "0.1.6" @@ -734,6 +759,12 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.35" @@ -915,12 +946,14 @@ dependencies = [ "opentelemetry-prometheus", "opentelemetry_sdk", "parking_lot", + "petgraph", "phf", "pin-project-lite", "predicates", "prometheus", "proxy-header", "rand", + "rayon", "regex", "reqwest", "ring", @@ -1982,6 +2015,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2299,6 +2342,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.7" diff --git a/Cargo.toml b/Cargo.toml index e5f30f91..19426912 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,11 +111,13 @@ opentelemetry-otlp = "0.17.0" opentelemetry-prometheus = "0.17.0" opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio"] } parking_lot = { version = "0.12.3", optional = true } +petgraph = "0.6.5" phf = { version = "0.11.2", features = ["macros"] } pin-project-lite = "0.2.15" prometheus = "0.13.4" proxy-header = { version = "0.1.2", features = ["tokio"] } rand = "0.8.5" +rayon = "1.10.0" regex = "1.11.1" reqwest = { version = "0.12.9", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } ring = "0.17.8" diff --git a/book/changelog.md b/book/changelog.md index 70af61f4..b54dbbb2 100644 --- a/book/changelog.md +++ b/book/changelog.md @@ -299,3 +299,6 @@ This will be the first release of Grapevine since it was forked from Conduit ([!121](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/121)) 25. Add configuration options to tune the value of each cache individually. ([!124](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/124)) +26. Add subcommand to get the state of all or a subset of rooms either from the + incrementally updated cache or by recomputing the state. + ([!133](https://gitlab.computer.surgery/matrix/grapevine/-/merge_requests/133)) diff --git a/src/cli.rs b/src/cli.rs index 99641790..874b2db3 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -6,6 +6,7 @@ use std::path::PathBuf; use clap::{Parser, Subcommand}; +use ruma::OwnedRoomId; use crate::{ config::{default_tracing_filter, EnvFilterClone, LogFormat}, @@ -13,6 +14,7 @@ use crate::{ }; mod check_config; +mod get_room_states; mod serve; /// Command line arguments @@ -33,6 +35,11 @@ pub(crate) enum Command { /// Check the configuration file for syntax and semantic errors. CheckConfig(CheckConfigArgs), + + /// Write the state of all rooms as JSON to stdout. + /// + /// This is primarily useful for debugging. + GetRoomStates(GetRoomStatesArgs), } #[derive(clap::Args)] @@ -88,6 +95,29 @@ pub(crate) struct ServeArgs { pub(crate) config: ConfigArg, } +#[derive(clap::Args)] +pub(crate) struct GetRoomStatesArgs { + #[clap(flatten)] + pub(crate) config: ConfigArg, + + /// Recompute the room state instead of getting it from the caches + /// + /// Note that this can be VERY SLOW for some rooms. `#matrix:matrix.org` + /// has been seen to take around 30 minutes to solve on a 7950X, for + /// example. + #[clap(long, short)] + pub(crate) recompute: bool, + + /// Limit the output to a subset of rooms by their IDs + /// + /// This option can be specified multiple times. + #[clap(long, short)] + pub(crate) select: Vec, + + #[clap(flatten)] + observability: ObservabilityArgs, +} + impl Args { pub(crate) async fn run(self) -> Result<(), error::Main> { if let Some((format, filter)) = self.command.cli_observability_args() { @@ -99,6 +129,7 @@ impl Args { Command::CheckConfig(args) => { check_config::run(args.config).await?; } + Command::GetRoomStates(args) => get_room_states::run(args).await?, } Ok(()) } @@ -113,6 +144,10 @@ impl Command { args.observability.log_format, args.observability.log_filter.clone(), )), + Command::GetRoomStates(args) => Some(( + args.observability.log_format, + args.observability.log_filter.clone(), + )), Command::Serve(_) => None, } } diff --git a/src/cli/get_room_states.rs b/src/cli/get_room_states.rs new file mode 100644 index 00000000..00eafb2e --- /dev/null +++ b/src/cli/get_room_states.rs @@ -0,0 +1,107 @@ +#![warn(clippy::missing_docs_in_private_items)] + +//! Implementation of the `get-room-states` command + +use std::{cmp::Ordering, collections::BTreeSet, sync::Arc}; + +use ruma::{events::StateEventType, EventId}; +use serde::Serialize; +use tokio::time::Instant; +use tracing as t; + +use super::GetRoomStatesArgs; +use crate::{ + config, database::KeyValueDatabase, error, services, PduEvent, Services, +}; + +mod cache; +mod recompute; + +/// A value in the output map +#[derive(Serialize)] +struct Value { + /// The resolved state of the room + resolved_state: BTreeSet, + + /// The PDUs of the resolved state events and their auth chains + pdus: BTreeSet>, + + /// The edges in the state events' auth chains + edges: BTreeSet, + + /// A graphviz representation of the resolved state and auth events + graphviz: String, +} + +/// An edge in the graph of `auth_events` of the resolved state +#[derive(Serialize, PartialEq, Eq, PartialOrd, Ord)] +struct Edge { + /// The outgoing edge + from: Arc, + + /// The incoming edge + to: Arc, +} + +/// Serializable information about a state event +#[derive(Serialize, PartialEq, Eq)] +struct State { + /// This event's `type` + kind: StateEventType, + + /// This event's `state_key` + key: String, + + /// The ID of this event + event_id: Arc, +} + +impl Ord for State { + fn cmp(&self, other: &Self) -> Ordering { + Ordering::Equal + .then_with(|| self.kind.cmp(&other.kind)) + .then_with(|| self.key.cmp(&other.key)) + .then_with(|| self.event_id.cmp(&other.event_id)) + } +} + +impl PartialOrd for State { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Subcommand entrypoint +pub(crate) async fn run( + args: GetRoomStatesArgs, +) -> Result<(), error::DumpStateCommand> { + use error::DumpStateCommand as Error; + + let config = config::load(args.config.config.as_ref()).await?; + + let db = Box::leak(Box::new( + KeyValueDatabase::load_or_create(&config).map_err(Error::Database)?, + )); + + Services::new(db, config, None) + .map_err(Error::InitializeServices)? + .install(); + + services().globals.err_if_server_name_changed()?; + + db.apply_migrations().await.map_err(Error::Database)?; + + let start = Instant::now(); + + let room_states = if args.recompute { + recompute::get_room_states(db, &args.select) + } else { + cache::get_room_states(db, &args.select).await + }; + + t::info!(duration = ?start.elapsed(), "Finished getting states"); + + serde_json::to_writer(std::io::stdout(), &room_states)?; + + Ok(()) +} diff --git a/src/cli/get_room_states/cache.rs b/src/cli/get_room_states/cache.rs new file mode 100644 index 00000000..d950a38a --- /dev/null +++ b/src/cli/get_room_states/cache.rs @@ -0,0 +1,144 @@ +//! Get room states from the caches + +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::Arc, +}; + +use petgraph::dot::Dot; +use ruma::{OwnedRoomId, RoomId}; +use tracing as t; + +use super::{recompute, Edge, State, Value}; +use crate::{database::KeyValueDatabase, services}; + +/// Get the state of all rooms, or the selected subset +#[t::instrument(skip(db, select))] +pub(crate) async fn get_room_states( + db: &KeyValueDatabase, + select: &[OwnedRoomId], +) -> BTreeMap { + let mut values = BTreeMap::new(); + + for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) { + if !select.is_empty() && !select.contains(&room_id) { + continue; + } + + if let Some(value) = get_room_value(db, &room_id).await { + values.insert(room_id, value); + } + } + + values +} + +/// Get the [`Value`] of the given room +#[t::instrument(skip(db))] +async fn get_room_value( + db: &KeyValueDatabase, + room_id: &RoomId, +) -> Option { + let shortstatehash = + match services().rooms.state.get_room_shortstatehash(room_id) { + Ok(Some(x)) => x, + Ok(None) => { + t::warn!("Missing shortstatehash"); + return None; + } + Err(error) => { + t::warn!(%error, "Failed to get shortstatehash"); + return None; + } + }; + + let state_map = match services() + .rooms + .state_accessor + .state_full(shortstatehash) + .await + { + Ok(x) => x, + Err(error) => { + t::warn!(%error, "Failed to get full state"); + return None; + } + }; + + let resolved_state = state_map + .iter() + .map(|((kind, key), event)| State { + kind: kind.clone(), + key: key.clone(), + event_id: event.event_id.clone(), + }) + .collect(); + + let state_event_ids = + state_map.values().map(|x| x.event_id.clone()).collect::>(); + + let auth_chain = match services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_event_ids.clone()) + .await + .map(Iterator::collect::>) + { + Ok(x) => x, + Err(error) => { + t::warn!(%error, "Failed to get auth chain"); + return None; + } + }; + + let pdus = std::iter::empty() + .chain(auth_chain.into_iter()) + .chain(state_event_ids.into_iter()) + .filter_map(|event_id| { + match services().rooms.timeline.get_pdu(&event_id) { + Ok(Some(x)) => Some(x), + Ok(None) => { + t::warn!(%event_id, "Missing PDU"); + None + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU"); + None + } + } + }) + .collect::>(); + + let edges = pdus.iter().map(|x| recompute::StateResolutionEdges { + event_id: x.event_id.clone(), + room_id: >::from(x.room_id.clone()), + auth_events: x.auth_events.clone(), + }); + + let graph = recompute::get_state_event_graphs(db, edges) + .remove(&>::from(room_id)) + .expect("there should only be a graph for the current room"); + + let edges = graph + .raw_edges() + .iter() + .map(|x| { + let from = graph[x.source()].event_id.clone(); + let to = graph[x.target()].event_id.clone(); + + Edge { + from, + to, + } + }) + .collect(); + + let graphviz = Dot::new(&graph).to_string(); + + Some(Value { + resolved_state, + pdus, + edges, + graphviz, + }) +} diff --git a/src/cli/get_room_states/recompute.rs b/src/cli/get_room_states/recompute.rs new file mode 100644 index 00000000..990e6422 --- /dev/null +++ b/src/cli/get_room_states/recompute.rs @@ -0,0 +1,613 @@ +//! Get room states by recomputing them + +use std::{ + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + error::Error, + fmt, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use petgraph::{ + dot::Dot, + graph::NodeIndex, + visit::{depth_first_search, DfsEvent, Reversed}, + Direction, +}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use ruma::{ + events::{StateEventType, TimelineEventType}, + state_res::StateMap, + EventId, OwnedRoomId, RoomId, RoomVersionId, UserId, +}; +use serde::Deserialize; +use tracing as t; + +use super::{Edge, State, Value}; +use crate::{database::KeyValueDatabase, utils, PduEvent}; + +/// A weightless unit type to use for graphs with unweighted edges +pub(crate) struct WeightlessEdge; + +impl fmt::Display for WeightlessEdge { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +/// A directed graph with unweighted edges +pub(crate) type DiGraph = petgraph::graph::DiGraph; + +/// Extracts the `.content.room_version` PDU field +#[derive(Deserialize)] +struct ExtractRoomVersion { + /// Room version + room_version: RoomVersionId, +} + +/// Information about a state event node for state resolution +#[derive(Clone, Debug)] +pub(crate) struct StateResolutionNode { + /// This event's ID + pub(crate) event_id: Arc, + + /// This event's type + pub(crate) kind: StateEventType, + + /// This event's `state_key` + pub(crate) state_key: String, +} + +impl fmt::Display for StateResolutionNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self:#?}") + } +} + +/// Information about a state event's edges for state resolution +pub(crate) struct StateResolutionEdges { + /// This event's ID + pub(crate) event_id: Arc, + + /// The ID of the room this event belongs to + pub(crate) room_id: Arc, + + /// This event's `auth_events` + pub(crate) auth_events: Vec>, +} + +/// Get the state of all rooms, or the selected subset +pub(crate) fn get_room_states( + db: &KeyValueDatabase, + select: &[OwnedRoomId], +) -> BTreeMap { + let state_resolution_edges = get_state_resolution_edges(db, select); + let graphs = get_state_event_graphs(db, state_resolution_edges); + + resolve_room_states(db, graphs) +} + +/// Resolve the current state of all rooms +/// +/// # Arguments +/// +/// * `db`: The database. +/// * `graphs`: A map of room IDs to a directed graph of their state events. The +/// edges in the graph must go from an event's `auth_events` to that event. +#[t::instrument(skip(db, graphs))] +fn resolve_room_states( + db: &KeyValueDatabase, + graphs: HashMap, DiGraph>, +) -> BTreeMap { + let span = t::Span::current(); + + let todo = AtomicUsize::new(graphs.len()); + + // State resolution is slow so parallelize it by room + graphs + .into_par_iter() + .filter_map(|(room_id, graph)| { + let _enter = span.enter(); + + let value = resolve_room_state(db, graph, &room_id); + + todo.fetch_sub(1, Ordering::SeqCst); + t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining"); + + value.map(|x| (OwnedRoomId::from(room_id), x)) + }) + .fold(BTreeMap::new, |mut acc, (room_id, state_map)| { + acc.insert(room_id, state_map); + acc + }) + .reduce(BTreeMap::new, |mut x, y| { + // Room IDs should be unique per chunk so there should be no + // collisions when combining this way + x.extend(y); + x + }) +} + +/// Find a single `m.room.create` event and return its room version +fn get_room_version( + db: &KeyValueDatabase, + graph: &DiGraph, +) -> Option { + let mut create_events = HashSet::new(); + for i in graph.node_indices() { + let n = &graph[i]; + + if n.kind != StateEventType::RoomCreate { + continue; + } + + create_events.insert(n.event_id.clone()); + } + + if create_events.len() > 1 { + t::warn!( + "Multiple `m.room.create` events in graph, rejecting all of them", + ); + return None; + } + + let Some(event_id) = create_events.into_iter().next() else { + t::warn!("No `m.room.create` event in graph"); + return None; + }; + + let pdu = match get_either_pdu_by_event_id(db, &event_id) { + Ok(Some(x)) => x, + Ok(None) => { + t::error!("No `m.room.create` PDU in database"); + return None; + } + Err(error) => { + t::error!( + %error, + "Failed to get `m.room.create` PDU from database", + ); + return None; + } + }; + + let content = + match serde_json::from_str::(pdu.content.get()) { + Ok(x) => x, + Err(error) => { + t::error!( + %error, + %event_id, + "Failed to extract `.content.room_version` from PDU", + ); + return None; + } + }; + + Some(content.room_version) +} + +/// Resolve the state of a single room +/// +/// # Arguments +/// +/// * `db`: The database. +/// * `graphs`: A map of room IDs to a directed graph of their state events. The +/// edges in the graph must go from an event's `auth_events` to that event. +/// * `room_id`: The ID of the room being resolved. Only used for tracing. +#[t::instrument(skip(db, graph))] +fn resolve_room_state( + db: &KeyValueDatabase, + mut graph: DiGraph, + room_id: &RoomId, +) -> Option { + let Some(room_version) = get_room_version(db, &graph) else { + t::error!("Couldn't get room version, skipping this room"); + return None; + }; + + let (state_sets, auth_chain_sets) = graph + .node_indices() + .map(|x| { + let StateResolutionNode { + event_id, + kind, + state_key, + } = graph[x].clone(); + + // At this point, we don't know any valid groupings of state events, + // so handle each state event by itself by making a separate map for + // each one. + let mut state_set = StateMap::new(); + state_set.insert((kind, state_key), event_id); + + let mut auth_chain_set = HashSet::new(); + depth_first_search(Reversed(&graph), [x], |event| { + if let DfsEvent::Discover(node_index, _) = event { + auth_chain_set.insert(graph[node_index].event_id.clone()); + } + }); + + (state_set, auth_chain_set) + }) + .collect::<(Vec<_>, Vec<_>)>(); + + let res = ruma::state_res::resolve( + &room_version, + &state_sets, + auth_chain_sets, + |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), + ); + + let state_map = match res { + Ok(x) => x, + Err(error) => { + t::error!(%error, "State resolution failed"); + return None; + } + }; + + let resolved_state = state_map + .iter() + .map(|((kind, key), event_id)| State { + kind: kind.clone(), + key: key.clone(), + event_id: event_id.clone(), + }) + .collect(); + + let state_indices = state_map + .values() + .map(|x| { + graph + .node_indices() + .find(|i| x == &graph[*i].event_id) + .expect("event found by stateres should be in input map") + }) + .collect::>(); + + let mut auth_chain_indices = HashSet::new(); + depth_first_search( + Reversed(&graph), + state_indices.iter().copied(), + |event| { + if let DfsEvent::Discover(node_index, _) = event { + auth_chain_indices.insert(node_index); + } + }, + ); + + let accepted_indices = std::iter::empty() + .chain(auth_chain_indices.iter()) + .chain(state_indices.iter()) + .copied() + .collect::>(); + + // Remove rejected events from the graph + graph.retain_nodes(|_, i| accepted_indices.contains(&i)); + + let pdus = graph + .node_indices() + .filter_map(|x| { + let event_id = &graph[x].event_id; + match get_either_pdu_by_event_id(db, event_id) { + Ok(Some(x)) => Some(Arc::new(x)), + Ok(None) => { + t::warn!(%event_id, "Missing PDU"); + None + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU"); + None + } + } + }) + .collect::>(); + + let edges = graph + .raw_edges() + .iter() + .map(|x| { + let from = graph[x.source()].event_id.clone(); + let to = graph[x.target()].event_id.clone(); + + Edge { + from, + to, + } + }) + .collect(); + + let graphviz = Dot::new(&graph).to_string(); + + Some(Value { + resolved_state, + pdus, + edges, + graphviz, + }) +} + +/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_accepted_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_key) = + db.eventid_pduid.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to get PDU ID for event ID", + ); + })? + else { + // This is normal, perhaps the requested event is an outlier + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU ID for event ID", + ); + return Ok(None); + }; + + let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to get PDU for PDU ID", + ); + })? + else { + t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID"); + return Err("No PDU for PDU ID".into()); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an outlier [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_outlier_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_bytes) = + db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to PDU for event ID", + ); + })? + else { + // This is normal, perhaps we just don't have this event + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU for event ID" + ); + return Ok(None); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database +/// +/// Tries searching accepted PDUs first, then outlier PDUs. +#[t::instrument(skip(db))] +fn get_either_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else { + return get_outlier_pdu_by_event_id(db, event_id); + }; + + Ok(Some(pdu)) +} + +/// Get the edges in the graph of state events for state resolution +#[t::instrument(skip(db, select))] +fn get_state_resolution_edges<'a>( + db: &'a KeyValueDatabase, + select: &'a [OwnedRoomId], +) -> impl Iterator + 'a { + let filter_map = |key: Vec, value: Vec, map: &str| { + let Ok(pdu) = serde_json::from_slice::(&value) else { + t::error!( + %map, + key = utils::u8_slice_to_hex(&key), + "Failed to deserialize PDU, skipping it", + ); + return None; + }; + + if !select.is_empty() && !select.contains(&pdu.room_id) { + return None; + } + + let Some(state_key) = pdu.state_key else { + // Filter out non-state events + return None; + }; + + // Ruma fails the entire state resolution if it sees an event whose + // `state_key` it needs to parse as an ID and the parsing fails, so + // we pre-emptively drop these events. Real-world cases where this + // has happened are detailed on the issue linked below. + // + // + if pdu.kind == TimelineEventType::RoomMember + && <&UserId>::try_from(&*state_key).is_err() + { + t::warn!( + %map, + key = utils::u8_slice_to_hex(&key), + event_id = %pdu.event_id, + "Dropping event that could cause state resolution to fail", + ); + return None; + } + + Some(StateResolutionEdges { + event_id: pdu.event_id, + room_id: >::from(pdu.room_id), + auth_events: pdu.auth_events, + }) + }; + + std::iter::empty() + .chain(db.pduid_pdu.iter().filter_map(move |(key, value)| { + filter_map(key, value, "pduid_pdu") + })) + .chain(db.eventid_outlierpdu.iter().filter_map(move |(key, value)| { + filter_map(key, value, "eventid_outlierpdu") + })) +} + +/// Get a directed graph of state events for all rooms +/// +/// The edges in the directed graph go from an event's `auth_events` to that +/// event. +#[t::instrument(skip(db, state_resolution_edges))] +pub(crate) fn get_state_event_graphs( + db: &KeyValueDatabase, + state_resolution_edges: I, +) -> HashMap, DiGraph> +where + I: Iterator, +{ + // Avoid inserting the same event ID twice; instead, mutate its edges + let mut visited = HashMap::<(Arc, Arc), NodeIndex>::new(); + + // Graph of event IDs + let mut graphs = HashMap::, DiGraph>>::new(); + + for StateResolutionEdges { + event_id, + room_id, + auth_events, + } in state_resolution_edges + { + let graph = graphs.entry(room_id.clone()).or_default(); + + let this = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + for event_id in auth_events { + let prev = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + graph.add_edge(prev, this, WeightlessEdge); + } + } + + // Convert the graph of event IDs into one of `StateNode`s + let graphs = graphs + .into_iter() + .map(|(room_id, graph)| { + let graph = graph.filter_map( + |index, event_id| { + to_state_resolution_node(db, index, event_id, &graph) + }, + |_, &WeightlessEdge| Some(WeightlessEdge), + ); + + (room_id, graph) + }) + .collect::>(); + + // Record some stats for fun + t::info!( + rooms = graphs.len(), + total_nodes = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.node_indices().count(); + acc + }), + total_edges = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.edge_indices().count(); + acc + }), + "Done", + ); + + graphs +} + +/// `filter_map` a graph of [`EventId`] nodes to one of [`StateResolutionNode`]s +fn to_state_resolution_node( + db: &KeyValueDatabase, + index: NodeIndex, + event_id: &Arc, + graph: &DiGraph>, +) -> Option { + let pdu = match get_either_pdu_by_event_id(db, event_id) { + Ok(Some(x)) => x, + Ok(None) => { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + + t::warn!( + ?dependents, + missing = %event_id, + "Missing `auth_event` depended on by one or more state events", + ); + + return None; + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU by event ID"); + return None; + } + }; + + let Some(state_key) = pdu.state_key else { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + t::warn!( + ?dependents, + not_state = %event_id, + "Event depended on by one or more state events isn't a state event", + ); + return None; + }; + + Some(StateResolutionNode { + event_id: event_id.clone(), + kind: pdu.kind.to_string().into(), + state_key, + }) +} diff --git a/src/error.rs b/src/error.rs index ce8ebb37..1b30f24a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,6 +43,9 @@ pub(crate) enum Main { #[error(transparent)] ServeCommand(#[from] ServeCommand), + #[error(transparent)] + DumpStateCommand(#[from] DumpStateCommand), + #[error("failed to install global default tracing subscriber")] SetSubscriber(#[from] tracing::subscriber::SetGlobalDefaultError), @@ -85,6 +88,28 @@ pub(crate) enum CheckConfigCommand { Config(#[from] Config), } +/// Errors returned from the `dump-state` CLI subcommand. +// Missing docs are allowed here since that kind of information should be +// encoded in the error messages themselves anyway. +#[allow(missing_docs)] +#[derive(Error, Debug)] +pub(crate) enum DumpStateCommand { + #[error("failed to load configuration")] + Config(#[from] Config), + + #[error("failed to initialize services")] + InitializeServices(#[source] crate::utils::error::Error), + + #[error("failed to load or create the database")] + Database(#[source] crate::utils::error::Error), + + #[error("`server_name` change check failed")] + ServerNameChanged(#[from] ServerNameChanged), + + #[error("failed to serialize state to stdout")] + Serialize(#[from] serde_json::Error), +} + /// Error generated if `server_name` has changed or if checking this failed // Missing docs are allowed here since that kind of information should be // encoded in the error messages themselves anyway.