From 279c6472c5c32103ea132243e8b7f128f5472717 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Sun, 8 Sep 2024 16:31:51 -0700 Subject: [PATCH] split some logic out of KeyValueDatabase::load_or_create This method did _a lot_ of things at the same time. In order to use `KeyValueDatabase` for the migrate-db command, we need to be able to open a db without attempting to apply all the migrations and without spawning a bunch of unrelated background tasks. The state after this refactor is still not great, but it's enough to do a migration tool. --- src/cli/serve.rs | 10 ++- src/database.rs | 142 +++++++++++++---------------------------- src/service/globals.rs | 78 ++++++++++++++++++++-- 3 files changed, 127 insertions(+), 103 deletions(-) diff --git a/src/cli/serve.rs b/src/cli/serve.rs index a71ebf1f..2dcdb765 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -70,9 +70,15 @@ pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> { .expect("should be able to increase the soft limit to the hard limit"); info!("Loading database"); - KeyValueDatabase::load_or_create(config, reload_handles) - .await + let db = KeyValueDatabase::load_or_create(config, reload_handles) .map_err(Error::DatabaseError)?; + db.apply_migrations().await.map_err(Error::DatabaseError)?; + + info!("Starting background tasks"); + services().admin.start_handler(); + services().sending.start_handler(); + KeyValueDatabase::start_cleanup_task(); + services().globals.set_emergency_access(); info!("Starting server"); run_server().await?; diff --git a/src/database.rs b/src/database.rs index 5e145acd..bd433dd7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -14,9 +14,7 @@ use abstraction::{KeyValueDatabaseEngine, KvTree}; use lru_cache::LruCache; use ruma::{ events::{ - push_rules::{PushRulesEvent, PushRulesEventContent}, - room::message::RoomMessageEventContent, - GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, + push_rules::PushRulesEvent, GlobalAccountDataEventType, StateEventType, }, push::Ruleset, CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, @@ -314,16 +312,17 @@ impl KeyValueDatabase { Ok(()) } - /// Load an existing database or create a new one. + /// Load an existing database or create a new one, and initialize all + /// services with the loaded database. #[cfg_attr( not(any(feature = "rocksdb", feature = "sqlite")), allow(unreachable_code) )] #[allow(clippy::too_many_lines)] - pub(crate) async fn load_or_create( + pub(crate) fn load_or_create( config: Config, reload_handles: FilterReloadHandles, - ) -> Result<()> { + ) -> Result<&'static KeyValueDatabase> { Self::check_db_setup(&config)?; if !Path::new(&config.database.path).exists() { @@ -565,13 +564,23 @@ impl KeyValueDatabase { } } + Ok(db) + } + + /// Ensure that the database is at the current version, applying migrations + /// if necessary. + /// + /// If it is not possible to migrate the database to the current version, + /// returns an error. + #[allow(clippy::too_many_lines)] + pub(crate) async fn apply_migrations(&self) -> Result<()> { // If the database has any data, perform data migrations before starting let latest_database_version = 13; if services().users.count()? > 0 { // MIGRATIONS migration(1, || { - for (roomserverid, _) in db.roomserverids.iter() { + for (roomserverid, _) in self.roomserverids.iter() { let mut parts = roomserverid.split(|&b| b == 0xFF); let room_id = parts.next().expect("split always returns one element"); @@ -583,7 +592,7 @@ impl KeyValueDatabase { serverroomid.push(0xFF); serverroomid.extend_from_slice(room_id); - db.serverroomids.insert(&serverroomid, &[])?; + self.serverroomids.insert(&serverroomid, &[])?; } Ok(()) })?; @@ -591,7 +600,7 @@ impl KeyValueDatabase { migration(2, || { // We accidentally inserted hashed versions of "" into the db // instead of just "" - for (userid, password) in db.userid_password.iter() { + for (userid, password) in self.userid_password.iter() { let password = utils::string_from_bytes(&password); let empty_hashed_password = password @@ -600,7 +609,7 @@ impl KeyValueDatabase { }); if empty_hashed_password { - db.userid_password.insert(&userid, b"")?; + self.userid_password.insert(&userid, b"")?; } } Ok(()) @@ -608,7 +617,7 @@ impl KeyValueDatabase { migration(3, || { // Move media to filesystem - for (key, content) in db.mediaid_file.iter() { + for (key, content) in self.mediaid_file.iter() { let key = MediaFileKey::new(key); if content.is_empty() { continue; @@ -617,7 +626,7 @@ impl KeyValueDatabase { let path = services().globals.get_media_file(&key); let mut file = fs::File::create(path)?; file.write_all(&content)?; - db.mediaid_file.insert(key.as_bytes(), &[])?; + self.mediaid_file.insert(key.as_bytes(), &[])?; } Ok(()) })?; @@ -650,7 +659,8 @@ impl KeyValueDatabase { migration(5, || { // Upgrade user data store - for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() + for (roomuserdataid, _) in + self.roomuserdataid_accountdata.iter() { let mut parts = roomuserdataid.split(|&b| b == 0xFF); let room_id = parts.next().unwrap(); @@ -664,7 +674,7 @@ impl KeyValueDatabase { key.push(0xFF); key.extend_from_slice(event_type); - db.roomusertype_roomuserdataid + self.roomusertype_roomuserdataid .insert(&key, &roomuserdataid)?; } Ok(()) @@ -672,7 +682,7 @@ impl KeyValueDatabase { migration(6, || { // Set room member count - for (roomid, _) in db.roomid_shortstatehash.iter() { + for (roomid, _) in self.roomid_shortstatehash.iter() { let string = utils::string_from_bytes(&roomid).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); services() @@ -750,7 +760,7 @@ impl KeyValueDatabase { }; for (k, seventid) in - db.db.open_tree("stateid_shorteventid")?.iter() + self.db.open_tree("stateid_shorteventid")?.iter() { let sstatehash = ShortStateHash::new( utils::u64_from_bytes(&k[0..size_of::()]) @@ -776,7 +786,7 @@ impl KeyValueDatabase { current_state = HashSet::new(); current_sstatehash = Some(sstatehash); - let event_id = db + let event_id = self .shorteventid_eventid .get(&seventid) .unwrap() @@ -820,14 +830,14 @@ impl KeyValueDatabase { migration(8, || { // Generate short room ids for all rooms - for (room_id, _) in db.roomid_shortstatehash.iter() { + for (room_id, _) in self.roomid_shortstatehash.iter() { let shortroomid = services().globals.next_count()?.to_be_bytes(); - db.roomid_shortroomid.insert(&room_id, &shortroomid)?; + self.roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } // Update pduids db layout - let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { + let mut batch = self.pduid_pdu.iter().filter_map(|(key, v)| { if !key.starts_with(b"!") { return None; } @@ -835,7 +845,7 @@ impl KeyValueDatabase { let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); - let short_room_id = db + let short_room_id = self .roomid_shortroomid .get(room_id) .unwrap() @@ -847,10 +857,10 @@ impl KeyValueDatabase { Some((new_key, v)) }); - db.pduid_pdu.insert_batch(&mut batch)?; + self.pduid_pdu.insert_batch(&mut batch)?; let mut batch2 = - db.eventid_pduid.iter().filter_map(|(k, value)| { + self.eventid_pduid.iter().filter_map(|(k, value)| { if !value.starts_with(b"!") { return None; } @@ -858,7 +868,7 @@ impl KeyValueDatabase { let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); - let short_room_id = db + let short_room_id = self .roomid_shortroomid .get(room_id) .unwrap() @@ -870,13 +880,13 @@ impl KeyValueDatabase { Some((k, new_value)) }); - db.eventid_pduid.insert_batch(&mut batch2)?; + self.eventid_pduid.insert_batch(&mut batch2)?; Ok(()) })?; migration(9, || { // Update tokenids db layout - let mut iter = db + let mut iter = self .tokenids .iter() .filter_map(|(key, _)| { @@ -889,7 +899,7 @@ impl KeyValueDatabase { let _pdu_id_room = parts.next().unwrap(); let pdu_id_count = parts.next().unwrap(); - let short_room_id = db + let short_room_id = self .roomid_shortroomid .get(room_id) .unwrap() @@ -903,20 +913,21 @@ impl KeyValueDatabase { .peekable(); while iter.peek().is_some() { - db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + self.tokenids + .insert_batch(&mut iter.by_ref().take(1000))?; debug!("Inserted smaller batch"); } info!("Deleting starts"); - let batch2: Vec<_> = db + let batch2: Vec<_> = self .tokenids .iter() .filter_map(|(key, _)| key.starts_with(b"!").then_some(key)) .collect(); for key in batch2 { - db.tokenids.remove(&key)?; + self.tokenids.remove(&key)?; } Ok(()) })?; @@ -924,9 +935,9 @@ impl KeyValueDatabase { migration(10, || { // Add other direction for shortstatekeys for (statekey, shortstatekey) in - db.statekey_shortstatekey.iter() + self.statekey_shortstatekey.iter() { - db.shortstatekey_statekey + self.shortstatekey_statekey .insert(&shortstatekey, &statekey)?; } @@ -939,7 +950,9 @@ impl KeyValueDatabase { })?; migration(11, || { - db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; + self.db + .open_tree("userdevicesessionid_uiaarequest")? + .clear()?; Ok(()) })?; @@ -1125,39 +1138,6 @@ impl KeyValueDatabase { ); } - services().admin.start_handler(); - - // Set emergency access for the grapevine user - match set_emergency_access() { - Ok(pwd_set) => { - if pwd_set { - warn!( - "The Grapevine account emergency password is set! \ - Please unset it as soon as you finish admin account \ - recovery!" - ); - services().admin.send_message( - RoomMessageEventContent::text_plain( - "The Grapevine account emergency password is set! \ - Please unset it as soon as you finish admin \ - account recovery!", - ), - ); - } - } - Err(error) => { - error!( - %error, - "Could not set the configured emergency password for the \ - Grapevine user", - ); - } - }; - - services().sending.start_handler(); - - Self::start_cleanup_task(); - Ok(()) } @@ -1210,36 +1190,6 @@ impl KeyValueDatabase { } } -/// Sets the emergency password and push rules for the @grapevine account in -/// case emergency password is set -fn set_emergency_access() -> Result { - let admin_bot = services().globals.admin_bot_user_id.as_ref(); - - services().users.set_password( - admin_bot, - services().globals.emergency_password().as_deref(), - )?; - - let (ruleset, res) = match services().globals.emergency_password() { - Some(_) => (Ruleset::server_default(admin_bot), Ok(true)), - None => (Ruleset::new(), Ok(false)), - }; - - services().account_data.update( - None, - admin_bot, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; - - res -} - /// If the current version is older than `new_version`, execute a migration /// function. fn migration(new_version: u64, migration: F) -> Result<(), Error> diff --git a/src/service/globals.rs b/src/service/globals.rs index b4ccba9d..9a280a23 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -23,19 +23,27 @@ use hyper_util::{ }; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::{ - api::federation::discovery::ServerSigningKeys, serde::Base64, DeviceId, - MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, - OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, - UserId, + api::federation::discovery::ServerSigningKeys, + events::{ + push_rules::PushRulesEventContent, + room::message::RoomMessageEventContent, GlobalAccountDataEvent, + GlobalAccountDataEventType, + }, + push::Ruleset, + serde::Base64, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomAliasId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, + ServerName, UserId, }; use tokio::sync::{broadcast, Mutex, RwLock, Semaphore}; -use tracing::{error, Instrument}; +use tracing::{error, warn, Instrument}; use trust_dns_resolver::TokioAsyncResolver; use crate::{ api::server_server::FedDest, observability::FilterReloadHandles, service::media::MediaFileKey, + services, utils::on_demand_hashmap::{OnDemandHashMap, TokenSet}, Config, Error, Result, }; @@ -406,6 +414,66 @@ impl Service { &self.config.emergency_password } + /// If the emergency password option is set, attempts to set the emergency + /// password and push rules for the @grapevine account. + /// + /// If an error occurs, it is logged. + pub(crate) fn set_emergency_access(&self) { + let inner = || -> Result { + let admin_bot = self.admin_bot_user_id.as_ref(); + + services().users.set_password( + admin_bot, + self.emergency_password().as_deref(), + )?; + + let (ruleset, res) = match self.emergency_password() { + Some(_) => (Ruleset::server_default(admin_bot), Ok(true)), + None => (Ruleset::new(), Ok(false)), + }; + + services().account_data.update( + None, + admin_bot, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + )?; + + res + }; + + match inner() { + Ok(pwd_set) => { + if pwd_set { + warn!( + "The Grapevine account emergency password is set! \ + Please unset it as soon as you finish admin account \ + recovery!" + ); + services().admin.send_message( + RoomMessageEventContent::text_plain( + "The Grapevine account emergency password is set! \ + Please unset it as soon as you finish admin \ + account recovery!", + ), + ); + } + } + Err(error) => { + error!( + %error, + "Could not set the configured emergency password for the \ + Grapevine user", + ); + } + }; + } + pub(crate) fn supported_room_versions(&self) -> Vec { self.stable_room_versions.clone() }