diff --git a/src/cli.rs b/src/cli.rs
index 99641790..dd27ccd0 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -13,6 +13,7 @@ use crate::{
};
mod check_config;
+mod export;
mod serve;
/// Command line arguments
@@ -33,6 +34,9 @@ pub(crate) enum Command {
/// Check the configuration file for syntax and semantic errors.
CheckConfig(CheckConfigArgs),
+
+ /// Export all persistent data.
+ Export(ExportArgs),
}
#[derive(clap::Args)]
@@ -44,6 +48,18 @@ pub(crate) struct CheckConfigArgs {
observability: ObservabilityArgs,
}
+#[derive(clap::Args)]
+pub(crate) struct ExportArgs {
+ #[clap(flatten)]
+ config: ConfigArg,
+
+ #[clap(flatten)]
+ observability: ObservabilityArgs,
+
+ #[clap(short, long)]
+ out_dir: PathBuf,
+}
+
/// Wrapper for the `--config` arg.
///
/// This exists to centralize the `mut_arg` code that sets the help value based
@@ -99,6 +115,9 @@ impl Args {
Command::CheckConfig(args) => {
check_config::run(args.config).await?;
}
+ Command::Export(args) => {
+ export::run(args).await?;
+ }
}
Ok(())
}
@@ -113,6 +132,10 @@ impl Command {
args.observability.log_format,
args.observability.log_filter.clone(),
)),
+ Command::Export(args) => Some((
+ args.observability.log_format,
+ args.observability.log_filter.clone(),
+ )),
Command::Serve(_) => None,
}
}
diff --git a/src/cli/export.rs b/src/cli/export.rs
new file mode 100644
index 00000000..7a87acf3
--- /dev/null
+++ b/src/cli/export.rs
@@ -0,0 +1,384 @@
+use core::str;
+use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+
+use ruma::{
+ serde::Base64, MilliSecondsSinceUnixEpoch, OwnedServerName, ServerName,
+ ServerSigningKeyId, ServerSigningKeyVersion, SigningKeyAlgorithm,
+};
+use serde::Serialize;
+use tokio::{
+ fs::{create_dir_all, OpenOptions},
+ io::AsyncWriteExt,
+ sync::mpsc,
+ task::JoinSet,
+};
+use tracing as t;
+
+use super::ExportArgs;
+use crate::{
+ config,
+ database::{abstraction::KvTree, KeyValueDatabase},
+ error,
+ service::globals::SigningKeys,
+ services, Services,
+};
+
+const CHANNEL_SIZE: usize = 16;
+
+#[derive(Serialize)]
+struct Main {
+ server_name: OwnedServerName,
+}
+
+#[derive(Serialize)]
+struct SigningKey<'a> {
+ server_name: &'a ServerName,
+ algorithm: SigningKeyAlgorithm,
+ name: &'a ServerSigningKeyVersion,
+ public: Base64,
+ old: bool,
+ expires_at: MilliSecondsSinceUnixEpoch,
+}
+
+pub(crate) async fn run(args: ExportArgs) -> Result<(), error::ExportCommand> {
+ use error::ExportCommand as Error;
+
+ let config = config::load(args.config.config.as_ref()).await?;
+
+ let db = Box::leak(Box::new(
+ KeyValueDatabase::load_or_create(&config).map_err(Error::Database)?,
+ ));
+
+ Services::new(db, config, None)
+ .map_err(Error::InitializeServices)?
+ .install();
+
+ services().globals.err_if_server_name_changed()?;
+
+ db.apply_migrations().await.map_err(Error::Database)?;
+
+ if args.out_dir.try_exists().map_err(Error::CheckExists)? {
+ return Err(Error::Exists);
+ }
+
+ create_dir_all(&args.out_dir).await.map_err(Error::CreateDir)?;
+
+ let mut set = JoinSet::new();
+
+ {
+ let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
+ set.spawn_blocking({
+ let tx = tx.clone();
+ || read_pdus(tx, db.pduid_pdu.clone())
+ });
+ set.spawn_blocking(|| read_pdus(tx, db.eventid_outlierpdu.clone()));
+ set.spawn(write_lines(rx, args.out_dir.clone(), "pdus.jsonl"));
+ }
+
+ {
+ let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
+ set.spawn_blocking(|| read_server_signing_keys(tx, db));
+ set.spawn(write_lines(
+ rx,
+ args.out_dir.clone(),
+ "server-signing-keys.jsonl",
+ ));
+ }
+
+ set.spawn(make_main(args.out_dir.clone()));
+
+ let mut errors = false;
+ while let Some(join) = set.join_next().await {
+ match join {
+ Ok(had_errors) => {
+ if had_errors {
+ errors = true;
+ }
+ }
+ Err(error) => {
+ t::error!(%error, "failed to join task");
+ errors = true;
+ }
+ }
+ }
+ if errors {
+ return Err(Error::Export);
+ }
+
+ Ok(())
+}
+
+/// Write each received string into a file, separated by a newline.
+async fn write_lines
(
+ mut rx: mpsc::Receiver,
+ out_dir: PathBuf,
+ out_file: P,
+) -> bool
+where
+ P: AsRef,
+{
+ let path = out_dir.join(out_file);
+
+ let mut file = match OpenOptions::new()
+ .write(true)
+ .create_new(true)
+ .open(&path)
+ .await
+ {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, path = %path.display(), "failed to open file");
+ return true;
+ }
+ };
+
+ let mut errors = false;
+
+ while let Some(next) = rx.recv().await {
+ let res = file.write_all(next.as_bytes()).await;
+ match res {
+ Ok(()) => (),
+ Err(error) => {
+ t::error!(%error, "failed to write line");
+ errors = true;
+ }
+ }
+
+ let res = file.write_all(b"\n").await;
+ match res {
+ Ok(()) => (),
+ Err(error) => {
+ t::error!(%error, "failed to write newline");
+ errors = true;
+ }
+ }
+ }
+
+ errors
+}
+
+/// Read PDUs into a channel.
+#[allow(clippy::needless_pass_by_value)]
+fn read_pdus(tx: mpsc::Sender, x_pdu: Arc) -> bool {
+ let mut errors = false;
+
+ for (_, v) in x_pdu.iter() {
+ // Deserialize to ensure the data is valid.
+ let s = match serde_json::from_slice::(&v) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to deserialize PDU");
+ errors = true;
+ continue;
+ }
+ };
+
+ // Serialize to ensure it's formatted on a single line.
+ let d = match serde_json::to_string(&s) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to serialize PDU");
+ errors = true;
+ continue;
+ }
+ };
+
+ match tx.blocking_send(d) {
+ Ok(()) => (),
+ Err(error) => {
+ t::error!(%error, "failed to send PDU over channel");
+ errors = true;
+ continue;
+ }
+ }
+ }
+
+ errors
+}
+
+/// Make the `main.json` file.
+async fn make_main(out_dir: PathBuf) -> bool {
+ let main = Main {
+ server_name: services().globals.server_name().to_owned(),
+ };
+
+ let s = match serde_json::to_string_pretty(&main) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to serialize JSON");
+ return true;
+ }
+ };
+
+ let path = out_dir.join("main.json");
+
+ let mut file = match OpenOptions::new()
+ .write(true)
+ .create_new(true)
+ .open(&path)
+ .await
+ {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, path = %path.display(), "failed to open file");
+ return true;
+ }
+ };
+
+ match file.write_all(s.as_bytes()).await {
+ Ok(()) => false,
+ Err(error) => {
+ t::error!(%error, "failed to write file");
+ true
+ }
+ }
+}
+
+/// Read server signing keys into a channel.
+#[allow(clippy::needless_pass_by_value)]
+fn read_server_signing_keys(
+ tx: mpsc::Sender,
+ db: &KeyValueDatabase,
+) -> bool {
+ // Handle our own signing keys first.
+ let mut errors = handle_server_signing_keys(
+ &tx,
+ services().globals.server_name(),
+ SigningKeys::load_own_keys(),
+ );
+
+ // Handle other servers' signing keys.
+ for (k, v) in db.server_signingkeys.iter() {
+ let k = match str::from_utf8(&k) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "server name contained invalid UTF-8");
+ errors = true;
+ continue;
+ }
+ };
+
+ let k = match ServerName::parse(k) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to deserialize server name");
+ errors = true;
+ continue;
+ }
+ };
+
+ // Not sure what causes this to happen, but ignoring them is *probably*
+ // the right thing to do.
+ if k == services().globals.server_name() {
+ t::debug!("found own signing keys in server_signingkeys, ignoring");
+ continue;
+ }
+
+ let v = match serde_json::from_slice::(&v) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to deserialize signing key data");
+ errors = true;
+ continue;
+ }
+ };
+
+ handle_server_signing_keys(&tx, &k, v);
+ }
+
+ errors
+}
+
+/// Shuffle the server signing key data into the desired format and send it.
+fn handle_server_signing_keys(
+ tx: &mpsc::Sender,
+ server_name: &ServerName,
+ signing_keys: SigningKeys,
+) -> bool {
+ let mut errors = false;
+
+ for (key_id, key) in signing_keys.verify_keys {
+ let key_id = match ServerSigningKeyId::parse(&key_id) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to parse signing key ID");
+ errors = true;
+ continue;
+ }
+ };
+
+ // There are currently bugs that cause keys to be miscategorized, so
+ // claim all keys are old to be on the safe side.
+ let key = SigningKey {
+ server_name,
+ algorithm: key_id.algorithm(),
+ name: key_id.key_name(),
+ public: key.key,
+ old: true,
+ expires_at: signing_keys.valid_until_ts,
+ };
+
+ let d = match serde_json::to_string(&key) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to serialize key");
+ errors = true;
+ continue;
+ }
+ };
+
+ match tx.blocking_send(d) {
+ Ok(()) => (),
+ Err(error) => {
+ t::error!(%error, "failed to send key over channel");
+ errors = true;
+ continue;
+ }
+ }
+ }
+
+ for (key_id, key) in signing_keys.old_verify_keys {
+ let key_id = match ServerSigningKeyId::parse(&key_id) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to parse signing key ID");
+ errors = true;
+ continue;
+ }
+ };
+
+ // There are currently bugs that cause keys to be miscategorized, so
+ // claim all keys are old to be on the safe side.
+ let key = SigningKey {
+ server_name,
+ algorithm: key_id.algorithm(),
+ name: key_id.key_name(),
+ public: key.key,
+ old: true,
+ expires_at: key.expired_ts,
+ };
+
+ let d = match serde_json::to_string(&key) {
+ Ok(x) => x,
+ Err(error) => {
+ t::error!(%error, "failed to serialize key");
+ errors = true;
+ continue;
+ }
+ };
+
+ match tx.blocking_send(d) {
+ Ok(()) => (),
+ Err(error) => {
+ t::error!(%error, "failed to send key over channel");
+ errors = true;
+ continue;
+ }
+ }
+ }
+
+ errors
+}
diff --git a/src/error.rs b/src/error.rs
index ce8ebb37..be5e0244 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -48,6 +48,9 @@ pub(crate) enum Main {
#[error(transparent)]
CheckConfigCommand(#[from] CheckConfigCommand),
+
+ #[error(transparent)]
+ ExportCommand(#[from] ExportCommand),
}
/// Errors returned from the `serve` CLI subcommand.
@@ -75,6 +78,37 @@ pub(crate) enum ServeCommand {
ServerNameChanged(#[from] ServerNameChanged),
}
+// Errors returned from the `export` CLI subcommand.
+// Missing docs are allowed here since that kind of information should be
+// encoded in the error messages themselves anyway.
+#[allow(missing_docs)]
+#[derive(Error, Debug)]
+pub(crate) enum ExportCommand {
+ #[error("failed to load configuration")]
+ Config(#[from] Config),
+
+ #[error("failed to initialize services")]
+ InitializeServices(#[source] crate::utils::error::Error),
+
+ #[error("failed to load or create the database")]
+ Database(#[source] crate::utils::error::Error),
+
+ #[error("`server_name` change check failed")]
+ ServerNameChanged(#[from] ServerNameChanged),
+
+ #[error("failed to check out directory existence")]
+ CheckExists(#[source] std::io::Error),
+
+ #[error("out directory already exists")]
+ Exists,
+
+ #[error("failed to create out directory or leading components")]
+ CreateDir(#[source] std::io::Error),
+
+ #[error("one or more errors ocurred during the export")]
+ Export,
+}
+
/// Errors returned from the `check-config` CLI subcommand.
// Missing docs are allowed here since that kind of information should be
// encoded in the error messages themselves anyway.