From 2ff415a5629f2cdacd3f1fc6e6b46c3e9e602071 Mon Sep 17 00:00:00 2001 From: Charles Hall Date: Mon, 3 Mar 2025 22:09:22 -0800 Subject: [PATCH] wip: add export subcommand Currently exports: * Server name * PDUs * Public signing keys --- src/cli.rs | 23 +++ src/cli/export.rs | 384 ++++++++++++++++++++++++++++++++++++++++++++++ src/error.rs | 34 ++++ 3 files changed, 441 insertions(+) create mode 100644 src/cli/export.rs diff --git a/src/cli.rs b/src/cli.rs index 99641790..dd27ccd0 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -13,6 +13,7 @@ use crate::{ }; mod check_config; +mod export; mod serve; /// Command line arguments @@ -33,6 +34,9 @@ pub(crate) enum Command { /// Check the configuration file for syntax and semantic errors. CheckConfig(CheckConfigArgs), + + /// Export all persistent data. + Export(ExportArgs), } #[derive(clap::Args)] @@ -44,6 +48,18 @@ pub(crate) struct CheckConfigArgs { observability: ObservabilityArgs, } +#[derive(clap::Args)] +pub(crate) struct ExportArgs { + #[clap(flatten)] + config: ConfigArg, + + #[clap(flatten)] + observability: ObservabilityArgs, + + #[clap(short, long)] + out_dir: PathBuf, +} + /// Wrapper for the `--config` arg. /// /// This exists to centralize the `mut_arg` code that sets the help value based @@ -99,6 +115,9 @@ impl Args { Command::CheckConfig(args) => { check_config::run(args.config).await?; } + Command::Export(args) => { + export::run(args).await?; + } } Ok(()) } @@ -113,6 +132,10 @@ impl Command { args.observability.log_format, args.observability.log_filter.clone(), )), + Command::Export(args) => Some(( + args.observability.log_format, + args.observability.log_filter.clone(), + )), Command::Serve(_) => None, } } diff --git a/src/cli/export.rs b/src/cli/export.rs new file mode 100644 index 00000000..7a87acf3 --- /dev/null +++ b/src/cli/export.rs @@ -0,0 +1,384 @@ +use core::str; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; + +use ruma::{ + serde::Base64, MilliSecondsSinceUnixEpoch, OwnedServerName, ServerName, + ServerSigningKeyId, ServerSigningKeyVersion, SigningKeyAlgorithm, +}; +use serde::Serialize; +use tokio::{ + fs::{create_dir_all, OpenOptions}, + io::AsyncWriteExt, + sync::mpsc, + task::JoinSet, +}; +use tracing as t; + +use super::ExportArgs; +use crate::{ + config, + database::{abstraction::KvTree, KeyValueDatabase}, + error, + service::globals::SigningKeys, + services, Services, +}; + +const CHANNEL_SIZE: usize = 16; + +#[derive(Serialize)] +struct Main { + server_name: OwnedServerName, +} + +#[derive(Serialize)] +struct SigningKey<'a> { + server_name: &'a ServerName, + algorithm: SigningKeyAlgorithm, + name: &'a ServerSigningKeyVersion, + public: Base64, + old: bool, + expires_at: MilliSecondsSinceUnixEpoch, +} + +pub(crate) async fn run(args: ExportArgs) -> Result<(), error::ExportCommand> { + use error::ExportCommand as Error; + + let config = config::load(args.config.config.as_ref()).await?; + + let db = Box::leak(Box::new( + KeyValueDatabase::load_or_create(&config).map_err(Error::Database)?, + )); + + Services::new(db, config, None) + .map_err(Error::InitializeServices)? + .install(); + + services().globals.err_if_server_name_changed()?; + + db.apply_migrations().await.map_err(Error::Database)?; + + if args.out_dir.try_exists().map_err(Error::CheckExists)? { + return Err(Error::Exists); + } + + create_dir_all(&args.out_dir).await.map_err(Error::CreateDir)?; + + let mut set = JoinSet::new(); + + { + let (tx, rx) = mpsc::channel(CHANNEL_SIZE); + set.spawn_blocking({ + let tx = tx.clone(); + || read_pdus(tx, db.pduid_pdu.clone()) + }); + set.spawn_blocking(|| read_pdus(tx, db.eventid_outlierpdu.clone())); + set.spawn(write_lines(rx, args.out_dir.clone(), "pdus.jsonl")); + } + + { + let (tx, rx) = mpsc::channel(CHANNEL_SIZE); + set.spawn_blocking(|| read_server_signing_keys(tx, db)); + set.spawn(write_lines( + rx, + args.out_dir.clone(), + "server-signing-keys.jsonl", + )); + } + + set.spawn(make_main(args.out_dir.clone())); + + let mut errors = false; + while let Some(join) = set.join_next().await { + match join { + Ok(had_errors) => { + if had_errors { + errors = true; + } + } + Err(error) => { + t::error!(%error, "failed to join task"); + errors = true; + } + } + } + if errors { + return Err(Error::Export); + } + + Ok(()) +} + +/// Write each received string into a file, separated by a newline. +async fn write_lines

