record pdus and edges in each room's auth chain

This changes the output format of the command drastically.
This commit is contained in:
Charles Hall 2024-11-03 18:42:22 -08:00
parent aa49b111ed
commit 052ab6dddd
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
3 changed files with 243 additions and 114 deletions

View file

@ -2,9 +2,9 @@
//! Implementation of the `get-room-states` command //! Implementation of the `get-room-states` command
use std::{cmp::Ordering, sync::Arc}; use std::{cmp::Ordering, collections::BTreeSet, sync::Arc};
use ruma::events::StateEventType; use ruma::{events::StateEventType, EventId};
use serde::Serialize; use serde::Serialize;
use super::GetRoomStatesArgs; use super::GetRoomStatesArgs;
@ -15,30 +15,52 @@ use crate::{
mod cache; mod cache;
mod recompute; mod recompute;
/// Serializable information about a state event /// A value in the output map
#[derive(Serialize, PartialEq, Eq)] #[derive(Serialize)]
struct StateEvent { struct Value {
/// The kind of state event /// The resolved state of the room
kind: StateEventType, resolved_state: BTreeSet<State>,
/// The `state_key` of the event /// The PDUs of the resolved state events and their auth chains
key: String, pdus: BTreeSet<Arc<PduEvent>>,
/// The event itself /// The edges in the state events' auth chains
event: Arc<PduEvent>, edges: BTreeSet<Edge>,
} }
impl Ord for StateEvent { /// An edge in the graph of `auth_events` of the resolved state
#[derive(Serialize, PartialEq, Eq, PartialOrd, Ord)]
struct Edge {
/// The outgoing edge
from: Arc<EventId>,
/// The incoming edge
to: Arc<EventId>,
}
/// Serializable information about a state event
#[derive(Serialize, PartialEq, Eq)]
struct State {
/// This event's `type`
kind: StateEventType,
/// This event's `state_key`
key: String,
/// The ID of this event
event_id: Arc<EventId>,
}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
Ordering::Equal Ordering::Equal
.then_with(|| self.event.room_id.cmp(&other.event.room_id))
.then_with(|| self.kind.cmp(&other.kind)) .then_with(|| self.kind.cmp(&other.kind))
.then_with(|| self.key.cmp(&other.key)) .then_with(|| self.key.cmp(&other.key))
.then_with(|| self.event.event_id.cmp(&other.event.event_id)) .then_with(|| self.event_id.cmp(&other.event_id))
} }
} }
impl PartialOrd for StateEvent { impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other)) Some(self.cmp(other))
} }
@ -67,7 +89,7 @@ pub(crate) async fn run(
let room_states = if args.recompute { let room_states = if args.recompute {
recompute::get_room_states(db, &args.select) recompute::get_room_states(db, &args.select)
} else { } else {
cache::get_room_states(&args.select).await cache::get_room_states(db, &args.select).await
}; };
serde_json::to_writer(std::io::stdout(), &room_states)?; serde_json::to_writer(std::io::stdout(), &room_states)?;

View file

@ -1,60 +1,57 @@
//! Get room states from the caches //! Get room states from the caches
use std::sync::Arc; use std::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};
use ruma::{state_res::StateMap, OwnedRoomId}; use ruma::{OwnedRoomId, RoomId};
use tracing as t; use tracing as t;
use super::StateEvent; use super::{recompute, Edge, State, Value};
use crate::{services, PduEvent}; use crate::{database::KeyValueDatabase, services};
/// Get the state of all rooms, or the selected subset /// Get the state of all rooms, or the selected subset
#[t::instrument(skip(select))] #[t::instrument(skip(db, select))]
pub(crate) async fn get_room_states(select: &[OwnedRoomId]) -> Vec<StateEvent> { pub(crate) async fn get_room_states(
let mut serializable_state = Vec::new(); db: &KeyValueDatabase,
select: &[OwnedRoomId],
) -> BTreeMap<OwnedRoomId, Value> {
let mut values = BTreeMap::new();
for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) { for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) {
if !select.is_empty() && !select.contains(&room_id) { if !select.is_empty() && !select.contains(&room_id) {
continue; continue;
} }
let Some(state) = get_room_state(room_id).await else { if let Some(value) = get_room_value(db, &room_id).await {
continue; values.insert(room_id, value);
}; }
serializable_state.extend(state.into_iter().map(
|((kind, key), event)| StateEvent {
kind,
key,
event,
},
));
} }
serializable_state.sort_unstable(); values
serializable_state
} }
/// Get the state of the given room /// Get the [`Value`] of the given room
#[t::instrument] #[t::instrument(skip(db))]
async fn get_room_state( async fn get_room_value(
room_id: OwnedRoomId, db: &KeyValueDatabase,
) -> Option<StateMap<Arc<PduEvent>>> { room_id: &RoomId,
) -> Option<Value> {
let shortstatehash = let shortstatehash =
match services().rooms.state.get_room_shortstatehash(&room_id) { match services().rooms.state.get_room_shortstatehash(room_id) {
Ok(Some(x)) => x, Ok(Some(x)) => x,
Ok(None) => { Ok(None) => {
t::warn!("No shortstatehash for room"); t::warn!("Missing shortstatehash");
return None; return None;
} }
Err(error) => { Err(error) => {
t::warn!(%error, "Failed to get shortstatehash for room"); t::warn!(%error, "Failed to get shortstatehash");
return None; return None;
} }
}; };
let state = match services() let state_map = match services()
.rooms .rooms
.state_accessor .state_accessor
.state_full(shortstatehash) .state_full(shortstatehash)
@ -62,10 +59,82 @@ async fn get_room_state(
{ {
Ok(x) => x, Ok(x) => x,
Err(error) => { Err(error) => {
t::warn!(%error, "Failed to get full state for room"); t::warn!(%error, "Failed to get full state");
return None; return None;
} }
}; };
Some(state) let resolved_state = state_map
.iter()
.map(|((kind, key), event)| State {
kind: kind.clone(),
key: key.clone(),
event_id: event.event_id.clone(),
})
.collect();
let state_event_ids =
state_map.values().map(|x| x.event_id.clone()).collect::<Vec<_>>();
let auth_chain = match services()
.rooms
.auth_chain
.get_auth_chain(room_id, state_event_ids.clone())
.await
.map(Iterator::collect::<Vec<_>>)
{
Ok(x) => x,
Err(error) => {
t::warn!(%error, "Failed to get auth chain");
return None;
}
};
let pdus = std::iter::empty()
.chain(auth_chain.into_iter())
.chain(state_event_ids.into_iter())
.filter_map(|event_id| {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(x)) => Some(x),
Ok(None) => {
t::warn!(%event_id, "Missing PDU");
None
}
Err(error) => {
t::warn!(%error, %event_id, "Failed to get PDU");
None
}
}
})
.collect::<BTreeSet<_>>();
let edges = pdus.iter().map(|x| recompute::StateResolutionEdges {
event_id: x.event_id.clone(),
room_id: <Arc<RoomId>>::from(x.room_id.clone()),
auth_events: x.auth_events.clone(),
});
let graph = recompute::get_state_event_graphs(db, edges)
.remove(&<Arc<RoomId>>::from(room_id))
.expect("there should only be a graph for the current room");
let edges = graph
.raw_edges()
.iter()
.map(|x| {
let from = graph[x.source()].event_id.clone();
let to = graph[x.target()].event_id.clone();
Edge {
from,
to,
}
})
.collect();
Some(Value {
resolved_state,
pdus,
edges,
})
} }

