From 0b4cc6a1b140c44c21a4a46f83976ba86eba2cd5 Mon Sep 17 00:00:00 2001 From: Charles Hall Date: Thu, 10 Oct 2024 12:02:54 -0700 Subject: [PATCH] add subcmd to repair some persistent state This version does state resolution from scratch instead of trusting the caches in the database. --- Cargo.lock | 63 ++++++ Cargo.toml | 2 + src/cli.rs | 23 ++ src/cli/repair.rs | 561 ++++++++++++++++++++++++++++++++++++++++++++++ src/error.rs | 3 + 5 files changed, 652 insertions(+) create mode 100644 src/cli/repair.rs diff --git a/Cargo.lock b/Cargo.lock index 591033a0..e1adc754 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -530,6 +530,31 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crypto-common" version = "0.1.6" @@ -713,6 +738,12 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.33" @@ -895,12 +926,14 @@ dependencies = [ "opentelemetry-prometheus", "opentelemetry_sdk", "parking_lot", + "petgraph", "phf", "pin-project-lite", "predicates", "prometheus", "proxy-header", "rand", + "rayon", "regex", "reqwest", "ring", @@ -1815,6 +1848,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.5.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2128,6 +2171,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.4" diff --git a/Cargo.toml b/Cargo.toml index d340dca5..e4cacbd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,11 +112,13 @@ opentelemetry-otlp = "0.17.0" opentelemetry-prometheus = "0.17.0" opentelemetry_sdk = { version = "0.24.0", features = ["rt-tokio"] } parking_lot = { version = "0.12.3", optional = true } +petgraph = "0.6.5" phf = { version = "0.11.2", features = ["macros"] } pin-project-lite = "0.2.14" prometheus = "0.13.4" proxy-header = { version = "0.1.2", features = ["tokio"] } rand = "0.8.5" +rayon = "1.10.0" regex = "1.10.6" reqwest = { version = "0.12.7", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } ring = "0.17.8" diff --git a/src/cli.rs b/src/cli.rs index 99641790..acdf63b3 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -13,6 +13,7 @@ use crate::{ }; mod check_config; +mod repair; mod serve; /// Command line arguments @@ -33,6 +34,14 @@ pub(crate) enum Command { /// Check the configuration file for syntax and semantic errors. CheckConfig(CheckConfigArgs), + + /// Repair (some) persistent state + /// + /// Currently fixes the following issues: + /// + /// * + /// * + Repair(RepairArgs), } #[derive(clap::Args)] @@ -88,6 +97,15 @@ pub(crate) struct ServeArgs { pub(crate) config: ConfigArg, } +#[derive(clap::Args)] +pub(crate) struct RepairArgs { + #[clap(flatten)] + observability: ObservabilityArgs, + + #[clap(flatten)] + pub(crate) config: ConfigArg, +} + impl Args { pub(crate) async fn run(self) -> Result<(), error::Main> { if let Some((format, filter)) = self.command.cli_observability_args() { @@ -99,6 +117,7 @@ impl Args { Command::CheckConfig(args) => { check_config::run(args.config).await?; } + Command::Repair(args) => repair::run(args).await?, } Ok(()) } @@ -113,6 +132,10 @@ impl Command { args.observability.log_format, args.observability.log_filter.clone(), )), + Command::Repair(args) => Some(( + args.observability.log_format, + args.observability.log_filter.clone(), + )), Command::Serve(_) => None, } } diff --git a/src/cli/repair.rs b/src/cli/repair.rs new file mode 100644 index 00000000..ab085a00 --- /dev/null +++ b/src/cli/repair.rs @@ -0,0 +1,561 @@ +#![warn(clippy::missing_docs_in_private_items)] + +//! Implementation of the `repair` subcommand + +use std::{ + collections::{HashMap, HashSet}, + error::Error, + sync::{ + atomic::{AtomicUsize, Ordering as AtomicOrdering}, + Arc, + }, +}; + +use petgraph::{ + graph::{DiGraph, NodeIndex}, + visit::{depth_first_search, DfsEvent, Reversed}, + Direction, +}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use ruma::{ + events::{room::member::MembershipState, StateEventType}, + state_res::StateMap, + EventId, RoomId, RoomVersionId, UserId, +}; +use serde::Deserialize; +use tracing as t; + +use super::RepairArgs; +use crate::{config, database::KeyValueDatabase, utils, PduEvent}; + +/// Extractor to read just the membership state +#[derive(Deserialize)] +struct ExtractMembership { + /// Membership state + membership: MembershipState, +} + +/// Extractor to read just the `room_version` field +#[derive(Deserialize)] +struct ExtractRoomVersion { + /// Room version + room_version: RoomVersionId, +} + +/// Node in a graph of state events whose edges are based on `auth_events` +#[derive(Clone)] +struct StateNode { + /// The event ID + id: Arc, + + /// The state event type + kind: StateEventType, + + /// The state key + key: String, +} + +/// Subcommand entrypoint +pub(crate) async fn run(args: RepairArgs) -> Result<(), Box> { + t::info!("Repairing persistent state"); + + let config = config::load(args.config.config).await?; + + let db = KeyValueDatabase::load_or_create(&config)?; + + let graphs = get_state_event_graphs(&db); + + let states = resolve_room_states(&db, graphs); + + repair_roomuserid_joined(&db, &states)?; + + t::info!("Done"); + + Ok(()) +} + +/// Repair the `roomuserid_joined` map +#[t::instrument(skip(db, states))] +fn repair_roomuserid_joined( + db: &KeyValueDatabase, + states: &HashMap, StateMap>>, +) -> Result<(), Box> { + t::info!( + length = db.roomuserid_joined.iter().count(), + "Original length of map", + ); + + db.roomuserid_joined.clear()?; + + for (room_id, state) in states { + let states = + state.iter().filter_map(|((kind, key), id)| { + if *kind != StateEventType::RoomMember { + return None; + } + + let user_id = UserId::parse(key).inspect_err(|error| { + t::error!( + %error, + %key, + event_id = %id, + "Failed to parse m.room.member state key as user ID", + ); + }).ok()?; + + let pdu = get_accepted_pdu_by_event_id(db, id).ok()??; + + let membership = match serde_json::from_str(pdu.content.get()) { + Ok(ExtractMembership { + membership, + }) => membership, + Err(error) => { + t::warn!( + %error, + event_id = %id, + "Failed to deserialize PDU content", + ); + return None; + } + }; + + Some((user_id, membership)) + }); + + for (user_id, membership) in states { + let is_joined = match membership { + MembershipState::Invite + | MembershipState::Ban + | MembershipState::Leave => false, + + MembershipState::Join => true, + + state => { + t::warn!( + %room_id, + %user_id, + %state, + "Unsupported membership state", + ); + false + } + }; + + if !is_joined { + continue; + } + + let key = [room_id.as_bytes(), user_id.as_bytes()].join(&0xFF); + + db.roomuserid_joined.insert(&key, &[])?; + } + } + + t::info!(length = db.roomuserid_joined.iter().count(), "New length of map"); + + Ok(()) +} + +/// Resolve the current state of all rooms +/// +/// # Arguments +/// +/// * `db`: The database. +/// * `graphs`: A map of room IDs to a directed graph of their state events. The +/// edges in the graph must go from an event's `auth_events` to that event. +#[t::instrument(skip(db, graphs))] +fn resolve_room_states( + db: &KeyValueDatabase, + graphs: HashMap, DiGraph>, +) -> HashMap, StateMap>> { + let span = t::Span::current(); + + let todo = AtomicUsize::new(graphs.len()); + + // State resolution is slow so parallelize it by room + graphs + .into_par_iter() + .filter_map(|(room_id, graph)| { + let _enter = span.enter(); + + let state_map = resolve_room_state(db, graph, &room_id); + + todo.fetch_sub(1, AtomicOrdering::SeqCst); + t::info!( + count = todo.load(AtomicOrdering::SeqCst), + "Rooms remaining", + ); + + state_map.map(|x| (room_id, x)) + }) + .fold(HashMap::new, |mut acc, (room_id, state_map)| { + acc.insert(room_id, state_map); + acc + }) + .reduce(HashMap::new, |mut x, y| { + // Room IDs should be unique per chunk so there should be no + // collisions when combining this way + x.extend(y); + x + }) +} + +/// Find a single `m.room.create` event and return its room version +fn get_room_version( + db: &KeyValueDatabase, + graph: &DiGraph, +) -> Option { + let mut create_events = HashSet::new(); + for i in graph.node_indices() { + let n = &graph[i]; + + if n.kind != StateEventType::RoomCreate { + continue; + } + + create_events.insert(n.id.clone()); + } + + if create_events.len() > 1 { + t::warn!( + "Multiple `m.room.create` events in graph, rejecting all of them", + ); + return None; + } + + let Some(event_id) = create_events.into_iter().next() else { + t::warn!("No `m.room.create` event in graph"); + return None; + }; + + let pdu = match get_either_pdu_by_event_id(db, &event_id) { + Ok(Some(x)) => x, + Ok(None) => { + t::error!("No `m.room.create` PDU in database"); + return None; + } + Err(error) => { + t::error!( + %error, + "Failed to get `m.room.create` PDU from database", + ); + return None; + } + }; + + let content = + match serde_json::from_str::(pdu.content.get()) { + Ok(x) => x, + Err(error) => { + t::error!( + %error, + %event_id, + "Failed to extract `.content.room_version` from PDU", + ); + return None; + } + }; + + Some(content.room_version) +} + +#[t::instrument(skip(db, graph))] +fn resolve_room_state( + db: &KeyValueDatabase, + graph: DiGraph, + + // Only used by the tracing span + room_id: &RoomId, +) -> Option>> { + let Some(room_version) = get_room_version(db, &graph) else { + t::error!("Couldn't get room version, skipping this room"); + return None; + }; + + let (state_sets, auth_chain_sets) = graph + .node_indices() + .map(|x| { + let StateNode { + id, + kind, + key, + } = graph[x].clone(); + + // At this point, we don't know any valid groupings of state events, + // so handle each state event by itself by making a separate map for + // each one. + let mut state_set = StateMap::new(); + state_set.insert((kind, key), id); + + let mut auth_chain_set = HashSet::new(); + depth_first_search(Reversed(&graph), [x], |event| { + if let DfsEvent::Discover(node_index, _) = event { + auth_chain_set.insert(graph[node_index].id.clone()); + } + }); + + (state_set, auth_chain_set) + }) + .collect::<(Vec<_>, Vec<_>)>(); + + let res = ruma::state_res::resolve( + &room_version, + &state_sets, + auth_chain_sets, + |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), + ); + + match res { + Ok(x) => { + t::info!("Done"); + Some(x) + } + Err(error) => { + t::error!(%error, "Failed"); + None + } + } +} + +/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_accepted_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_key) = + db.eventid_pduid.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to get PDU ID for event ID", + ); + })? + else { + // This is normal, perhaps the requested event is an outlier + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU ID for event ID", + ); + return Ok(None); + }; + + let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to get PDU for PDU ID", + ); + })? + else { + t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID"); + return Err("No PDU for PDU ID".into()); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(&pdu_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an outlier [`PduEvent`] by its [`EventId`] in the database +#[t::instrument(skip(db))] +fn get_outlier_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let event_id_key = event_id.as_bytes(); + + let Some(pdu_bytes) = + db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to PDU for event ID", + ); + })? + else { + // This is normal, perhaps we just don't have this event + t::debug!( + key = utils::u8_slice_to_hex(event_id_key), + "No PDU for event ID" + ); + return Ok(None); + }; + + let pdu = serde_json::from_slice::(&pdu_bytes).inspect_err( + |error| { + t::error!( + %error, + key = utils::u8_slice_to_hex(event_id_key), + "Failed to deserialize PDU", + ); + }, + )?; + + Ok(Some(pdu)) +} + +/// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database +/// +/// Tries searching accepted PDUs first, then outlier PDUs. +#[t::instrument(skip(db))] +fn get_either_pdu_by_event_id( + db: &KeyValueDatabase, + event_id: &EventId, +) -> Result, Box> { + let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else { + return get_outlier_pdu_by_event_id(db, event_id); + }; + + Ok(Some(pdu)) +} + +/// Get per-room directed graphs of state events +/// +/// The edges in the directed graph go from an event's `auth_events` to that +/// event. +#[t::instrument(skip(db))] +fn get_state_event_graphs( + db: &KeyValueDatabase, +) -> HashMap, DiGraph> { + t::info!("Getting state event graphs for all rooms (slow)"); + + let extract_relationship_data = + |key: Vec, value: Vec, map: &str| { + let Ok(pdu) = serde_json::from_slice::(&value) else { + t::error!( + %map, + key = utils::u8_slice_to_hex(&key), + "Failed to deserialize PDU, skipping it", + ); + return None; + }; + + // Filter out non-state events + pdu.state_key?; + + Some(( + >::from(pdu.room_id), + pdu.auth_events, + pdu.event_id, + )) + }; + + let event_relationships = std::iter::empty() + .chain(db.pduid_pdu.iter().filter_map(|(key, value)| { + extract_relationship_data(key, value, "pduid_pdu") + })) + .chain(db.eventid_outlierpdu.iter().filter_map(|(key, value)| { + extract_relationship_data(key, value, "eventid_outlierpdu") + })); + + // Avoid inserting the same event ID twice; instead, mutate its edges + let mut visited = HashMap::<(Arc, Arc), NodeIndex>::new(); + + // Graph of event IDs + let mut graphs = HashMap::, DiGraph, ()>>::new(); + + for (room_id, auth_events, event_id) in event_relationships { + let graph = graphs.entry(room_id.clone()).or_default(); + + let this = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + for event_id in auth_events { + let prev = *visited + .entry((room_id.clone(), event_id.clone())) + .or_insert_with(|| graph.add_node(event_id)); + + graph.add_edge(prev, this, ()); + } + } + + // Convert the graph of event IDs into one of `StateNode`s + let graphs = graphs + .into_iter() + .map(|(room_id, graph)| { + let graph = graph.filter_map( + |index, event_id| to_state_node(db, index, event_id, &graph), + |_, &()| Some(()), + ); + + (room_id, graph) + }) + .collect::>(); + + // Record some stats for fun + t::info!( + rooms = graphs.len(), + total_nodes = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.node_indices().count(); + acc + }), + total_edges = graphs.iter().fold(0, |mut acc, x| { + acc += x.1.edge_indices().count(); + acc + }), + "Done", + ); + + graphs +} + +/// For `filter_map`ping a graph of event ID nodes to state nodes +fn to_state_node( + db: &KeyValueDatabase, + index: NodeIndex, + event_id: &Arc, + graph: &DiGraph, ()>, +) -> Option { + let pdu = match get_either_pdu_by_event_id(db, event_id) { + Ok(Some(x)) => x, + Ok(None) => { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + t::debug!( + ?dependents, + missing = %event_id, + "Missing `auth_event` depended on by one or more \ + state events", + ); + return None; + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU by event ID"); + return None; + } + }; + + let Some(state_key) = pdu.state_key else { + let mut dependents = Vec::new(); + for i in graph.neighbors_directed(index, Direction::Outgoing) { + dependents.push(&graph[i]); + } + t::warn!( + ?dependents, + not_state = %event_id, + "Event depended on by one or more state events is \ + not a state event", + ); + return None; + }; + + Some(StateNode { + id: event_id.clone(), + kind: pdu.kind.to_string().into(), + key: state_key, + }) +} diff --git a/src/error.rs b/src/error.rs index ce8ebb37..01aefe80 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,6 +43,9 @@ pub(crate) enum Main { #[error(transparent)] ServeCommand(#[from] ServeCommand), + #[error(transparent)] + RepairCommand(#[from] Box), + #[error("failed to install global default tracing subscriber")] SetSubscriber(#[from] tracing::subscriber::SetGlobalDefaultError),