diff --git a/src/cli/get_room_states.rs b/src/cli/get_room_states.rs index 3a826bd8..eaf55457 100644 --- a/src/cli/get_room_states.rs +++ b/src/cli/get_room_states.rs @@ -2,9 +2,9 @@ //! Implementation of the `get-room-states` command -use std::{cmp::Ordering, sync::Arc}; +use std::{cmp::Ordering, collections::BTreeSet, sync::Arc}; -use ruma::events::StateEventType; +use ruma::{events::StateEventType, EventId}; use serde::Serialize; use super::GetRoomStatesArgs; @@ -15,30 +15,52 @@ use crate::{ mod cache; mod recompute; -/// Serializable information about a state event -#[derive(Serialize, PartialEq, Eq)] -struct StateEvent { - /// The kind of state event - kind: StateEventType, +/// A value in the output map +#[derive(Serialize)] +struct Value { + /// The resolved state of the room + resolved_state: BTreeSet, - /// The `state_key` of the event - key: String, + /// The PDUs of the resolved state events and their auth chains + pdus: BTreeSet>, - /// The event itself - event: Arc, + /// The edges in the state events' auth chains + edges: BTreeSet, } -impl Ord for StateEvent { +/// An edge in the graph of `auth_events` of the resolved state +#[derive(Serialize, PartialEq, Eq, PartialOrd, Ord)] +struct Edge { + /// The outgoing edge + from: Arc, + + /// The incoming edge + to: Arc, +} + +/// Serializable information about a state event +#[derive(Serialize, PartialEq, Eq)] +struct State { + /// This event's `type` + kind: StateEventType, + + /// This event's `state_key` + key: String, + + /// The ID of this event + event_id: Arc, +} + +impl Ord for State { fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal - .then_with(|| self.event.room_id.cmp(&other.event.room_id)) .then_with(|| self.kind.cmp(&other.kind)) .then_with(|| self.key.cmp(&other.key)) - .then_with(|| self.event.event_id.cmp(&other.event.event_id)) + .then_with(|| self.event_id.cmp(&other.event_id)) } } -impl PartialOrd for StateEvent { +impl PartialOrd for State { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } @@ -67,7 +89,7 @@ pub(crate) async fn run( let room_states = if args.recompute { recompute::get_room_states(db, &args.select) } else { - cache::get_room_states(&args.select).await + cache::get_room_states(db, &args.select).await }; serde_json::to_writer(std::io::stdout(), &room_states)?; diff --git a/src/cli/get_room_states/cache.rs b/src/cli/get_room_states/cache.rs index 23485834..756dc8e7 100644 --- a/src/cli/get_room_states/cache.rs +++ b/src/cli/get_room_states/cache.rs @@ -1,60 +1,57 @@ //! Get room states from the caches -use std::sync::Arc; +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::Arc, +}; -use ruma::{state_res::StateMap, OwnedRoomId}; +use ruma::{OwnedRoomId, RoomId}; use tracing as t; -use super::StateEvent; -use crate::{services, PduEvent}; +use super::{recompute, Edge, State, Value}; +use crate::{database::KeyValueDatabase, services}; /// Get the state of all rooms, or the selected subset -#[t::instrument(skip(select))] -pub(crate) async fn get_room_states(select: &[OwnedRoomId]) -> Vec { - let mut serializable_state = Vec::new(); +#[t::instrument(skip(db, select))] +pub(crate) async fn get_room_states( + db: &KeyValueDatabase, + select: &[OwnedRoomId], +) -> BTreeMap { + let mut values = BTreeMap::new(); for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) { if !select.is_empty() && !select.contains(&room_id) { continue; } - let Some(state) = get_room_state(room_id).await else { - continue; - }; - - serializable_state.extend(state.into_iter().map( - |((kind, key), event)| StateEvent { - kind, - key, - event, - }, - )); + if let Some(value) = get_room_value(db, &room_id).await { + values.insert(room_id, value); + } } - serializable_state.sort_unstable(); - - serializable_state + values } -/// Get the state of the given room -#[t::instrument] -async fn get_room_state( - room_id: OwnedRoomId, -) -> Option>> { +/// Get the [`Value`] of the given room +#[t::instrument(skip(db))] +async fn get_room_value( + db: &KeyValueDatabase, + room_id: &RoomId, +) -> Option { let shortstatehash = - match services().rooms.state.get_room_shortstatehash(&room_id) { + match services().rooms.state.get_room_shortstatehash(room_id) { Ok(Some(x)) => x, Ok(None) => { - t::warn!("No shortstatehash for room"); + t::warn!("Missing shortstatehash"); return None; } Err(error) => { - t::warn!(%error, "Failed to get shortstatehash for room"); + t::warn!(%error, "Failed to get shortstatehash"); return None; } }; - let state = match services() + let state_map = match services() .rooms .state_accessor .state_full(shortstatehash) @@ -62,10 +59,82 @@ async fn get_room_state( { Ok(x) => x, Err(error) => { - t::warn!(%error, "Failed to get full state for room"); + t::warn!(%error, "Failed to get full state"); return None; } }; - Some(state) + let resolved_state = state_map + .iter() + .map(|((kind, key), event)| State { + kind: kind.clone(), + key: key.clone(), + event_id: event.event_id.clone(), + }) + .collect(); + + let state_event_ids = + state_map.values().map(|x| x.event_id.clone()).collect::>(); + + let auth_chain = match services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_event_ids.clone()) + .await + .map(Iterator::collect::>) + { + Ok(x) => x, + Err(error) => { + t::warn!(%error, "Failed to get auth chain"); + return None; + } + }; + + let pdus = std::iter::empty() + .chain(auth_chain.into_iter()) + .chain(state_event_ids.into_iter()) + .filter_map(|event_id| { + match services().rooms.timeline.get_pdu(&event_id) { + Ok(Some(x)) => Some(x), + Ok(None) => { + t::warn!(%event_id, "Missing PDU"); + None + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU"); + None + } + } + }) + .collect::>(); + + let edges = pdus.iter().map(|x| recompute::StateResolutionEdges { + event_id: x.event_id.clone(), + room_id: >::from(x.room_id.clone()), + auth_events: x.auth_events.clone(), + }); + + let graph = recompute::get_state_event_graphs(db, edges) + .remove(&>::from(room_id)) + .expect("there should only be a graph for the current room"); + + let edges = graph + .raw_edges() + .iter() + .map(|x| { + let from = graph[x.source()].event_id.clone(); + let to = graph[x.target()].event_id.clone(); + + Edge { + from, + to, + } + }) + .collect(); + + Some(Value { + resolved_state, + pdus, + edges, + }) } diff --git a/src/cli/get_room_states/recompute.rs b/src/cli/get_room_states/recompute.rs index 480aed89..c58cb985 100644 --- a/src/cli/get_room_states/recompute.rs +++ b/src/cli/get_room_states/recompute.rs @@ -1,7 +1,7 @@ //! Get room states by recomputing them use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, error::Error, sync::{ atomic::{AtomicUsize, Ordering}, @@ -23,14 +23,14 @@ use ruma::{ use serde::Deserialize; use tracing as t; -use super::StateEvent; +use super::{Edge, State, Value}; use crate::{database::KeyValueDatabase, utils, PduEvent}; /// A weightless unit type to use for graphs with unweighted edges -struct WeightlessEdge; +pub(crate) struct WeightlessEdge; /// A directed graph with unweighted edges -type DiGraph = petgraph::graph::DiGraph; +pub(crate) type DiGraph = petgraph::graph::DiGraph; /// Extracts the `.content.room_version` PDU field #[derive(Deserialize)] @@ -41,74 +41,38 @@ struct ExtractRoomVersion { /// Information about a state event node for state resolution #[derive(Clone)] -struct StateResolutionNode { +pub(crate) struct StateResolutionNode { /// This event's ID - event_id: Arc, + pub(crate) event_id: Arc, /// This event's type - kind: StateEventType, + pub(crate) kind: StateEventType, /// This event's `state_key` - state_key: String, + pub(crate) state_key: String, } /// Information about a state event's edges for state resolution -struct StateResolutionEdges { +pub(crate) struct StateResolutionEdges { /// This event's ID - event_id: Arc, + pub(crate) event_id: Arc, /// The ID of the room this event belongs to - room_id: Arc, + pub(crate) room_id: Arc, /// This event's `auth_events` - auth_events: Vec>, + pub(crate) auth_events: Vec>, } /// Get the state of all rooms, or the selected subset pub(crate) fn get_room_states( db: &KeyValueDatabase, select: &[OwnedRoomId], -) -> Vec { +) -> BTreeMap { let state_resolution_edges = get_state_resolution_edges(db, select); let graphs = get_state_event_graphs(db, state_resolution_edges); - let states = resolve_room_states(db, graphs); - - let mut serializable_state = Vec::new(); - - for (_, state) in states { - serializable_state.extend(state.into_iter().filter_map( - |((kind, key), event)| { - let event = match get_either_pdu_by_event_id(db, &event) { - Ok(Some(x)) => Arc::new(x), - Ok(None) => { - t::warn!(event_id = %event, "Unknown event, omitting"); - return None; - } - Err(error) => { - t::warn!( - %error, - event_id = %event, - "Failed to get event, omitting", - ); - return None; - } - }; - - let x = StateEvent { - kind, - key, - event, - }; - - Some(x) - }, - )); - } - - serializable_state.sort_unstable(); - - serializable_state + resolve_room_states(db, graphs) } /// Resolve the current state of all rooms @@ -122,7 +86,7 @@ pub(crate) fn get_room_states( fn resolve_room_states( db: &KeyValueDatabase, graphs: HashMap, DiGraph>, -) -> HashMap, StateMap>> { +) -> BTreeMap { let span = t::Span::current(); let todo = AtomicUsize::new(graphs.len()); @@ -133,18 +97,18 @@ fn resolve_room_states( .filter_map(|(room_id, graph)| { let _enter = span.enter(); - let state_map = resolve_room_state(db, graph, &room_id); + let value = resolve_room_state(db, graph, &room_id); todo.fetch_sub(1, Ordering::SeqCst); t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining"); - state_map.map(|x| (room_id, x)) + value.map(|x| (OwnedRoomId::from(room_id), x)) }) - .fold(HashMap::new, |mut acc, (room_id, state_map)| { + .fold(BTreeMap::new, |mut acc, (room_id, state_map)| { acc.insert(room_id, state_map); acc }) - .reduce(HashMap::new, |mut x, y| { + .reduce(BTreeMap::new, |mut x, y| { // Room IDs should be unique per chunk so there should be no // collisions when combining this way x.extend(y); @@ -222,9 +186,9 @@ fn get_room_version( #[t::instrument(skip(db, graph))] fn resolve_room_state( db: &KeyValueDatabase, - graph: DiGraph, + mut graph: DiGraph, room_id: &RoomId, -) -> Option>> { +) -> Option { let Some(room_version) = get_room_version(db, &graph) else { t::error!("Couldn't get room version, skipping this room"); return None; @@ -263,16 +227,90 @@ fn resolve_room_state( |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), ); - match res { - Ok(x) => { - t::info!("Done"); - Some(x) - } + let state_map = match res { + Ok(x) => x, Err(error) => { - t::error!(%error, "Failed"); - None + t::error!(%error, "State resolution failed"); + return None; } - } + }; + + let resolved_state = state_map + .iter() + .map(|((kind, key), event_id)| State { + kind: kind.clone(), + key: key.clone(), + event_id: event_id.clone(), + }) + .collect(); + + let state_indices = state_map + .values() + .map(|x| { + graph + .node_indices() + .find(|i| x == &graph[*i].event_id) + .expect("event found by stateres should be in input map") + }) + .collect::>(); + + let mut auth_chain_indices = HashSet::new(); + depth_first_search( + Reversed(&graph), + state_indices.iter().copied(), + |event| { + if let DfsEvent::Discover(node_index, _) = event { + auth_chain_indices.insert(node_index); + } + }, + ); + + let accepted_indices = std::iter::empty() + .chain(auth_chain_indices.iter()) + .chain(state_indices.iter()) + .copied() + .collect::>(); + + // Remove rejected events from the graph + graph.retain_nodes(|_, i| accepted_indices.contains(&i)); + + let pdus = graph + .node_indices() + .filter_map(|x| { + let event_id = &graph[x].event_id; + match get_either_pdu_by_event_id(db, event_id) { + Ok(Some(x)) => Some(Arc::new(x)), + Ok(None) => { + t::warn!(%event_id, "Missing PDU"); + None + } + Err(error) => { + t::warn!(%error, %event_id, "Failed to get PDU"); + None + } + } + }) + .collect::>(); + + let edges = graph + .raw_edges() + .iter() + .map(|x| { + let from = graph[x.source()].event_id.clone(); + let to = graph[x.target()].event_id.clone(); + + Edge { + from, + to, + } + }) + .collect(); + + Some(Value { + resolved_state, + pdus, + edges, + }) } /// Look up an accepted [`PduEvent`] by its [`EventId`] in the database @@ -442,7 +480,7 @@ fn get_state_resolution_edges<'a>( /// The edges in the directed graph go from an event's `auth_events` to that /// event. #[t::instrument(skip(db, state_resolution_edges))] -fn get_state_event_graphs( +pub(crate) fn get_state_event_graphs( db: &KeyValueDatabase, state_resolution_edges: I, ) -> HashMap, DiGraph>