diff --git a/Cargo.lock b/Cargo.lock index 199a5971..e7421a60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata 0.4.9", @@ -921,6 +921,7 @@ dependencies = [ "axum-extra", "axum-server", "base64 0.22.1", + "bstr", "bytes", "clap", "futures-util", @@ -966,6 +967,7 @@ dependencies = [ "thread_local", "tikv-jemallocator", "tokio", + "tokio-util", "toml", "tower 0.5.2", "tower-http", @@ -1002,6 +1004,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -3410,6 +3418,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 65761015..02f7aee5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,6 +94,7 @@ axum = { version = "0.7.9", default-features = false, features = ["form", "http1 axum-extra = { version = "0.9.5", features = ["typed-header"] } axum-server = { version = "0.7.2", features = ["tls-rustls-no-provider"] } base64 = "0.22.1" +bstr = "1.12.0" bytes = "1.10.1" clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } futures-util = { version = "0.3.31", default-features = false } @@ -135,6 +136,7 @@ thiserror = "2.0.12" thread_local = "1.1.8" tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } tokio = { version = "1.44.1", features = ["fs", "macros", "signal", "sync"] } +tokio-util = { version = "0.7.12", features = ["rt"] } toml = "0.8.20" tower = { version = "0.5.2", features = ["util"] } tower-http = { version = "0.6.2", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] } diff --git a/src/integrity.rs b/src/integrity.rs new file mode 100644 index 00000000..76ae7c37 --- /dev/null +++ b/src/integrity.rs @@ -0,0 +1,286 @@ +use std::string::FromUtf8Error; + +use bstr::BString; +use serde::Deserialize; +use serde_json::value::RawValue; +use thiserror::Error; + +use crate::{ + database::{abstraction::KvTree, KeyValueDatabase}, + utils::error::Result, +}; + +pub(crate) trait CheckIntegrity: Sync { + fn check_integrity( + &'static self, + ) -> Box>>; +} + +#[derive(Debug, Error)] +pub(crate) enum IntegrityError { + #[error(transparent)] + Symmetry(#[from] SymmetryError), + + #[error("key {x_key:?} in {x_name} is not a key in {y_name}")] + BadForeignKey { + x_name: &'static str, + x_key: Vec, + + y_name: &'static str, + }, + + #[error(transparent)] + InvalidAccountDataEvent(#[from] InvalidAccountDataEvent), +} + +#[derive(Debug, Error)] +pub(crate) enum SymmetryError { + #[error( + "missing key {y_key:?} in {y_name} (referenced by key {x_key:?} in \ + {x_name})" + )] + MissingKey { + x_name: &'static str, + x_key: Vec, + + y_name: &'static str, + y_key: Vec, + }, + + #[error( + "key {x_key:?} in {x_name} points to {y_key:?} in {y_name}, but that \ + points to {y_value:?}" + )] + Incoherent { + x_name: &'static str, + x_key: Vec, + + y_name: &'static str, + y_key: Vec, + y_value: Vec, + }, +} + +#[derive(Clone, Copy)] +struct NamedTree<'db> { + name: &'static str, + tree: &'db dyn KvTree, +} + +macro_rules! tree { + ($db:expr, $tree:ident) => { + NamedTree { + name: stringify!($tree), + tree: &*$db.$tree, + } + }; +} + +fn check_symmetric<'db>( + x: NamedTree<'db>, + y: NamedTree<'db>, +) -> impl Iterator> + 'db { + x.tree.iter().filter_map(move |(x_key, y_key)| match y.tree.get(&y_key) { + Err(e) => Some(Err(e)), + Ok(Some(y_value)) if y_value == x_key => None, + Ok(Some(y_value)) => Some(Ok(SymmetryError::Incoherent { + x_name: x.name, + x_key, + y_name: y.name, + y_key, + y_value, + } + .into())), + Ok(None) => Some(Ok(SymmetryError::MissingKey { + x_name: x.name, + x_key, + y_name: y.name, + y_key, + } + .into())), + }) +} + +fn check_symmetric_both<'db>( + x: NamedTree<'db>, + y: NamedTree<'db>, +) -> impl Iterator> + 'db { + std::iter::empty().chain(check_symmetric(x, y)).chain(check_symmetric(y, x)) +} + +fn check_foreign_key<'db>( + primary: NamedTree<'db>, + subset: NamedTree<'db>, +) -> impl Iterator> + 'db { + subset.tree.iter().filter_map(move |(x_key, _)| { + match primary.tree.get(&x_key) { + Err(e) => Some(Err(e)), + Ok(Some(_)) => None, + Ok(None) => Some(Ok(IntegrityError::BadForeignKey { + x_name: subset.name, + x_key, + y_name: primary.name, + })), + } + }) +} + +fn short_ids( + database: &'static KeyValueDatabase, +) -> impl Iterator> { + let eventid_symmetry = check_symmetric_both( + tree!(database, shorteventid_eventid), + tree!(database, eventid_shorteventid), + ); + + let eventid_foreign = [ + tree!(database, shorteventid_shortstatehash), + tree!(database, shorteventid_authchain), + ] + .into_iter() + .flat_map(|subset| { + check_foreign_key(tree!(database, shorteventid_eventid), subset) + }); + + let statekey_symmetry = check_symmetric_both( + tree!(database, shortstatekey_statekey), + tree!(database, statekey_shortstatekey), + ); + + eventid_symmetry.chain(eventid_foreign).chain(statekey_symmetry) +} + +#[derive(Debug, Error)] +pub(crate) enum InvalidAccountDataEvent { + #[error("missing event type in key {key:?}")] + MissingType { + key: BString, + }, + #[error("invalid event type string in key {key:?}")] + InvalidType { + key: BString, + #[source] + err: FromUtf8Error, + }, + + #[error("missing count in key {key:?}")] + MissingCount { + key: BString, + }, + + #[error("missing user ID in key {key:?}")] + MissingUserId { + key: BString, + }, + #[error("invalid user ID string in key {key:?}")] + InvalidUserId { + key: BString, + #[source] + err: FromUtf8Error, + }, + + #[error("invalid event data for {key:?}: {value:?}")] + InvalidEventData { + key: BString, + value: BString, + #[source] + err: serde_json::Error, + }, + + #[error( + "mismatch between event type in column key ({key_type}) and \ + serialised field ({serialised_type}) for user {user_id}" + )] + MismatchedType { + user_id: String, + key_type: String, + serialised_type: String, + }, +} + +fn account_data( + database: &'static KeyValueDatabase, +) -> impl Iterator> { + let tree = tree!(database, roomuserdataid_accountdata); + tree.tree + .iter() + .map(|(key, value)| { + #[allow(dead_code)] + #[derive(Deserialize)] + struct ExtractEventFields<'a> { + #[serde(rename = "type")] + event_type: &'a str, + content: &'a RawValue, + } + + let key = BString::from(key); + let mut key_parts = key.rsplit(|&b| b == 0xFF); + + let event_type = String::from_utf8( + key_parts + .next() + .ok_or_else(|| InvalidAccountDataEvent::MissingType { + key: key.clone(), + })? + .to_vec(), + ) + .map_err(|err| { + InvalidAccountDataEvent::InvalidType { + key: key.clone(), + err, + } + })?; + + let Some(_) = key_parts.next() else { + return Err(InvalidAccountDataEvent::MissingCount { + key: key.clone(), + }); + }; + + let user_id = String::from_utf8( + key_parts + .next() + .ok_or_else(|| InvalidAccountDataEvent::MissingUserId { + key: key.clone(), + })? + .to_vec(), + ) + .map_err(|err| { + InvalidAccountDataEvent::InvalidUserId { + key: key.clone(), + err, + } + })?; + + let extract: ExtractEventFields<'_> = + serde_json::from_slice(&value).map_err(|err| { + InvalidAccountDataEvent::InvalidEventData { + key: key.clone(), + value: value.clone().into(), + err, + } + })?; + + if extract.event_type != event_type { + return Err(InvalidAccountDataEvent::MismatchedType { + user_id, + key_type: event_type, + serialised_type: extract.event_type.to_owned(), + }); + } + + Ok(()) + }) + .filter_map(|r| match r { + Ok(()) => None, + Err(e) => Some(Ok(IntegrityError::InvalidAccountDataEvent(e))), + }) +} + +impl CheckIntegrity for KeyValueDatabase { + fn check_integrity( + &'static self, + ) -> Box>> { + Box::new(short_ids(self).chain(account_data(self))) + } +} diff --git a/src/main.rs b/src/main.rs index 8f4268a1..fa78c65c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ mod cli; mod config; mod database; mod error; +mod integrity; mod observability; mod service; mod utils; diff --git a/src/service.rs b/src/service.rs index 85e7a302..73a2658a 100644 --- a/src/service.rs +++ b/src/service.rs @@ -52,6 +52,7 @@ impl Services { + key_backups::Data + media::Data + sending::Data + + crate::integrity::CheckIntegrity + 'static, >( db: &'static D, @@ -115,7 +116,7 @@ impl Services { uiaa: uiaa::Service::new(db), users: users::Service::new(db), account_data: db, - admin: admin::Service::new(), + admin: admin::Service::new(db), key_backups: db, media: media::Service { db, diff --git a/src/service/admin.rs b/src/service/admin.rs index 5e3ee770..af161955 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -28,11 +28,14 @@ use ruma::{ }; use serde_json::value::to_raw_value; use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio_util::task::AbortOnDropHandle; use tracing::warn; use super::pdu::PduBuilder; use crate::{ api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, + error::DisplayWithSources, + integrity::CheckIntegrity, services, utils::{self, dbg_truncate_str, room_version::RoomVersion}, Error, PduEvent, Result, @@ -209,6 +212,9 @@ enum AdminCommand { #[command(subcommand)] cmd: TracingFilterCommand, }, + + /// Check database integrity + CheckIntegrity, } #[derive(Debug, Subcommand)] @@ -249,6 +255,7 @@ pub(crate) enum AdminRoomEvent { } pub(crate) struct Service { + pub(crate) db: &'static dyn CheckIntegrity, pub(crate) sender: mpsc::UnboundedSender, receiver: Mutex>, } @@ -261,9 +268,10 @@ enum TracingBackend { } impl Service { - pub(crate) fn new() -> Arc { + pub(crate) fn new(db: &'static dyn CheckIntegrity) -> Arc { let (sender, receiver) = mpsc::unbounded_channel(); Arc::new(Self { + db, sender, receiver: Mutex::new(receiver), }) @@ -1247,6 +1255,68 @@ impl Service { "Filter reloaded", )); } + AdminCommand::CheckIntegrity => { + let (tx, mut rx) = mpsc::channel(16); + + let task = { + let db = self.db; + tokio::task::spawn_blocking(move || { + for err in db.check_integrity() { + let Ok(()) = tx.blocking_send(err?) else { + // channel has been closed + return Ok::<_, Error>(()); + }; + } + Ok(()) + }) + }; + let task = AbortOnDropHandle::new(task); + + let mut errors = Vec::new(); + for _ in 0..50 { + let Some(error) = rx.recv().await else { + break; + }; + errors.push(error); + } + + let mut message = String::new(); + for error in &errors { + writeln!( + message, + "- {}", + DisplayWithSources { + error, + infix: "\n - caused by: " + } + ) + .unwrap(); + } + + if task.is_finished() { + match task.await { + Ok(Ok(())) => { + if errors.is_empty() { + writeln!(message, "No errors were found") + .unwrap(); + } + } + Ok(Err(e)) => return Err(e), + Err(e) => { + writeln!( + message, + "An error occured in the validity checking \ + task, results may be incorrect: {e}" + ) + .unwrap(); + } + }; + } else { + writeln!(message, "...more errors not shown").unwrap(); + } + + RoomMessageEventContent::text_plain(message) + } }; Ok(reply_message_content)