Merge branch 'lambda/check-integrity' into 'main'

Draft: Add integrity check command

See merge request matrix/grapevine!110
This commit is contained in:
Lambda 2025-05-05 00:31:27 +00:00
commit 5952251e2f
6 changed files with 374 additions and 4 deletions

14
Cargo.lock generated
View file

@ -333,9 +333,9 @@ dependencies = [
[[package]] [[package]]
name = "bstr" name = "bstr"
version = "1.11.3" version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4"
dependencies = [ dependencies = [
"memchr", "memchr",
"regex-automata 0.4.9", "regex-automata 0.4.9",
@ -921,6 +921,7 @@ dependencies = [
"axum-extra", "axum-extra",
"axum-server", "axum-server",
"base64 0.22.1", "base64 0.22.1",
"bstr",
"bytes", "bytes",
"clap", "clap",
"futures-util", "futures-util",
@ -966,6 +967,7 @@ dependencies = [
"thread_local", "thread_local",
"tikv-jemallocator", "tikv-jemallocator",
"tokio", "tokio",
"tokio-util",
"toml", "toml",
"tower 0.5.2", "tower 0.5.2",
"tower-http", "tower-http",
@ -1002,6 +1004,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.2" version = "0.15.2"
@ -3410,6 +3418,8 @@ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]

View file

@ -94,6 +94,7 @@ axum = { version = "0.7.9", default-features = false, features = ["form", "http1
axum-extra = { version = "0.9.5", features = ["typed-header"] } axum-extra = { version = "0.9.5", features = ["typed-header"] }
axum-server = { version = "0.7.2", features = ["tls-rustls-no-provider"] } axum-server = { version = "0.7.2", features = ["tls-rustls-no-provider"] }
base64 = "0.22.1" base64 = "0.22.1"
bstr = "1.12.0"
bytes = "1.10.1" bytes = "1.10.1"
clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } clap = { version = "4.5.34", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] }
futures-util = { version = "0.3.31", default-features = false } futures-util = { version = "0.3.31", default-features = false }
@ -135,6 +136,7 @@ thiserror = "2.0.12"
thread_local = "1.1.8" thread_local = "1.1.8"
tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } tikv-jemallocator = { version = "0.6.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }
tokio = { version = "1.44.1", features = ["fs", "macros", "signal", "sync"] } tokio = { version = "1.44.1", features = ["fs", "macros", "signal", "sync"] }
tokio-util = { version = "0.7.12", features = ["rt"] }
toml = "0.8.20" toml = "0.8.20"
tower = { version = "0.5.2", features = ["util"] } tower = { version = "0.5.2", features = ["util"] }
tower-http = { version = "0.6.2", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] } tower-http = { version = "0.6.2", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }

286
src/integrity.rs Normal file
View file

@ -0,0 +1,286 @@
use std::string::FromUtf8Error;
use bstr::BString;
use serde::Deserialize;
use serde_json::value::RawValue;
use thiserror::Error;
use crate::{
database::{abstraction::KvTree, KeyValueDatabase},
utils::error::Result,
};
pub(crate) trait CheckIntegrity: Sync {
fn check_integrity(
&'static self,
) -> Box<dyn Iterator<Item = Result<IntegrityError>>>;
}
#[derive(Debug, Error)]
pub(crate) enum IntegrityError {
#[error(transparent)]
Symmetry(#[from] SymmetryError),
#[error("key {x_key:?} in {x_name} is not a key in {y_name}")]
BadForeignKey {
x_name: &'static str,
x_key: Vec<u8>,
y_name: &'static str,
},
#[error(transparent)]
InvalidAccountDataEvent(#[from] InvalidAccountDataEvent),
}
#[derive(Debug, Error)]
pub(crate) enum SymmetryError {
#[error(
"missing key {y_key:?} in {y_name} (referenced by key {x_key:?} in \
{x_name})"
)]
MissingKey {
x_name: &'static str,
x_key: Vec<u8>,
y_name: &'static str,
y_key: Vec<u8>,
},
#[error(
"key {x_key:?} in {x_name} points to {y_key:?} in {y_name}, but that \
points to {y_value:?}"
)]
Incoherent {
x_name: &'static str,
x_key: Vec<u8>,
y_name: &'static str,
y_key: Vec<u8>,
y_value: Vec<u8>,
},
}
#[derive(Clone, Copy)]
struct NamedTree<'db> {
name: &'static str,
tree: &'db dyn KvTree,
}
macro_rules! tree {
($db:expr, $tree:ident) => {
NamedTree {
name: stringify!($tree),
tree: &*$db.$tree,
}
};
}
fn check_symmetric<'db>(
x: NamedTree<'db>,
y: NamedTree<'db>,
) -> impl Iterator<Item = Result<IntegrityError>> + 'db {
x.tree.iter().filter_map(move |(x_key, y_key)| match y.tree.get(&y_key) {
Err(e) => Some(Err(e)),
Ok(Some(y_value)) if y_value == x_key => None,
Ok(Some(y_value)) => Some(Ok(SymmetryError::Incoherent {
x_name: x.name,
x_key,
y_name: y.name,
y_key,
y_value,
}
.into())),
Ok(None) => Some(Ok(SymmetryError::MissingKey {
x_name: x.name,
x_key,
y_name: y.name,
y_key,
}
.into())),
})
}
fn check_symmetric_both<'db>(
x: NamedTree<'db>,
y: NamedTree<'db>,
) -> impl Iterator<Item = Result<IntegrityError>> + 'db {
std::iter::empty().chain(check_symmetric(x, y)).chain(check_symmetric(y, x))
}
fn check_foreign_key<'db>(
primary: NamedTree<'db>,
subset: NamedTree<'db>,
) -> impl Iterator<Item = Result<IntegrityError>> + 'db {
subset.tree.iter().filter_map(move |(x_key, _)| {
match primary.tree.get(&x_key) {
Err(e) => Some(Err(e)),
Ok(Some(_)) => None,
Ok(None) => Some(Ok(IntegrityError::BadForeignKey {
x_name: subset.name,
x_key,
y_name: primary.name,
})),
}
})
}
fn short_ids(
database: &'static KeyValueDatabase,
) -> impl Iterator<Item = Result<IntegrityError>> {
let eventid_symmetry = check_symmetric_both(
tree!(database, shorteventid_eventid),
tree!(database, eventid_shorteventid),
);
let eventid_foreign = [
tree!(database, shorteventid_shortstatehash),
tree!(database, shorteventid_authchain),
]
.into_iter()
.flat_map(|subset| {
check_foreign_key(tree!(database, shorteventid_eventid), subset)
});
let statekey_symmetry = check_symmetric_both(
tree!(database, shortstatekey_statekey),
tree!(database, statekey_shortstatekey),
);
eventid_symmetry.chain(eventid_foreign).chain(statekey_symmetry)
}
#[derive(Debug, Error)]
pub(crate) enum InvalidAccountDataEvent {
#[error("missing event type in key {key:?}")]
MissingType {
key: BString,
},
#[error("invalid event type string in key {key:?}")]
InvalidType {
key: BString,
#[source]
err: FromUtf8Error,
},
#[error("missing count in key {key:?}")]
MissingCount {
key: BString,
},
#[error("missing user ID in key {key:?}")]
MissingUserId {
key: BString,
},
#[error("invalid user ID string in key {key:?}")]
InvalidUserId {
key: BString,
#[source]
err: FromUtf8Error,
},
#[error("invalid event data for {key:?}: {value:?}")]
InvalidEventData {
key: BString,
value: BString,
#[source]
err: serde_json::Error,
},
#[error(
"mismatch between event type in column key ({key_type}) and \
serialised field ({serialised_type}) for user {user_id}"
)]
MismatchedType {
user_id: String,
key_type: String,
serialised_type: String,
},
}
fn account_data(
database: &'static KeyValueDatabase,
) -> impl Iterator<Item = Result<IntegrityError>> {
let tree = tree!(database, roomuserdataid_accountdata);
tree.tree
.iter()
.map(|(key, value)| {
#[allow(dead_code)]
#[derive(Deserialize)]
struct ExtractEventFields<'a> {
#[serde(rename = "type")]
event_type: &'a str,
content: &'a RawValue,
}
let key = BString::from(key);
let mut key_parts = key.rsplit(|&b| b == 0xFF);
let event_type = String::from_utf8(
key_parts
.next()
.ok_or_else(|| InvalidAccountDataEvent::MissingType {
key: key.clone(),
})?
.to_vec(),
)
.map_err(|err| {
InvalidAccountDataEvent::InvalidType {
key: key.clone(),
err,
}
})?;
let Some(_) = key_parts.next() else {
return Err(InvalidAccountDataEvent::MissingCount {
key: key.clone(),
});
};
let user_id = String::from_utf8(
key_parts
.next()
.ok_or_else(|| InvalidAccountDataEvent::MissingUserId {
key: key.clone(),
})?
.to_vec(),
)
.map_err(|err| {
InvalidAccountDataEvent::InvalidUserId {
key: key.clone(),
err,
}
})?;
let extract: ExtractEventFields<'_> =
serde_json::from_slice(&value).map_err(|err| {
InvalidAccountDataEvent::InvalidEventData {
key: key.clone(),
value: value.clone().into(),
err,
}
})?;
if extract.event_type != event_type {
return Err(InvalidAccountDataEvent::MismatchedType {
user_id,
key_type: event_type,
serialised_type: extract.event_type.to_owned(),
});
}
Ok(())
})
.filter_map(|r| match r {
Ok(()) => None,
Err(e) => Some(Ok(IntegrityError::InvalidAccountDataEvent(e))),
})
}
impl CheckIntegrity for KeyValueDatabase {
fn check_integrity(
&'static self,
) -> Box<dyn Iterator<Item = Result<IntegrityError>>> {
Box::new(short_ids(self).chain(account_data(self)))
}
}