View file

@ -1,7 +1,7 @@
//! Get room states by recomputing them //! Get room states by recomputing them
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{BTreeMap, BTreeSet, HashMap, HashSet},
error::Error, error::Error,
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
@ -23,14 +23,14 @@ use ruma::{
use serde::Deserialize; use serde::Deserialize;
use tracing as t; use tracing as t;
use super::StateEvent; use super::{Edge, State, Value};
use crate::{database::KeyValueDatabase, utils, PduEvent}; use crate::{database::KeyValueDatabase, utils, PduEvent};
/// A weightless unit type to use for graphs with unweighted edges /// A weightless unit type to use for graphs with unweighted edges
struct WeightlessEdge; pub(crate) struct WeightlessEdge;
/// A directed graph with unweighted edges /// A directed graph with unweighted edges
type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>; pub(crate) type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>;
/// Extracts the `.content.room_version` PDU field /// Extracts the `.content.room_version` PDU field
#[derive(Deserialize)] #[derive(Deserialize)]
@ -41,74 +41,38 @@ struct ExtractRoomVersion {
/// Information about a state event node for state resolution /// Information about a state event node for state resolution
#[derive(Clone)] #[derive(Clone)]
struct StateResolutionNode { pub(crate) struct StateResolutionNode {
/// This event's ID /// This event's ID
event_id: Arc<EventId>, pub(crate) event_id: Arc<EventId>,
/// This event's type /// This event's type
kind: StateEventType, pub(crate) kind: StateEventType,
/// This event's `state_key` /// This event's `state_key`
state_key: String, pub(crate) state_key: String,
} }
/// Information about a state event's edges for state resolution /// Information about a state event's edges for state resolution
struct StateResolutionEdges { pub(crate) struct StateResolutionEdges {
/// This event's ID /// This event's ID
event_id: Arc<EventId>, pub(crate) event_id: Arc<EventId>,
/// The ID of the room this event belongs to /// The ID of the room this event belongs to
room_id: Arc<RoomId>, pub(crate) room_id: Arc<RoomId>,
/// This event's `auth_events` /// This event's `auth_events`
auth_events: Vec<Arc<EventId>>, pub(crate) auth_events: Vec<Arc<EventId>>,
} }
/// Get the state of all rooms, or the selected subset /// Get the state of all rooms, or the selected subset
pub(crate) fn get_room_states( pub(crate) fn get_room_states(
db: &KeyValueDatabase, db: &KeyValueDatabase,
select: &[OwnedRoomId], select: &[OwnedRoomId],
) -> Vec<StateEvent> { ) -> BTreeMap<OwnedRoomId, Value> {
let state_resolution_edges = get_state_resolution_edges(db, select); let state_resolution_edges = get_state_resolution_edges(db, select);
let graphs = get_state_event_graphs(db, state_resolution_edges); let graphs = get_state_event_graphs(db, state_resolution_edges);
let states = resolve_room_states(db, graphs); resolve_room_states(db, graphs)
let mut serializable_state = Vec::new();
for (_, state) in states {
serializable_state.extend(state.into_iter().filter_map(
|((kind, key), event)| {
let event = match get_either_pdu_by_event_id(db, &event) {
Ok(Some(x)) => Arc::new(x),
Ok(None) => {
t::warn!(event_id = %event, "Unknown event, omitting");
return None;
}
Err(error) => {
t::warn!(
%error,
event_id = %event,
"Failed to get event, omitting",
);
return None;
}
};
let x = StateEvent {
kind,
key,
event,
};
Some(x)
},
));
}
serializable_state.sort_unstable();
serializable_state
} }
/// Resolve the current state of all rooms /// Resolve the current state of all rooms
@ -122,7 +86,7 @@ pub(crate) fn get_room_states(
fn resolve_room_states( fn resolve_room_states(
db: &KeyValueDatabase, db: &KeyValueDatabase,
graphs: HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>, graphs: HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>,
) -> HashMap<Arc<RoomId>, StateMap<Arc<EventId>>> { ) -> BTreeMap<OwnedRoomId, Value> {
let span = t::Span::current(); let span = t::Span::current();
let todo = AtomicUsize::new(graphs.len()); let todo = AtomicUsize::new(graphs.len());
@ -133,18 +97,18 @@ fn resolve_room_states(
.filter_map(|(room_id, graph)| { .filter_map(|(room_id, graph)| {
let _enter = span.enter(); let _enter = span.enter();
let state_map = resolve_room_state(db, graph, &room_id); let value = resolve_room_state(db, graph, &room_id);
todo.fetch_sub(1, Ordering::SeqCst); todo.fetch_sub(1, Ordering::SeqCst);
t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining"); t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining");
state_map.map(|x| (room_id, x)) value.map(|x| (OwnedRoomId::from(room_id), x))
}) })
.fold(HashMap::new, |mut acc, (room_id, state_map)| { .fold(BTreeMap::new, |mut acc, (room_id, state_map)| {
acc.insert(room_id, state_map); acc.insert(room_id, state_map);
acc acc
}) })
.reduce(HashMap::new, |mut x, y| { .reduce(BTreeMap::new, |mut x, y| {
// Room IDs should be unique per chunk so there should be no // Room IDs should be unique per chunk so there should be no
// collisions when combining this way // collisions when combining this way
x.extend(y); x.extend(y);
@ -222,9 +186,9 @@ fn get_room_version(
#[t::instrument(skip(db, graph))] #[t::instrument(skip(db, graph))]
fn resolve_room_state( fn resolve_room_state(
db: &KeyValueDatabase, db: &KeyValueDatabase,
graph: DiGraph<StateResolutionNode>, mut graph: DiGraph<StateResolutionNode>,
room_id: &RoomId, room_id: &RoomId,
) -> Option<StateMap<Arc<EventId>>> { ) -> Option<Value> {
let Some(room_version) = get_room_version(db, &graph) else { let Some(room_version) = get_room_version(db, &graph) else {
t::error!("Couldn't get room version, skipping this room"); t::error!("Couldn't get room version, skipping this room");
return None; return None;
@ -263,16 +227,90 @@ fn resolve_room_state(
|event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(), |event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(),
); );
match res { let state_map = match res {
Ok(x) => { Ok(x) => x,
t::info!("Done");
Some(x)
}
Err(error) => { Err(error) => {
t::error!(%error, "Failed"); t::error!(%error, "State resolution failed");
None return None;
} }
} };
let resolved_state = state_map
.iter()
.map(|((kind, key), event_id)| State {
kind: kind.clone(),
key: key.clone(),
event_id: event_id.clone(),
})
.collect();
let state_indices = state_map
.values()
.map(|x| {
graph
.node_indices()
.find(|i| x == &graph[*i].event_id)
.expect("event found by stateres should be in input map")
})
.collect::<HashSet<_>>();
let mut auth_chain_indices = HashSet::new();
depth_first_search(
Reversed(&graph),
state_indices.iter().copied(),
|event| {
if let DfsEvent::Discover(node_index, _) = event {
auth_chain_indices.insert(node_index);
}
},
);
let accepted_indices = std::iter::empty()
.chain(auth_chain_indices.iter())
.chain(state_indices.iter())
.copied()
.collect::<HashSet<_>>();
// Remove rejected events from the graph
graph.retain_nodes(|_, i| accepted_indices.contains(&i));
let pdus = graph
.node_indices()
.filter_map(|x| {
let event_id = &graph[x].event_id;
match get_either_pdu_by_event_id(db, event_id) {
Ok(Some(x)) => Some(Arc::new(x)),
Ok(None) => {
t::warn!(%event_id, "Missing PDU");
None
}
Err(error) => {
t::warn!(%error, %event_id, "Failed to get PDU");
None
}
}
})
.collect::<BTreeSet<_>>();
let edges = graph
.raw_edges()
.iter()
.map(|x| {
let from = graph[x.source()].event_id.clone();
let to = graph[x.target()].event_id.clone();
Edge {
from,
to,
}
})
.collect();
Some(Value {
resolved_state,
pdus,
edges,
})
} }
/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database /// Look up an accepted [`PduEvent`] by its [`EventId`] in the database
@ -442,7 +480,7 @@ fn get_state_resolution_edges<'a>(
/// The edges in the directed graph go from an event's `auth_events` to that /// The edges in the directed graph go from an event's `auth_events` to that
/// event. /// event.
#[t::instrument(skip(db, state_resolution_edges))] #[t::instrument(skip(db, state_resolution_edges))]
fn get_state_event_graphs<I>( pub(crate) fn get_state_event_graphs<I>(
db: &KeyValueDatabase, db: &KeyValueDatabase,
state_resolution_edges: I, state_resolution_edges: I,
) -> HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>> ) -> HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>