( + mut rx: mpsc::Receiver, + out_dir: PathBuf, + out_file: P, +) -> bool +where + P: AsRef, +{ + let path = out_dir.join(out_file); + + let mut file = match OpenOptions::new() + .write(true) + .create_new(true) + .open(&path) + .await + { + Ok(x) => x, + Err(error) => { + t::error!(%error, path = %path.display(), "failed to open file"); + return true; + } + }; + + let mut errors = false; + + while let Some(next) = rx.recv().await { + let res = file.write_all(next.as_bytes()).await; + match res { + Ok(()) => (), + Err(error) => { + t::error!(%error, "failed to write line"); + errors = true; + } + } + + let res = file.write_all(b"\n").await; + match res { + Ok(()) => (), + Err(error) => { + t::error!(%error, "failed to write newline"); + errors = true; + } + } + } + + errors +} + +/// Read PDUs into a channel. +#[allow(clippy::needless_pass_by_value)] +fn read_pdus(tx: mpsc::Sender, x_pdu: Arc) -> bool { + let mut errors = false; + + for (_, v) in x_pdu.iter() { + // Deserialize to ensure the data is valid. + let s = match serde_json::from_slice::(&v) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to deserialize PDU"); + errors = true; + continue; + } + }; + + // Serialize to ensure it's formatted on a single line. + let d = match serde_json::to_string(&s) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to serialize PDU"); + errors = true; + continue; + } + }; + + match tx.blocking_send(d) { + Ok(()) => (), + Err(error) => { + t::error!(%error, "failed to send PDU over channel"); + errors = true; + continue; + } + } + } + + errors +} + +/// Make the `main.json` file. +async fn make_main(out_dir: PathBuf) -> bool { + let main = Main { + server_name: services().globals.server_name().to_owned(), + }; + + let s = match serde_json::to_string_pretty(&main) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to serialize JSON"); + return true; + } + }; + + let path = out_dir.join("main.json"); + + let mut file = match OpenOptions::new() + .write(true) + .create_new(true) + .open(&path) + .await + { + Ok(x) => x, + Err(error) => { + t::error!(%error, path = %path.display(), "failed to open file"); + return true; + } + }; + + match file.write_all(s.as_bytes()).await { + Ok(()) => false, + Err(error) => { + t::error!(%error, "failed to write file"); + true + } + } +} + +/// Read server signing keys into a channel. +#[allow(clippy::needless_pass_by_value)] +fn read_server_signing_keys( + tx: mpsc::Sender, + db: &KeyValueDatabase, +) -> bool { + // Handle our own signing keys first. + let mut errors = handle_server_signing_keys( + &tx, + services().globals.server_name(), + SigningKeys::load_own_keys(), + ); + + // Handle other servers' signing keys. + for (k, v) in db.server_signingkeys.iter() { + let k = match str::from_utf8(&k) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "server name contained invalid UTF-8"); + errors = true; + continue; + } + }; + + let k = match ServerName::parse(k) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to deserialize server name"); + errors = true; + continue; + } + }; + + // Not sure what causes this to happen, but ignoring them is *probably* + // the right thing to do. + if k == services().globals.server_name() { + t::debug!("found own signing keys in server_signingkeys, ignoring"); + continue; + } + + let v = match serde_json::from_slice::(&v) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to deserialize signing key data"); + errors = true; + continue; + } + }; + + handle_server_signing_keys(&tx, &k, v); + } + + errors +} + +/// Shuffle the server signing key data into the desired format and send it. +fn handle_server_signing_keys( + tx: &mpsc::Sender, + server_name: &ServerName, + signing_keys: SigningKeys, +) -> bool { + let mut errors = false; + + for (key_id, key) in signing_keys.verify_keys { + let key_id = match ServerSigningKeyId::parse(&key_id) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to parse signing key ID"); + errors = true; + continue; + } + }; + + // There are currently bugs that cause keys to be miscategorized, so + // claim all keys are old to be on the safe side. + let key = SigningKey { + server_name, + algorithm: key_id.algorithm(), + name: key_id.key_name(), + public: key.key, + old: true, + expires_at: signing_keys.valid_until_ts, + }; + + let d = match serde_json::to_string(&key) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to serialize key"); + errors = true; + continue; + } + }; + + match tx.blocking_send(d) { + Ok(()) => (), + Err(error) => { + t::error!(%error, "failed to send key over channel"); + errors = true; + continue; + } + } + } + + for (key_id, key) in signing_keys.old_verify_keys { + let key_id = match ServerSigningKeyId::parse(&key_id) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to parse signing key ID"); + errors = true; + continue; + } + }; + + // There are currently bugs that cause keys to be miscategorized, so + // claim all keys are old to be on the safe side. + let key = SigningKey { + server_name, + algorithm: key_id.algorithm(), + name: key_id.key_name(), + public: key.key, + old: true, + expires_at: key.expired_ts, + }; + + let d = match serde_json::to_string(&key) { + Ok(x) => x, + Err(error) => { + t::error!(%error, "failed to serialize key"); + errors = true; + continue; + } + }; + + match tx.blocking_send(d) { + Ok(()) => (), + Err(error) => { + t::error!(%error, "failed to send key over channel"); + errors = true; + continue; + } + } + } + + errors +} diff --git a/src/error.rs b/src/error.rs index ce8ebb37..be5e0244 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,9 @@ pub(crate) enum Main { #[error(transparent)] CheckConfigCommand(#[from] CheckConfigCommand), + + #[error(transparent)] + ExportCommand(#[from] ExportCommand), } /// Errors returned from the `serve` CLI subcommand. @@ -75,6 +78,37 @@ pub(crate) enum ServeCommand { ServerNameChanged(#[from] ServerNameChanged), } +// Errors returned from the `export` CLI subcommand. +// Missing docs are allowed here since that kind of information should be +// encoded in the error messages themselves anyway. +#[allow(missing_docs)] +#[derive(Error, Debug)] +pub(crate) enum ExportCommand { + #[error("failed to load configuration")] + Config(#[from] Config), + + #[error("failed to initialize services")] + InitializeServices(#[source] crate::utils::error::Error), + + #[error("failed to load or create the database")] + Database(#[source] crate::utils::error::Error), + + #[error("`server_name` change check failed")] + ServerNameChanged(#[from] ServerNameChanged), + + #[error("failed to check out directory existence")] + CheckExists(#[source] std::io::Error), + + #[error("out directory already exists")] + Exists, + + #[error("failed to create out directory or leading components")] + CreateDir(#[source] std::io::Error), + + #[error("one or more errors ocurred during the export")] + Export, +} + /// Errors returned from the `check-config` CLI subcommand. // Missing docs are allowed here since that kind of information should be // encoded in the error messages themselves anyway.