add option to recompute room state

Hopefully we can make use of this state resolution code for more stuff
in the future.
This commit is contained in:
Charles Hall 2024-11-02 18:26:00 -07:00
parent 33598a79b7
commit b1e14fad5c
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
5 changed files with 629 additions and 1 deletions

63
Cargo.lock generated
View file

@ -525,6 +525,31 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crypto-common"
version = "0.1.6"
@ -734,6 +759,12 @@ version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "fixedbitset"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.35"
@ -915,12 +946,14 @@ dependencies = [
"opentelemetry-prometheus",
"opentelemetry_sdk",
"parking_lot",
"petgraph",
"phf",
"pin-project-lite",
"predicates",
"prometheus",
"proxy-header",
"rand",
"rayon",
"regex",
"reqwest",
"ring",
@ -1982,6 +2015,16 @@ dependencies = [
"sha2",
]
[[package]]
name = "petgraph"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.6.0",
]
[[package]]
name = "phf"
version = "0.11.2"
@ -2299,6 +2342,26 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "redox_syscall"
version = "0.5.7"

View file

@ -111,11 +111,13 @@ opentelemetry-otlp = "0.17.0"
opentelemetry-prometheus = "0.17.0"
opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio"] }
parking_lot = { version = "0.12.3", optional = true }
petgraph = "0.6.5"
phf = { version = "0.11.2", features = ["macros"] }
pin-project-lite = "0.2.15"
prometheus = "0.13.4"
proxy-header = { version = "0.1.2", features = ["tokio"] }
rand = "0.8.5"
rayon = "1.10.0"
regex = "1.11.1"
reqwest = { version = "0.12.9", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }
ring = "0.17.8"

View file

@ -99,6 +99,14 @@ pub(crate) struct GetRoomStatesArgs {
#[clap(flatten)]
pub(crate) config: ConfigArg,
/// Recompute the room state instead of getting it from the caches
///
/// Note that this can be VERY SLOW for some rooms. `#matrix:matrix.org`
/// has been seen to take around 30 minutes to solve on a 7950X, for
/// example.
#[clap(long, short)]
pub(crate) recompute: bool,
#[clap(flatten)]
observability: ObservabilityArgs,
}

View file

