record pdus and edges in each room's auth chain

This changes the output format of the command drastically.
This commit is contained in:
Charles Hall 2024-11-03 18:42:22 -08:00
parent aa49b111ed
commit 052ab6dddd
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
3 changed files with 243 additions and 114 deletions

View file

@ -2,9 +2,9 @@
//! Implementation of the `get-room-states` command
use std::{cmp::Ordering, sync::Arc};
use std::{cmp::Ordering, collections::BTreeSet, sync::Arc};
use ruma::events::StateEventType;
use ruma::{events::StateEventType, EventId};
use serde::Serialize;
use super::GetRoomStatesArgs;
@ -15,30 +15,52 @@ use crate::{
mod cache;
mod recompute;
/// Serializable information about a state event
#[derive(Serialize, PartialEq, Eq)]
struct StateEvent {
/// The kind of state event
kind: StateEventType,
/// A value in the output map
#[derive(Serialize)]
struct Value {
/// The resolved state of the room
resolved_state: BTreeSet<State>,
/// The `state_key` of the event
key: String,
/// The PDUs of the resolved state events and their auth chains
pdus: BTreeSet<Arc<PduEvent>>,
/// The event itself
event: Arc<PduEvent>,
/// The edges in the state events' auth chains
edges: BTreeSet<Edge>,
}
impl Ord for StateEvent {
/// An edge in the graph of `auth_events` of the resolved state
#[derive(Serialize, PartialEq, Eq, PartialOrd, Ord)]
struct Edge {
/// The outgoing edge
from: Arc<EventId>,
/// The incoming edge
to: Arc<EventId>,
}
/// Serializable information about a state event
#[derive(Serialize, PartialEq, Eq)]
struct State {
/// This event's `type`
kind: StateEventType,
/// This event's `state_key`
key: String,
/// The ID of this event
event_id: Arc<EventId>,
}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
Ordering::Equal
.then_with(|| self.event.room_id.cmp(&other.event.room_id))
.then_with(|| self.kind.cmp(&other.kind))
.then_with(|| self.key.cmp(&other.key))
.then_with(|| self.event.event_id.cmp(&other.event.event_id))
.then_with(|| self.event_id.cmp(&other.event_id))
}
}
impl PartialOrd for StateEvent {
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
@ -67,7 +89,7 @@ pub(crate) async fn run(
let room_states = if args.recompute {
recompute::get_room_states(db, &args.select)
} else {
cache::get_room_states(&args.select).await
cache::get_room_states(db, &args.select).await
};
serde_json::to_writer(std::io::stdout(), &room_states)?;

View file

@ -1,60 +1,57 @@
//! Get room states from the caches
use std::sync::Arc;
use std::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};
use ruma::{state_res::StateMap, OwnedRoomId};
use ruma::{OwnedRoomId, RoomId};
use tracing as t;
use super::StateEvent;
use crate::{services, PduEvent};
use super::{recompute, Edge, State, Value};
use crate::{database::KeyValueDatabase, services};
/// Get the state of all rooms, or the selected subset
#[t::instrument(skip(select))]
pub(crate) async fn get_room_states(select: &[OwnedRoomId]) -> Vec<StateEvent> {
let mut serializable_state = Vec::new();
#[t::instrument(skip(db, select))]
pub(crate) async fn get_room_states(
db: &KeyValueDatabase,
select: &[OwnedRoomId],
) -> BTreeMap<OwnedRoomId, Value> {
let mut values = BTreeMap::new();
for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) {
if !select.is_empty() && !select.contains(&room_id) {
continue;
}
let Some(state) = get_room_state(room_id).await else {
continue;
};
serializable_state.extend(state.into_iter().map(
|((kind, key), event)| StateEvent {
kind,
key,
event,
},
));
if let Some(value) = get_room_value(db, &room_id).await {
values.insert(room_id, value);
}
}
serializable_state.sort_unstable();
serializable_state
values
}
/// Get the state of the given room
#[t::instrument]
async fn get_room_state(
room_id: OwnedRoomId,
) -> Option<StateMap<Arc<PduEvent>>> {
/// Get the [`Value`] of the given room
#[t::instrument(skip(db))]
async fn get_room_value(
db: &KeyValueDatabase,
room_id: &RoomId,
) -> Option<Value> {
let shortstatehash =
match services().rooms.state.get_room_shortstatehash(&room_id) {
match services().rooms.state.get_room_shortstatehash(room_id) {
Ok(Some(x)) => x,
Ok(None) => {
t::warn!("No shortstatehash for room");
t::warn!("Missing shortstatehash");
return None;
}
Err(error) => {
t::warn!(%error, "Failed to get shortstatehash for room");
t::warn!(%error, "Failed to get shortstatehash");
return None;
}
};
let state = match services()
let state_map = match services()
.rooms
.state_accessor
.state_full(shortstatehash)
@ -62,10 +59,82 @@ async fn get_room_state(
{
Ok(x) => x,
Err(error) => {
t::warn!(%error, "Failed to get full state for room");
t::warn!(%error, "Failed to get full state");
return None;
}
};
Some(state)
let resolved_state = state_map
.iter()
.map(|((kind, key), event)| State {
kind: kind.clone(),
key: key.clone(),
event_id: event.event_id.clone(),
})
.collect();
let state_event_ids =
state_map.values().map(|x| x.event_id.clone()).collect::<Vec<_>>();
let auth_chain = match services()
.rooms
.auth_chain
.get_auth_chain(room_id, state_event_ids.clone())
.await
.map(Iterator::collect::<Vec<_>>)
{
Ok(x) => x,
Err(error) => {
t::warn!(%error, "Failed to get auth chain");
return None;
}
};
let pdus = std::iter::empty()
.chain(auth_chain.into_iter())
.chain(state_event_ids.into_iter())
.filter_map(|event_id| {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(x)) => Some(x),
Ok(None) => {
t::warn!(%event_id, "Missing PDU");
None
}
Err(error) => {
t::warn!(%error, %event_id, "Failed to get PDU");
None
}
}
})
.collect::<BTreeSet<_>>();
let edges = pdus.iter().map(|x| recompute::StateResolutionEdges {
event_id: x.event_id.clone(),
room_id: <Arc<RoomId>>::from(x.room_id.clone()),
auth_events: x.auth_events.clone(),
});
let graph = recompute::get_state_event_graphs(db, edges)
.remove(&<Arc<RoomId>>::from(room_id))
.expect("there should only be a graph for the current room");
let edges = graph
.raw_edges()
.iter()
.map(|x| {
let from = graph[x.source()].event_id.clone();
let to = graph[x.target()].event_id.clone();
Edge {
from,
to,
}
})
.collect();
Some(Value {
resolved_state,
pdus,
edges,
})
}

View file

@ -1,7 +1,7 @@
//! Get room states by recomputing them
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
error::Error,
sync::{
atomic::{AtomicUsize, Ordering},
@ -23,14 +23,14 @@ use ruma::{
use serde::Deserialize;
use tracing as t;
use super::StateEvent;
use super::{Edge, State, Value};
use crate::{database::KeyValueDatabase, utils, PduEvent};
/// A weightless unit type to use for graphs with unweighted edges
struct WeightlessEdge;
pub(crate) struct WeightlessEdge;
/// A directed graph with unweighted edges
type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>;
pub(crate) type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>;
/// Extracts the `.content.room_version` PDU field
#[derive(Deserialize)]
@ -41,74 +41,38 @@ struct ExtractRoomVersion {
/// Information about a state event node for state resolution
#[derive(Clone)]
struct StateResolutionNode {
pub(crate) struct StateResolutionNode {
/// This event's ID
event_id: Arc<EventId>,
pub(crate) event_id: Arc<EventId>,
/// This event's type
kind: StateEventType,
pub(crate) kind: StateEventType,
/// This event's `state_key`
state_key: String,
pub(crate) state_key: String,
}
/// Information about a state event's edges for state resolution
struct StateResolutionEdges {
pub(crate) struct StateResolutionEdges {
/// This event's ID
event_id: Arc<EventId>,
pub(crate) event_id: Arc<EventId>,
/// The ID of the room this event belongs to
room_id: Arc<RoomId>,
pub(crate) room_id: Arc<RoomId>,
/// This event's `auth_events`
auth_events: Vec<Arc<EventId>>,
pub(crate) auth_events: Vec<Arc<EventId>>,
}
/// Get the state of all rooms, or the selected subset
pub(crate) fn get_room_states(
db: &KeyValueDatabase,
select: &[OwnedRoomId],
) -> Vec<StateEvent> {
) -> BTreeMap<OwnedRoomId, Value> {
let state_resolution_edges = get_state_resolution_edges(db, select);
let graphs = get_state_event_graphs(db, state_resolution_edges);
let states = resolve_room_states(db, graphs);
let mut serializable_state = Vec::new();
for (_, state) in states {
serializable_state.extend(state.into_iter().filter_map(
|((kind, key), event)| {
let event = match get_either_pdu_by_event_id(db, &event) {
Ok(Some(x)) => Arc::new(x),
Ok(None) => {
t::warn!(event_id = %event, "Unknown event, omitting");
return None;
}
Err(error) => {
t::warn!(
%error,
event_id = %event,
"Failed to get event, omitting",
);
return None;
}
};
let x = StateEvent {
kind,
key,
event,
};
Some(x)
},
));
}
serializable_state.sort_unstable();
serializable_state
resolve_room_states(db, graphs)
}
/// Resolve the current state of all rooms
@ -122,7 +86,7 @@ pub(crate) fn get_room_states(
fn resolve_room_states(
db: &KeyValueDatabase,
graphs: HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>,
) -> HashMap<Arc<RoomId>, StateMap<Arc<EventId>>> {
) -> BTreeMap<OwnedRoomId, Value> {
let span = t::Span::current();
let todo = AtomicUsize::new(graphs.len());
@ -133,18 +97,18 @@ fn resolve_room_states(
.filter_map(|(room_id, graph)| {
let _enter = span.enter();
let state_map = resolve_room_state(db, graph, &room_id);
let value = resolve_room_state(db, graph, &room_id);
todo.fetch_sub(1, Ordering::SeqCst);
t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining");
state_map.map(|x| (room_id, x))
value.map(|x| (OwnedRoomId::from(room_id), x))
})
.fold(HashMap::new, |mut acc, (room_id, state_map)| {
.fold(BTreeMap::new, |mut acc, (room_id, state_map)| {
acc.insert(room_id, state_map);
acc
})
.reduce(HashMap::new, |mut x, y| {
.reduce(BTreeMap::new, |mut x, y| {
// Room IDs should be unique per chunk so there should be no
// collisions when combining this way
x.extend(y);
@ -222,9 +186,9 @@ fn get_room_version(
#[t::instrument(skip(db, graph))]
fn resolve_room_state(
db: &KeyValueDatabase,
graph: DiGraph<StateResolutionNode>,
mut graph: DiGraph<StateResolutionNode>,
room_id: &RoomId,
) -> Option<StateMap<Arc<EventId>>> {
) -> Option<Value> {
let Some(room_version) = get_room_version(db, &graph) else {
t::error!("Couldn't get room version, skipping this room");
return None;
@ -263,16 +227,90 @@ fn resolve_room_state(
|event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(),
);
match res {
Ok(x) => {
t::info!("Done");
Some(x)
}
let state_map = match res {
Ok(x) => x,
Err(error) => {
t::error!(%error, "Failed");
None
t::error!(%error, "State resolution failed");
return None;
}
}
};
let resolved_state = state_map
.iter()
.map(|((kind, key), event_id)| State {
kind: kind.clone(),
key: key.clone(),
event_id: event_id.clone(),
})
.collect();
let state_indices = state_map
.values()
.map(|x| {
graph
.node_indices()
.find(|i| x == &graph[*i].event_id)
.expect("event found by stateres should be in input map")
})
.collect::<HashSet<_>>();
let mut auth_chain_indices = HashSet::new();
depth_first_search(
Reversed(&graph),
state_indices.iter().copied(),
|event| {
if let DfsEvent::Discover(node_index, _) = event {
auth_chain_indices.insert(node_index);
}
},
);
let accepted_indices = std::iter::empty()
.chain(auth_chain_indices.iter())
.chain(state_indices.iter())
.copied()
.collect::<HashSet<_>>();
// Remove rejected events from the graph
graph.retain_nodes(|_, i| accepted_indices.contains(&i));
let pdus = graph
.node_indices()
.filter_map(|x| {
let event_id = &graph[x].event_id;
match get_either_pdu_by_event_id(db, event_id) {
Ok(Some(x)) => Some(Arc::new(x)),
Ok(None) => {
t::warn!(%event_id, "Missing PDU");
None
}
Err(error) => {
t::warn!(%error, %event_id, "Failed to get PDU");
None
}
}
})
.collect::<BTreeSet<_>>();
let edges = graph
.raw_edges()
.iter()
.map(|x| {
let from = graph[x.source()].event_id.clone();
let to = graph[x.target()].event_id.clone();
Edge {
from,
to,
}
})
.collect();
Some(Value {
resolved_state,
pdus,
edges,
})
}
/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database
@ -442,7 +480,7 @@ fn get_state_resolution_edges<'a>(
/// The edges in the directed graph go from an event's `auth_events` to that
/// event.
#[t::instrument(skip(db, state_resolution_edges))]
fn get_state_event_graphs<I>(
pub(crate) fn get_state_event_graphs<I>(
db: &KeyValueDatabase,
state_resolution_edges: I,
) -> HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>