From 33598a79b7e56b237a8f467212dc3370d482398f Mon Sep 17 00:00:00 2001 From: Charles Hall Date: Tue, 22 Oct 2024 12:25:53 -0700 Subject: [PATCH] add command to get the state of all rooms --- src/cli.rs | 20 +++++++++ src/cli/get_room_states.rs | 71 ++++++++++++++++++++++++++++++++ src/cli/get_room_states/cache.rs | 67 ++++++++++++++++++++++++++++++ src/error.rs | 25 +++++++++++ 4 files changed, 183 insertions(+) create mode 100644 src/cli/get_room_states.rs create mode 100644 src/cli/get_room_states/cache.rs diff --git a/src/cli.rs b/src/cli.rs index 99641790..5742986c 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -13,6 +13,7 @@ use crate::{ }; mod check_config; +mod get_room_states; mod serve; /// Command line arguments @@ -33,6 +34,11 @@ pub(crate) enum Command { /// Check the configuration file for syntax and semantic errors. CheckConfig(CheckConfigArgs), + + /// Write the state of all rooms as JSON to stdout. + /// + /// This is primarily useful for debugging. + GetRoomStates(GetRoomStatesArgs), } #[derive(clap::Args)] @@ -88,6 +94,15 @@ pub(crate) struct ServeArgs { pub(crate) config: ConfigArg, } +#[derive(clap::Args)] +pub(crate) struct GetRoomStatesArgs { + #[clap(flatten)] + pub(crate) config: ConfigArg, + + #[clap(flatten)] + observability: ObservabilityArgs, +} + impl Args { pub(crate) async fn run(self) -> Result<(), error::Main> { if let Some((format, filter)) = self.command.cli_observability_args() { @@ -99,6 +114,7 @@ impl Args { Command::CheckConfig(args) => { check_config::run(args.config).await?; } + Command::GetRoomStates(args) => get_room_states::run(args).await?, } Ok(()) } @@ -113,6 +129,10 @@ impl Command { args.observability.log_format, args.observability.log_filter.clone(), )), + Command::GetRoomStates(args) => Some(( + args.observability.log_format, + args.observability.log_filter.clone(), + )), Command::Serve(_) => None, } } diff --git a/src/cli/get_room_states.rs b/src/cli/get_room_states.rs new file mode 100644 index 00000000..b65d5673 --- /dev/null +++ b/src/cli/get_room_states.rs @@ -0,0 +1,71 @@ +#![warn(clippy::missing_docs_in_private_items)] + +//! Implementation of the `get-room-states` command + +use std::{cmp::Ordering, sync::Arc}; + +use ruma::events::StateEventType; +use serde::Serialize; + +use super::GetRoomStatesArgs; +use crate::{ + config, database::KeyValueDatabase, error, services, PduEvent, Services, +}; + +mod cache; + +/// Serializable information about a state event +#[derive(Serialize, PartialEq, Eq)] +struct StateEvent { + /// The kind of state event + kind: StateEventType, + + /// The `state_key` of the event + key: String, + + /// The event itself + event: Arc, +} + +impl Ord for StateEvent { + fn cmp(&self, other: &Self) -> Ordering { + Ordering::Equal + .then_with(|| self.event.room_id.cmp(&other.event.room_id)) + .then_with(|| self.kind.cmp(&other.kind)) + .then_with(|| self.key.cmp(&other.key)) + .then_with(|| self.event.event_id.cmp(&other.event.event_id)) + } +} + +impl PartialOrd for StateEvent { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Subcommand entrypoint +pub(crate) async fn run( + args: GetRoomStatesArgs, +) -> Result<(), error::DumpStateCommand> { + use error::DumpStateCommand as Error; + + let config = config::load(args.config.config.as_ref()).await?; + + let db = Box::leak(Box::new( + KeyValueDatabase::load_or_create(&config).map_err(Error::Database)?, + )); + + Services::new(db, config, None) + .map_err(Error::InitializeServices)? + .install(); + + services().globals.err_if_server_name_changed()?; + + db.apply_migrations().await.map_err(Error::Database)?; + + let room_states = cache::get_room_states().await; + + serde_json::to_writer(std::io::stdout(), &room_states)?; + + Ok(()) +} diff --git a/src/cli/get_room_states/cache.rs b/src/cli/get_room_states/cache.rs new file mode 100644 index 00000000..ba2f8ccf --- /dev/null +++ b/src/cli/get_room_states/cache.rs @@ -0,0 +1,67 @@ +//! Get room states from the caches + +use std::sync::Arc; + +use ruma::{state_res::StateMap, OwnedRoomId}; +use tracing as t; + +use super::StateEvent; +use crate::{services, PduEvent}; + +/// Get the state of all rooms +#[t::instrument] +pub(crate) async fn get_room_states() -> Vec { + let mut serializable_state = Vec::new(); + + for room_id in services().rooms.metadata.iter_ids().filter_map(Result::ok) { + let Some(state) = get_room_state(room_id).await else { + continue; + }; + + serializable_state.extend(state.into_iter().map( + |((kind, key), event)| StateEvent { + kind, + key, + event, + }, + )); + } + + serializable_state.sort_unstable(); + + serializable_state +} + +/// Get the state of the given room +#[t::instrument] +async fn get_room_state( + room_id: OwnedRoomId, +) -> Option>> { + let shortstatehash = + match services().rooms.state.get_room_shortstatehash(&room_id) { + Ok(Some(x)) => x, + Ok(None) => { + t::warn!("No shortstatehash for room"); + return None; + } + Err(error) => { + t::warn!(%error, "Failed to get shortstatehash for room"); + return None; + } + }; + + let state = match services() + .rooms + .state_accessor + .state_full(shortstatehash) + .await + { + Ok(x) => x, + Err(error) => { + t::warn!(%error, "Failed to get full state for room"); + return None; + } + }; + + Some(state) +} diff --git a/src/error.rs b/src/error.rs index ce8ebb37..1b30f24a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,6 +43,9 @@ pub(crate) enum Main { #[error(transparent)] ServeCommand(#[from] ServeCommand), + #[error(transparent)] + DumpStateCommand(#[from] DumpStateCommand), + #[error("failed to install global default tracing subscriber")] SetSubscriber(#[from] tracing::subscriber::SetGlobalDefaultError), @@ -85,6 +88,28 @@ pub(crate) enum CheckConfigCommand { Config(#[from] Config), } +/// Errors returned from the `dump-state` CLI subcommand. +// Missing docs are allowed here since that kind of information should be +// encoded in the error messages themselves anyway. +#[allow(missing_docs)] +#[derive(Error, Debug)] +pub(crate) enum DumpStateCommand { + #[error("failed to load configuration")] + Config(#[from] Config), + + #[error("failed to initialize services")] + InitializeServices(#[source] crate::utils::error::Error), + + #[error("failed to load or create the database")] + Database(#[source] crate::utils::error::Error), + + #[error("`server_name` change check failed")] + ServerNameChanged(#[from] ServerNameChanged), + + #[error("failed to serialize state to stdout")] + Serialize(#[from] serde_json::Error), +} + /// Error generated if `server_name` has changed or if checking this failed // Missing docs are allowed here since that kind of information should be // encoded in the error messages themselves anyway.