@ -13,6 +13,7 @@ use crate::{
};
mod cache;
mod recompute;
/// Serializable information about a state event
#[derive(Serialize, PartialEq, Eq)]
@ -63,7 +64,11 @@ pub(crate) async fn run(
db.apply_migrations().await.map_err(Error::Database)?;
let room_states = cache::get_room_states().await;
let room_states = if args.recompute {
recompute::get_room_states(db)
} else {
cache::get_room_states().await
};
serde_json::to_writer(std::io::stdout(), &room_states)?;

View file

@ -0,0 +1,550 @@
//! Get room states by recomputing them
use std::{
collections::{HashMap, HashSet},
error::Error,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use petgraph::{
graph::NodeIndex,
visit::{depth_first_search, DfsEvent, Reversed},
Direction,
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use ruma::{
events::{StateEventType, TimelineEventType},
state_res::StateMap,
EventId, RoomId, RoomVersionId, UserId,
};
use serde::Deserialize;
use tracing as t;
use super::StateEvent;
use crate::{database::KeyValueDatabase, utils, PduEvent};
/// A weightless unit type to use for graphs with unweighted edges
struct WeightlessEdge;
/// A directed graph with unweighted edges
type DiGraph<T> = petgraph::graph::DiGraph<T, WeightlessEdge>;
/// Extracts the `.content.room_version` PDU field
#[derive(Deserialize)]
struct ExtractRoomVersion {
/// Room version
room_version: RoomVersionId,
}
/// Information about a state event node for state resolution
#[derive(Clone)]
struct StateResolutionNode {
/// This event's ID
event_id: Arc<EventId>,
/// This event's type
kind: StateEventType,
/// This event's `state_key`
state_key: String,
}
/// Information about a state event's edges for state resolution
struct StateResolutionEdges {
/// This event's ID
event_id: Arc<EventId>,
/// The ID of the room this event belongs to
room_id: Arc<RoomId>,
/// This event's `auth_events`
auth_events: Vec<Arc<EventId>>,
}
/// Get the state of all rooms
pub(crate) fn get_room_states(db: &KeyValueDatabase) -> Vec<StateEvent> {
let state_resolution_edges = get_state_resolution_edges(db);
let graphs = get_state_event_graphs(db, state_resolution_edges);
let states = resolve_room_states(db, graphs);
let mut serializable_state = Vec::new();
for (_, state) in states {
serializable_state.extend(state.into_iter().filter_map(
|((kind, key), event)| {
let event = match get_either_pdu_by_event_id(db, &event) {
Ok(Some(x)) => Arc::new(x),
Ok(None) => {
t::warn!(event_id = %event, "Unknown event, omitting");
return None;
}
Err(error) => {
t::warn!(
%error,
event_id = %event,
"Failed to get event, omitting",
);
return None;
}
};
let x = StateEvent {
kind,
key,
event,
};
Some(x)
},
));
}
serializable_state.sort_unstable();
serializable_state
}
/// Resolve the current state of all rooms
///
/// # Arguments
///
/// * `db`: The database.
/// * `graphs`: A map of room IDs to a directed graph of their state events. The
/// edges in the graph must go from an event's `auth_events` to that event.
#[t::instrument(skip(db, graphs))]
fn resolve_room_states(
db: &KeyValueDatabase,
graphs: HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>,
) -> HashMap<Arc<RoomId>, StateMap<Arc<EventId>>> {
let span = t::Span::current();
let todo = AtomicUsize::new(graphs.len());
// State resolution is slow so parallelize it by room
graphs
.into_par_iter()
.filter_map(|(room_id, graph)| {
let _enter = span.enter();
let state_map = resolve_room_state(db, graph, &room_id);
todo.fetch_sub(1, Ordering::SeqCst);
t::info!(count = todo.load(Ordering::SeqCst), "Rooms remaining");
state_map.map(|x| (room_id, x))
})
.fold(HashMap::new, |mut acc, (room_id, state_map)| {
acc.insert(room_id, state_map);
acc
})
.reduce(HashMap::new, |mut x, y| {
// Room IDs should be unique per chunk so there should be no
// collisions when combining this way
x.extend(y);
x
})
}
/// Find a single `m.room.create` event and return its room version
fn get_room_version(
db: &KeyValueDatabase,
graph: &DiGraph<StateResolutionNode>,
) -> Option<RoomVersionId> {
let mut create_events = HashSet::new();
for i in graph.node_indices() {
let n = &graph[i];
if n.kind != StateEventType::RoomCreate {
continue;
}
create_events.insert(n.event_id.clone());
}
if create_events.len() > 1 {
t::warn!(
"Multiple `m.room.create` events in graph, rejecting all of them",
);
return None;
}
let Some(event_id) = create_events.into_iter().next() else {
t::warn!("No `m.room.create` event in graph");
return None;
};
let pdu = match get_either_pdu_by_event_id(db, &event_id) {
Ok(Some(x)) => x,
Ok(None) => {
t::error!("No `m.room.create` PDU in database");
return None;
}
Err(error) => {
t::error!(
%error,
"Failed to get `m.room.create` PDU from database",
);
return None;
}
};
let content =
match serde_json::from_str::<ExtractRoomVersion>(pdu.content.get()) {
Ok(x) => x,
Err(error) => {
t::error!(
%error,
%event_id,
"Failed to extract `.content.room_version` from PDU",
);
return None;
}
};
Some(content.room_version)
}
/// Resolve the state of a single room
///
/// # Arguments
///
/// * `db`: The database.
/// * `graphs`: A map of room IDs to a directed graph of their state events. The
/// edges in the graph must go from an event's `auth_events` to that event.
/// * `room_id`: The ID of the room being resolved. Only used for tracing.
#[t::instrument(skip(db, graph))]
fn resolve_room_state(
db: &KeyValueDatabase,
graph: DiGraph<StateResolutionNode>,
room_id: &RoomId,
) -> Option<StateMap<Arc<EventId>>> {
let Some(room_version) = get_room_version(db, &graph) else {
t::error!("Couldn't get room version, skipping this room");
return None;
};
let (state_sets, auth_chain_sets) = graph
.node_indices()
.map(|x| {
let StateResolutionNode {
event_id,
kind,
state_key,
} = graph[x].clone();
// At this point, we don't know any valid groupings of state events,
// so handle each state event by itself by making a separate map for
// each one.
let mut state_set = StateMap::new();
state_set.insert((kind, state_key), event_id);
let mut auth_chain_set = HashSet::new();
depth_first_search(Reversed(&graph), [x], |event| {
if let DfsEvent::Discover(node_index, _) = event {
auth_chain_set.insert(graph[node_index].event_id.clone());
}
});
(state_set, auth_chain_set)
})
.collect::<(Vec<_>, Vec<_>)>();
let res = ruma::state_res::resolve(
&room_version,
&state_sets,
auth_chain_sets,
|event_id| get_either_pdu_by_event_id(db, event_id).ok().flatten(),
);
match res {
Ok(x) => {
t::info!("Done");
Some(x)
}
Err(error) => {
t::error!(%error, "Failed");
None
}
}
}
/// Look up an accepted [`PduEvent`] by its [`EventId`] in the database
#[t::instrument(skip(db))]
fn get_accepted_pdu_by_event_id(
db: &KeyValueDatabase,
event_id: &EventId,
) -> Result<Option<PduEvent>, Box<dyn Error>> {
let event_id_key = event_id.as_bytes();
let Some(pdu_key) =
db.eventid_pduid.get(event_id_key).inspect_err(|error| {
t::error!(
%error,
key = utils::u8_slice_to_hex(event_id_key),
"Failed to get PDU ID for event ID",
);
})?
else {
// This is normal, perhaps the requested event is an outlier
t::debug!(
key = utils::u8_slice_to_hex(event_id_key),
"No PDU ID for event ID",
);
return Ok(None);
};
let Some(pdu_bytes) = db.pduid_pdu.get(&pdu_key).inspect_err(|error| {
t::error!(
%error,
key = utils::u8_slice_to_hex(&pdu_key),
"Failed to get PDU for PDU ID",
);
})?
else {
t::error!(key = utils::u8_slice_to_hex(&pdu_key), "No PDU for PDU ID");
return Err("No PDU for PDU ID".into());
};
let pdu = serde_json::from_slice::<PduEvent>(&pdu_bytes).inspect_err(
|error| {
t::error!(
%error,
key = utils::u8_slice_to_hex(&pdu_key),
"Failed to deserialize PDU",
);
},
)?;
Ok(Some(pdu))
}
/// Look up an outlier [`PduEvent`] by its [`EventId`] in the database
#[t::instrument(skip(db))]
fn get_outlier_pdu_by_event_id(
db: &KeyValueDatabase,
event_id: &EventId,
) -> Result<Option<PduEvent>, Box<dyn Error>> {
let event_id_key = event_id.as_bytes();
let Some(pdu_bytes) =
db.eventid_outlierpdu.get(event_id_key).inspect_err(|error| {
t::error!(
%error,
key = utils::u8_slice_to_hex(event_id_key),
"Failed to PDU for event ID",
);
})?
else {
// This is normal, perhaps we just don't have this event
t::debug!(
key = utils::u8_slice_to_hex(event_id_key),
"No PDU for event ID"
);
return Ok(None);
};
let pdu = serde_json::from_slice::<PduEvent>(&pdu_bytes).inspect_err(
|error| {
t::error!(
%error,
key = utils::u8_slice_to_hex(event_id_key),
"Failed to deserialize PDU",
);
},
)?;
Ok(Some(pdu))
}
/// Look up an accepted/outlier [`PduEvent`] by its [`EventId`] in the database
///
/// Tries searching accepted PDUs first, then outlier PDUs.
#[t::instrument(skip(db))]
fn get_either_pdu_by_event_id(
db: &KeyValueDatabase,
event_id: &EventId,
) -> Result<Option<PduEvent>, Box<dyn Error>> {
let Some(pdu) = get_accepted_pdu_by_event_id(db, event_id)? else {
return get_outlier_pdu_by_event_id(db, event_id);
};
Ok(Some(pdu))
}
/// Get the edges in the graph of state events for state resolution
#[t::instrument(skip(db))]
fn get_state_resolution_edges(
db: &KeyValueDatabase,
) -> impl Iterator<Item = StateResolutionEdges> + '_ {
let filter_map = |key: Vec<u8>, value: Vec<u8>, map: &str| {
let Ok(pdu) = serde_json::from_slice::<PduEvent>(&value) else {
t::error!(
%map,
key = utils::u8_slice_to_hex(&key),
"Failed to deserialize PDU, skipping it",
);
return None;
};
let Some(state_key) = pdu.state_key else {
// Filter out non-state events
return None;
};
// Ruma fails the entire state resolution if it sees an event whose
// `state_key` it needs to parse as an ID and the parsing fails, so
// we pre-emptively drop these events. Real-world cases where this
// has happened are detailed on the issue linked below.
//
// <https://github.com/ruma/ruma/issues/1944>
if pdu.kind == TimelineEventType::RoomMember
&& <&UserId>::try_from(&*state_key).is_err()
{
t::warn!(
%map,
key = utils::u8_slice_to_hex(&key),
event_id = %pdu.event_id,
"Dropping event that could cause state resolution to fail",
);
return None;
}
Some(StateResolutionEdges {
event_id: pdu.event_id,
room_id: <Arc<RoomId>>::from(pdu.room_id),
auth_events: pdu.auth_events,
})
};
std::iter::empty()
.chain(db.pduid_pdu.iter().filter_map(move |(key, value)| {
filter_map(key, value, "pduid_pdu")
}))
.chain(db.eventid_outlierpdu.iter().filter_map(move |(key, value)| {
filter_map(key, value, "eventid_outlierpdu")
}))
}
/// Get a directed graph of state events for all rooms
///
/// The edges in the directed graph go from an event's `auth_events` to that
/// event.
#[t::instrument(skip(db, state_resolution_edges))]
fn get_state_event_graphs<I>(
db: &KeyValueDatabase,
state_resolution_edges: I,
) -> HashMap<Arc<RoomId>, DiGraph<StateResolutionNode>>
where
I: Iterator<Item = StateResolutionEdges>,
{
// Avoid inserting the same event ID twice; instead, mutate its edges
let mut visited = HashMap::<(Arc<RoomId>, Arc<EventId>), NodeIndex>::new();
// Graph of event IDs
let mut graphs = HashMap::<Arc<RoomId>, DiGraph<Arc<EventId>>>::new();
for StateResolutionEdges {
event_id,
room_id,
auth_events,
} in state_resolution_edges
{
let graph = graphs.entry(room_id.clone()).or_default();
let this = *visited
.entry((room_id.clone(), event_id.clone()))
.or_insert_with(|| graph.add_node(event_id));
for event_id in auth_events {
let prev = *visited
.entry((room_id.clone(), event_id.clone()))
.or_insert_with(|| graph.add_node(event_id));
graph.add_edge(prev, this, WeightlessEdge);
}
}
// Convert the graph of event IDs into one of `StateNode`s
let graphs = graphs
.into_iter()
.map(|(room_id, graph)| {
let graph = graph.filter_map(
|index, event_id| {
to_state_resolution_node(db, index, event_id, &graph)
},
|_, &WeightlessEdge| Some(WeightlessEdge),
);
(room_id, graph)
})
.collect::<HashMap<_, _>>();
// Record some stats for fun
t::info!(
rooms = graphs.len(),
total_nodes = graphs.iter().fold(0, |mut acc, x| {
acc += x.1.node_indices().count();
acc
}),
total_edges = graphs.iter().fold(0, |mut acc, x| {
acc += x.1.edge_indices().count();
acc
}),
"Done",
);
graphs
}
/// `filter_map` a graph of [`EventId`] nodes to one of [`StateResolutionNode`]s
fn to_state_resolution_node(
db: &KeyValueDatabase,
index: NodeIndex,
event_id: &Arc<EventId>,
graph: &DiGraph<Arc<EventId>>,
) -> Option<StateResolutionNode> {
let pdu = match get_either_pdu_by_event_id(db, event_id) {
Ok(Some(x)) => x,
Ok(None) => {
let mut dependents = Vec::new();
for i in graph.neighbors_directed(index, Direction::Outgoing) {
dependents.push(&graph[i]);
}
t::warn!(
?dependents,
missing = %event_id,
"Missing `auth_event` depended on by one or more state events",
);
return None;
}
Err(error) => {
t::warn!(%error, %event_id, "Failed to get PDU by event ID");
return None;
}
};
let Some(state_key) = pdu.state_key else {
let mut dependents = Vec::new();
for i in graph.neighbors_directed(index, Direction::Outgoing) {
dependents.push(&graph[i]);
}
t::warn!(
?dependents,
not_state = %event_id,
"Event depended on by one or more state events isn't a state event",
);
return None;
};
Some(StateResolutionNode {
event_id: event_id.clone(),
kind: pdu.kind.to_string().into(),
state_key,
})
}