View file

@ -14,6 +14,7 @@ mod cli;
mod config; mod config;
mod database; mod database;
mod error; mod error;
mod integrity;
mod observability; mod observability;
mod service; mod service;
mod utils; mod utils;

View file

@ -52,6 +52,7 @@ impl Services {
+ key_backups::Data + key_backups::Data
+ media::Data + media::Data
+ sending::Data + sending::Data
+ crate::integrity::CheckIntegrity
+ 'static, + 'static,
>( >(
db: &'static D, db: &'static D,
@ -115,7 +116,7 @@ impl Services {
uiaa: uiaa::Service::new(db), uiaa: uiaa::Service::new(db),
users: users::Service::new(db), users: users::Service::new(db),
account_data: db, account_data: db,
admin: admin::Service::new(), admin: admin::Service::new(db),
key_backups: db, key_backups: db,
media: media::Service { media: media::Service {
db, db,

View file

@ -28,11 +28,14 @@ use ruma::{
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::task::AbortOnDropHandle;
use tracing::warn; use tracing::warn;
use super::pdu::PduBuilder; use super::pdu::PduBuilder;
use crate::{ use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
error::DisplayWithSources,
integrity::CheckIntegrity,
services, services,
utils::{self, dbg_truncate_str, room_version::RoomVersion}, utils::{self, dbg_truncate_str, room_version::RoomVersion},
Error, PduEvent, Result, Error, PduEvent, Result,
@ -209,6 +212,9 @@ enum AdminCommand {
#[command(subcommand)] #[command(subcommand)]
cmd: TracingFilterCommand, cmd: TracingFilterCommand,
}, },
/// Check database integrity
CheckIntegrity,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -249,6 +255,7 @@ pub(crate) enum AdminRoomEvent {
} }
pub(crate) struct Service { pub(crate) struct Service {
pub(crate) db: &'static dyn CheckIntegrity,
pub(crate) sender: mpsc::UnboundedSender<AdminRoomEvent>, pub(crate) sender: mpsc::UnboundedSender<AdminRoomEvent>,
receiver: Mutex<mpsc::UnboundedReceiver<AdminRoomEvent>>, receiver: Mutex<mpsc::UnboundedReceiver<AdminRoomEvent>>,
} }
@ -261,9 +268,10 @@ enum TracingBackend {
} }
impl Service { impl Service {
pub(crate) fn new() -> Arc<Self> { pub(crate) fn new(db: &'static dyn CheckIntegrity) -> Arc<Self> {
let (sender, receiver) = mpsc::unbounded_channel(); let (sender, receiver) = mpsc::unbounded_channel();
Arc::new(Self { Arc::new(Self {
db,
sender, sender,
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
}) })
@ -1247,6 +1255,68 @@ impl Service {
"Filter reloaded", "Filter reloaded",
)); ));
} }
AdminCommand::CheckIntegrity => {
let (tx, mut rx) = mpsc::channel(16);
let task = {
let db = self.db;
tokio::task::spawn_blocking(move || {
for err in db.check_integrity() {
let Ok(()) = tx.blocking_send(err?) else {
// channel has been closed
return Ok::<_, Error>(());
};
}
Ok(())
})
};
let task = AbortOnDropHandle::new(task);
let mut errors = Vec::new();
for _ in 0..50 {
let Some(error) = rx.recv().await else {
break;
};
errors.push(error);
}
let mut message = String::new();
for error in &errors {
writeln!(
message,
"- {}",
DisplayWithSources {
error,
infix: "\n - caused by: "
}
)
.unwrap();
}
if task.is_finished() {
match task.await {
Ok(Ok(())) => {
if errors.is_empty() {
writeln!(message, "No errors were found")
.unwrap();
}
}
Ok(Err(e)) => return Err(e),
Err(e) => {
writeln!(
message,
"An error occured in the validity checking \
task, results may be incorrect: {e}"
)
.unwrap();
}
};
} else {
writeln!(message, "...more errors not shown").unwrap();
}
RoomMessageEventContent::text_plain(message)
}
}; };
Ok(reply_message_content) Ok(reply_message_content)