mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 23:31:24 +01:00
550 lines
16 KiB
Rust
550 lines
16 KiB
Rust
//! Get room states by recomputing them
|
|
|
|
use std::{
|
|
collections::{HashMap, HashSet},
|
|
error::Error,
|
|
sync::{
|
|
atomic::{AtomicUsize, Ordering},
|
|
Arc,
|
|
},
|
|
};
|
|
|
|
use petgraph::{
|
|
graph::NodeIndex,
|
|
visit::{depth_first_search, DfsEvent, Reversed},
|
|
Direction,
|
|
};
|
|
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
|
use ruma::{
|
|
events::{StateEventType, TimelineEventType},
|
|
state_res::StateMap,
|
|
EventId, RoomId, RoomVersionId, UserId,
|
|
};
|
|
use serde::Deserialize;
|
|
use tracing as t;
|
|
|
|
use super::StateEvent;
|
|
use crate::{database::KeyValueDatabase, utils, PduEvent};
|
|
|
|
/// A weightless unit type to use for graphs with unweighted edges
|
|
struct WeightlessEdge;
|
|
|
|
/// A directed graph with unweighted edges
|
|
type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>;
|
|
|
|
/// Extracts the `.content.room_version` PDU field
|
|
#[derive(Deserialize)]
|
|
struct ExtractRoomVersion {
|
|
/// Room version
|
|
room_version: RoomVersionId,
|
|
}
|
|
|
|
/// Information about a state event node for state resolution
|
|
#[derive(Clone)]
|
|
struct StateResolutionNode {
|
|
/// This event's ID
|
|
event_id: Arc<EventId>,
|
|
|
|
/// This event's type
|
|
kind: StateEventType,
|
|
|
|
/// This event's `state_key`
|
|
state_key: String,
|
|
}
|
|
|
|
/// Information about a state event's edges for state resolution
|
|
struct StateResolutionEdges {
|
|
/// This event's ID
|
|
event_id: Arc<EventId>,
|
|
|
|
/// The ID of the room this event belongs to
|
|
room_id: Arc<RoomId>,
|
|
|
|
/// This event's `auth_events`
|
|
auth_events: Vec<Arc<EventId>>,
|
|
}
|
|
|
|
/// Get the state of all rooms
|
|
pub(crate) fn get_room_states(db: &KeyValueDatabase) -> Vec<StateEvent> {
|
|
let state_resolution_edges = get_state_resolution_edges(db);
|
|
let graphs = get_state_event_graphs(db, state_resolution_edges);
|
|
|
|
let states = resolve_room_states(db, graphs);
|
|
|
|
let mut serializable_state = Vec::new();
|
|
|
|
for (_, state) in states {
|
|
serializable_state.extend(state.into_iter().filter_map(
|
|
|((kind, key), event)| {
|
|
let event = match get_either_pdu_by_event_id(db, &event) {
|
|
Ok(Some(x)) => Arc::new(x),
|
|
Ok(None) => {
|
|
t::warn!(event_id = %event, "Unknown event, omitting");
|
|
return None;
|
|
}
|
|
Err(error) => {
|
|
t::warn!(
|
|
%error,
|
|
event_id = %event,
|
|
"Failed to get event, omitting",
|
|
);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let x = StateEvent {
|
|
kind,
|
|
key,
|
|
event,
|
|
};
|
|
|
|
Some(x)
|
|
},
|
|
));
|
|
}
|
|
|
|
serializable_state.sort_unstable();
|
|
|
|
serializable_state
|
|
}
|
|
|
|
/// Resolve the current state of all rooms
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `db`: The database.
|
|
/// * `graphs`: A map of room IDs to a directed graph of their state events. The
|
|
/// edges in the graph must go from an event's `auth_events` to that event.
|
|
#[t::instrument(skip(db, graphs))]
|
|
fn resolve_room_states(
|
|
db: &KeyValueDatabase,
|
|
graphs: HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>,
|
|
) -> HashMap<Arc<RoomId>, StateMap<Arc<EventId>>> {
|
|
let span = t::Span::current();
|
|
|
|
let todo = AtomicUsize::new(graphs.len());
|
|
|
|
// State resolution is slow so parallelize it by room
|
|
graphs
|
|
.into_par_iter()
|
|
.filter_map(|(room_id, graph)| {
|
|
let _enter = span.enter();
|
|
|
|
let state_map = resolve_room_state(db, graph, &room_id);
|
|
|
|
todo.fetch_sub(1, Ordering::SeqCst);
|
|
t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining");
|
|
|
|
state_map.map(|x| (room_id, x))
|
|
})
|
|
.fold(HashMap::new, |mut acc, (room_id, state_map)| {
|
|
acc.insert(room_id, state_map);
|
|
acc
|
|
})
|
|
.reduce(HashMap::new, |mut x, y| {
|
|
// Room IDs should be unique per chunk so there should be no
|
|
// collisions when combining this way
|
|
x.extend(y);
|
|
x
|
|
})
|
|
}
|
|
|
|
/// Find a single `m.room.create` event and return its room version
|
|
fn get_room_version(
|
|
db: &KeyValueDatabase,
|
|
graph: &DiGraph<StateResolutionNode>,
|
|
) -> Option<RoomVersionId> {
|
|
let mut create_events = HashSet::new();
|
|
for i in graph.node_indices() {
|
|
let n = &graph[i];
|
|
|
|
if n.kind != StateEventType::RoomCreate {
|
|
continue;
|
|
}
|
|
|
|
create_events.insert(n.event_id.clone());
|
|
}
|
|
|
|
if create_events.len() > 1 {
|
|
t::warn!(
|
|
"Multiple `m.room.create` events in graph, rejecting all of them",
|
|
);
|
|
return None;
|
|
}
|
|
|
|
let Some(event_id) = create_events.into_iter().next() else {
|
|
t::warn!("No `m.room.create` event in graph");
|
|
return None;
|
|
};
|
|
|
|
let pdu = match get_either_pdu_by_event_id(db, &event_id) {
|
|
Ok(Some(x)) => x,
|
|
Ok(None) => {
|
|
t::error!("No `m.room.create` PDU in database");
|
|
return None;
|
|
}
|
|
Err(error) => {
|
|
t::error!(
|
|
%error,
|
|
"Failed to get `m.room.create` PDU from database",
|
|
);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let content =
|
|
match serde_json::from_str::<ExtractRoomVersion>(pdu.content.get()) {
|
|
Ok(x) => x,
|
|
Err(error) => {
|
|
t::error!(
|
|
%error,
|
|
%event_id,
|
|
"Failed to extract `.content.room_version` from PDU",
|
|
);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
Some(content.room_version)
|
|
}
|
|
|
|
/// Resolve the state of a single room
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `db`: The database.
|
|
/// * `graphs`: A map of room IDs to a directed graph of their state events. The
|
|
/// edges in the graph must go from an event's `auth_events` to that event.
|
|
/// * `room_id`: The ID of the room being resolved. Only used for tracing.
|
|
#[t::instrument(skip(db, graph))]
|
|
fn resolve_room_state(
|
|
db: &KeyValueDatabase,
|
|
graph: DiGraph<StateResolutionNode>,
|
|
room_id: &RoomId,
|
|
) -> Option<StateMap<Arc<EventId>>> {
|
|
let Some(room_version) = get_room_version(db, &graph) else {
|
|
t::error!("Couldn't get room version, skipping this room");
|
|
return None;
|
|
};
|
|
|
|
let (state_sets, auth_chain_sets) = graph
|
|
.node_indices()
|
|
.map(|x| {
|
|
let StateResolutionNode {
|
|
event_id,
|
|
kind,
|
|
state_key,
|
|
} = graph[x].clone();
|
|
|
|
// At this point, we don't know any valid groupings of state events,
|
|
// so handle each state event by itself by making a separate map for
|
|
// each one.
|
|
let mut state_set = StateMap::new();
|
|
state_set.insert((kind, state_key), event_id);
|
|
|
|
let mut auth_chain_set = HashSet::new();
|
|
depth_first_search(Reversed(&graph), [x], |event| {
|
|
if let DfsEvent::Discover(node_index, _) = event {
|
|
auth_chain_set.insert(graph[node_index].event_id.clone());
|
|
}
|
|
});
|
|
|
|
(state_set, auth_chain_set)
|
|
})
|
|
.collect::<(Vec<_>, Vec<_>)>();
|
|
|
|
let res = ruma::state_res::resolve(
|
|
&room_version,
|
|
&state_sets,
|
|
auth_chain_sets,
|
|
|event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(),
|
|
);
|
|
|
|
match res {
|
|
Ok(x) => {
|
|
t::info!("Done");
|
|
Some(x)
|
|
}
|
|
Err(error) => {
|
|
t::error!(%error, "Failed");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database
|
|
#[t::instrument(skip(db))]
|
|
fn get_accepted_pdu_by_event_id(
|
|
db: &KeyValueDatabase,
|
|
event_id: &EventId,
|
|
) -> Result<Option<PduEvent>, Box<dyn Error>> {
|
|
let event_id_key = event_id.as_bytes();
|
|
|
|
let Some(pdu_key) =
|
|
db.eventid_pduid.get(event_id_key).inspect_err(|error| {
|
|
t::error!(
|
|
%error,
|
|
key = utils::u8_slice_to_hex(event_id_key),
|
|
"Failed to get PDU ID for event ID",
|
|
);
|
|
})?
|
|
else {
|
|
// This is normal, perhaps the requested event is an outlier
|
|
t::debug!(
|
|
key = utils::u8_slice_to_hex(event_id_key),
|
|
"No PDU ID for event ID",
|
|
);
|
|
return Ok(None);
|
|
};
|
|
|
|
let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| {
|
|
t::error!(
|
|
%error,
|
|
key = utils::u8_slice_to_hex(&pdu_key),
|
|
"Failed to get PDU for PDU ID",
|
|
);
|
|
})?
|
|
else {
|
|
t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID");
|
|
return Err("No PDU for PDU ID".into());
|
|
};
|
|
|
|
let pdu = serde_json::from_slice::<PduEvent>(&pdu_bytes).inspect_err(
|
|
|error| {
|
|
t::error!(
|
|
%error,
|
|
key = utils::u8_slice_to_hex(&pdu_key),
|
|
"Failed to deserialize PDU",
|
|
);
|
|
},
|
|
)?;
|
|
|
|
Ok(Some(pdu))
|
|
}
|
|
|
|
/// Look up an outlier [`PduEvent`] by its [`EventId`] in the database
|
|
#[t::instrument(skip(db))]
|
|
fn get_outlier_pdu_by_event_id(
|
|
db: &KeyValueDatabase,
|
|
event_id: &EventId,
|
|
) -> Result<Option<PduEvent>, Box<dyn Error>> {
|
|
let event_id_key = event_id.as_bytes();
|
|
|
|
let Some(pdu_bytes) =
|
|
db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| {
|
|
t::error!(
|
|
%error,
|
|
key = utils::u8_slice_to_hex(event_id_key),
|
|
"Failed to PDU for event ID",
|
|
);
|
|
})?
|
|
else {
|
|
// This is normal, perhaps we just don't have this event
|
|
t::debug!(
|
|
key = utils::u8_slice_to_hex(event_id_key),
|
|
"No PDU for event ID"
|
|
);
|
|
return Ok(None);
|
|
};
|
|
|
|
let pdu = serde_json::from_slice::<PduEvent>(&pdu_bytes).inspect_err(
|
|
|error| {
|
|
t::error!(
|
|
%error,
|
|
key = utils::u8_slice_to_hex(event_id_key),
|
|
"Failed to deserialize PDU",
|
|
);
|
|
},
|
|
)?;
|
|
|
|
Ok(Some(pdu))
|
|
}
|
|
|
|
/// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database
|
|
///
|
|
/// Tries searching accepted PDUs first, then outlier PDUs.
|
|
#[t::instrument(skip(db))]
|
|
fn get_either_pdu_by_event_id(
|
|
db: &KeyValueDatabase,
|
|
event_id: &EventId,
|
|
) -> Result<Option<PduEvent>, Box<dyn Error>> {
|
|
let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else {
|
|
return get_outlier_pdu_by_event_id(db, event_id);
|
|
};
|
|
|
|
Ok(Some(pdu))
|
|
}
|
|
|
|
/// Get the edges in the graph of state events for state resolution
|
|
#[t::instrument(skip(db))]
|
|
fn get_state_resolution_edges(
|
|
db: &KeyValueDatabase,
|
|
) -> impl Iterator<Item = StateResolutionEdges> + '_ {
|
|
let filter_map = |key: Vec<u8>, value: Vec<u8>, map: &str| {
|
|
let Ok(pdu) = serde_json::from_slice::<PduEvent>(&value) else {
|
|
t::error!(
|
|
%map,
|
|
key = utils::u8_slice_to_hex(&key),
|
|
"Failed to deserialize PDU, skipping it",
|
|
);
|
|
return None;
|
|
};
|
|
|
|
let Some(state_key) = pdu.state_key else {
|
|
// Filter out non-state events
|
|
return None;
|
|
};
|
|
|
|
// Ruma fails the entire state resolution if it sees an event whose
|
|
// `state_key` it needs to parse as an ID and the parsing fails, so
|
|
// we pre-emptively drop these events. Real-world cases where this
|
|
// has happened are detailed on the issue linked below.
|
|
//
|
|
// <https://github.com/ruma/ruma/issues/1944>
|
|
if pdu.kind == TimelineEventType::RoomMember
|
|
&& <&UserId>::try_from(&*state_key).is_err()
|
|
{
|
|
t::warn!(
|
|
%map,
|
|
key = utils::u8_slice_to_hex(&key),
|
|
event_id = %pdu.event_id,
|
|
"Dropping event that could cause state resolution to fail",
|
|
);
|
|
return None;
|
|
}
|
|
|
|
Some(StateResolutionEdges {
|
|
event_id: pdu.event_id,
|
|
room_id: <Arc<RoomId>>::from(pdu.room_id),
|
|
auth_events: pdu.auth_events,
|
|
})
|
|
};
|
|
|
|
std::iter::empty()
|
|
.chain(db.pduid_pdu.iter().filter_map(move |(key, value)| {
|
|
filter_map(key, value, "pduid_pdu")
|
|
}))
|
|
.chain(db.eventid_outlierpdu.iter().filter_map(move |(key, value)| {
|
|
filter_map(key, value, "eventid_outlierpdu")
|
|
}))
|
|
}
|
|
|
|
/// Get a directed graph of state events for all rooms
|
|
///
|
|
/// The edges in the directed graph go from an event's `auth_events` to that
|
|
/// event.
|
|
#[t::instrument(skip(db, state_resolution_edges))]
|
|
fn get_state_event_graphs<I>(
|
|
db: &KeyValueDatabase,
|
|
state_resolution_edges: I,
|
|
) -> HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>
|
|
where
|
|
I: Iterator<Item = StateResolutionEdges>,
|
|
{
|
|
// Avoid inserting the same event ID twice; instead, mutate its edges
|
|
let mut visited = HashMap::<(Arc<RoomId>, Arc<EventId>), NodeIndex>::new();
|
|
|
|
// Graph of event IDs
|
|
let mut graphs = HashMap::<Arc<RoomId>, DiGraph<Arc<EventId>>>::new();
|
|
|
|
for StateResolutionEdges {
|
|
event_id,
|
|
room_id,
|
|
auth_events,
|
|
} in state_resolution_edges
|
|
{
|
|
let graph = graphs.entry(room_id.clone()).or_default();
|
|
|
|
let this = *visited
|
|
.entry((room_id.clone(), event_id.clone()))
|
|
.or_insert_with(|| graph.add_node(event_id));
|
|
|
|
for event_id in auth_events {
|
|
let prev = *visited
|
|
.entry((room_id.clone(), event_id.clone()))
|
|
.or_insert_with(|| graph.add_node(event_id));
|
|
|
|
graph.add_edge(prev, this, WeightlessEdge);
|
|
}
|
|
}
|
|
|
|
// Convert the graph of event IDs into one of `StateNode`s
|
|
let graphs = graphs
|
|
.into_iter()
|
|
.map(|(room_id, graph)| {
|
|
let graph = graph.filter_map(
|
|
|index, event_id| {
|
|
to_state_resolution_node(db, index, event_id, &graph)
|
|
},
|
|
|_, &WeightlessEdge| Some(WeightlessEdge),
|
|
);
|
|
|
|
(room_id, graph)
|
|
})
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
// Record some stats for fun
|
|
t::info!(
|
|
rooms = graphs.len(),
|
|
total_nodes = graphs.iter().fold(0, |mut acc, x| {
|
|
acc += x.1.node_indices().count();
|
|
acc
|
|
}),
|
|
total_edges = graphs.iter().fold(0, |mut acc, x| {
|
|
acc += x.1.edge_indices().count();
|
|
acc
|
|
}),
|
|
"Done",
|
|
);
|
|
|
|
graphs
|
|
}
|
|
|
|
/// `filter_map` a graph of [`EventId`] nodes to one of [`StateResolutionNode`]s
|
|
fn to_state_resolution_node(
|
|
db: &KeyValueDatabase,
|
|
index: NodeIndex,
|
|
event_id: &Arc<EventId>,
|
|
graph: &DiGraph<Arc<EventId>>,
|
|
) -> Option<StateResolutionNode> {
|
|
let pdu = match get_either_pdu_by_event_id(db, event_id) {
|
|
Ok(Some(x)) => x,
|
|
Ok(None) => {
|
|
let mut dependents = Vec::new();
|
|
for i in graph.neighbors_directed(index, Direction::Outgoing) {
|
|
dependents.push(&graph[i]);
|
|
}
|
|
|
|
t::warn!(
|
|
?dependents,
|
|
missing = %event_id,
|
|
"Missing `auth_event` depended on by one or more state events",
|
|
);
|
|
|
|
return None;
|
|
}
|
|
Err(error) => {
|
|
t::warn!(%error, %event_id, "Failed to get PDU by event ID");
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let Some(state_key) = pdu.state_key else {
|
|
let mut dependents = Vec::new();
|
|
for i in graph.neighbors_directed(index, Direction::Outgoing) {
|
|
dependents.push(&graph[i]);
|
|
}
|
|
t::warn!(
|
|
?dependents,
|
|
not_state = %event_id,
|
|
"Event depended on by one or more state events isn't a state event",
|
|
);
|
|
return None;
|
|
};
|
|
|
|
Some(StateResolutionNode {
|
|
event_id: event_id.clone(),
|
|
kind: pdu.kind.to_string().into(),
|
|
state_key,
|
|
})
|
|
}
|