split some logic out of KeyValueDatabase::load_or_create

This method did _a lot_ of things at the same time. In order to use
`KeyValueDatabase` for the migrate-db command, we need to be able to
open a db without attempting to apply all the migrations and without
spawning a bunch of unrelated background tasks.

The state after this refactor is still not great, but it's enough to do
a migration tool.
This commit is contained in:
Benjamin Lee 2024-09-08 16:31:51 -07:00 committed by Charles Hall
parent 059dfe54e3
commit 279c6472c5
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
3 changed files with 127 additions and 103 deletions

View file

@ -70,9 +70,15 @@ pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> {
.expect("should be able to increase the soft limit to the hard limit"); .expect("should be able to increase the soft limit to the hard limit");
info!("Loading database"); info!("Loading database");
KeyValueDatabase::load_or_create(config, reload_handles) let db = KeyValueDatabase::load_or_create(config, reload_handles)
.await
.map_err(Error::DatabaseError)?; .map_err(Error::DatabaseError)?;
db.apply_migrations().await.map_err(Error::DatabaseError)?;
info!("Starting background tasks");
services().admin.start_handler();
services().sending.start_handler();
KeyValueDatabase::start_cleanup_task();
services().globals.set_emergency_access();
info!("Starting server"); info!("Starting server");
run_server().await?; run_server().await?;

View file

@ -14,9 +14,7 @@ use abstraction::{KeyValueDatabaseEngine, KvTree};
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
events::{ events::{
push_rules::{PushRulesEvent, PushRulesEventContent}, push_rules::PushRulesEvent, GlobalAccountDataEventType, StateEventType,
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
}, },
push::Ruleset, push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId,
@ -314,16 +312,17 @@ impl KeyValueDatabase {
Ok(()) Ok(())
} }
/// Load an existing database or create a new one. /// Load an existing database or create a new one, and initialize all
/// services with the loaded database.
#[cfg_attr( #[cfg_attr(
not(any(feature = "rocksdb", feature = "sqlite")), not(any(feature = "rocksdb", feature = "sqlite")),
allow(unreachable_code) allow(unreachable_code)
)] )]
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn load_or_create( pub(crate) fn load_or_create(
config: Config, config: Config,
reload_handles: FilterReloadHandles, reload_handles: FilterReloadHandles,
) -> Result<()> { ) -> Result<&'static KeyValueDatabase> {
Self::check_db_setup(&config)?; Self::check_db_setup(&config)?;
if !Path::new(&config.database.path).exists() { if !Path::new(&config.database.path).exists() {
@ -565,13 +564,23 @@ impl KeyValueDatabase {
} }
} }
Ok(db)
}
/// Ensure that the database is at the current version, applying migrations
/// if necessary.
///
/// If it is not possible to migrate the database to the current version,
/// returns an error.
#[allow(clippy::too_many_lines)]
pub(crate) async fn apply_migrations(&self) -> Result<()> {
// If the database has any data, perform data migrations before starting // If the database has any data, perform data migrations before starting
let latest_database_version = 13; let latest_database_version = 13;
if services().users.count()? > 0 { if services().users.count()? > 0 {
// MIGRATIONS // MIGRATIONS
migration(1, || { migration(1, || {
for (roomserverid, _) in db.roomserverids.iter() { for (roomserverid, _) in self.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xFF); let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = let room_id =
parts.next().expect("split always returns one element"); parts.next().expect("split always returns one element");
@ -583,7 +592,7 @@ impl KeyValueDatabase {
serverroomid.push(0xFF); serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id); serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?; self.serverroomids.insert(&serverroomid, &[])?;
} }
Ok(()) Ok(())
})?; })?;
@ -591,7 +600,7 @@ impl KeyValueDatabase {
migration(2, || { migration(2, || {
// We accidentally inserted hashed versions of "" into the db // We accidentally inserted hashed versions of "" into the db
// instead of just "" // instead of just ""
for (userid, password) in db.userid_password.iter() { for (userid, password) in self.userid_password.iter() {
let password = utils::string_from_bytes(&password); let password = utils::string_from_bytes(&password);
let empty_hashed_password = password let empty_hashed_password = password
@ -600,7 +609,7 @@ impl KeyValueDatabase {
}); });
if empty_hashed_password { if empty_hashed_password {
db.userid_password.insert(&userid, b"")?; self.userid_password.insert(&userid, b"")?;
} }
} }
Ok(()) Ok(())
@ -608,7 +617,7 @@ impl KeyValueDatabase {
migration(3, || { migration(3, || {
// Move media to filesystem // Move media to filesystem
for (key, content) in db.mediaid_file.iter() { for (key, content) in self.mediaid_file.iter() {
let key = MediaFileKey::new(key); let key = MediaFileKey::new(key);
if content.is_empty() { if content.is_empty() {
continue; continue;
@ -617,7 +626,7 @@ impl KeyValueDatabase {
let path = services().globals.get_media_file(&key); let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?; let mut file = fs::File::create(path)?;
file.write_all(&content)?; file.write_all(&content)?;
db.mediaid_file.insert(key.as_bytes(), &[])?; self.mediaid_file.insert(key.as_bytes(), &[])?;
} }
Ok(()) Ok(())
})?; })?;
@ -650,7 +659,8 @@ impl KeyValueDatabase {
migration(5, || { migration(5, || {
// Upgrade user data store // Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() for (roomuserdataid, _) in
self.roomuserdataid_accountdata.iter()
{ {
let mut parts = roomuserdataid.split(|&b| b == 0xFF); let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
@ -664,7 +674,7 @@ impl KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(event_type); key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?; .insert(&key, &roomuserdataid)?;
} }
Ok(()) Ok(())
@ -672,7 +682,7 @@ impl KeyValueDatabase {
migration(6, || { migration(6, || {
// Set room member count // Set room member count
for (roomid, _) in db.roomid_shortstatehash.iter() { for (roomid, _) in self.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap(); let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services() services()
@ -750,7 +760,7 @@ impl KeyValueDatabase {
}; };
for (k, seventid) in for (k, seventid) in
db.db.open_tree("stateid_shorteventid")?.iter() self.db.open_tree("stateid_shorteventid")?.iter()
{ {
let sstatehash = ShortStateHash::new( let sstatehash = ShortStateHash::new(
utils::u64_from_bytes(&k[0..size_of::<u64>()]) utils::u64_from_bytes(&k[0..size_of::<u64>()])
@ -776,7 +786,7 @@ impl KeyValueDatabase {
current_state = HashSet::new(); current_state = HashSet::new();
current_sstatehash = Some(sstatehash); current_sstatehash = Some(sstatehash);
let event_id = db let event_id = self
.shorteventid_eventid .shorteventid_eventid
.get(&seventid) .get(&seventid)
.unwrap() .unwrap()
@ -820,14 +830,14 @@ impl KeyValueDatabase {
migration(8, || { migration(8, || {
// Generate short room ids for all rooms // Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() { for (room_id, _) in self.roomid_shortstatehash.iter() {
let shortroomid = let shortroomid =
services().globals.next_count()?.to_be_bytes(); services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?; self.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8"); info!("Migration: 8");
} }
// Update pduids db layout // Update pduids db layout
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { let mut batch = self.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") { if !key.starts_with(b"!") {
return None; return None;
} }
@ -835,7 +845,7 @@ impl KeyValueDatabase {
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap(); let count = parts.next().unwrap();
let short_room_id = db let short_room_id = self
.roomid_shortroomid .roomid_shortroomid
.get(room_id) .get(room_id)
.unwrap() .unwrap()
@ -847,10 +857,10 @@ impl KeyValueDatabase {
Some((new_key, v)) Some((new_key, v))
}); });
db.pduid_pdu.insert_batch(&mut batch)?; self.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = let mut batch2 =
db.eventid_pduid.iter().filter_map(|(k, value)| { self.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") { if !value.starts_with(b"!") {
return None; return None;
} }
@ -858,7 +868,7 @@ impl KeyValueDatabase {
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap(); let count = parts.next().unwrap();
let short_room_id = db let short_room_id = self
.roomid_shortroomid .roomid_shortroomid
.get(room_id) .get(room_id)
.unwrap() .unwrap()
@ -870,13 +880,13 @@ impl KeyValueDatabase {
Some((k, new_value)) Some((k, new_value))
}); });
db.eventid_pduid.insert_batch(&mut batch2)?; self.eventid_pduid.insert_batch(&mut batch2)?;
Ok(()) Ok(())
})?; })?;
migration(9, || { migration(9, || {
// Update tokenids db layout // Update tokenids db layout
let mut iter = db let mut iter = self
.tokenids .tokenids
.iter() .iter()
.filter_map(|(key, _)| { .filter_map(|(key, _)| {
@ -889,7 +899,7 @@ impl KeyValueDatabase {
let _pdu_id_room = parts.next().unwrap(); let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap(); let pdu_id_count = parts.next().unwrap();
let short_room_id = db let short_room_id = self
.roomid_shortroomid .roomid_shortroomid
.get(room_id) .get(room_id)
.unwrap() .unwrap()
@ -903,20 +913,21 @@ impl KeyValueDatabase {
.peekable(); .peekable();
while iter.peek().is_some() { while iter.peek().is_some() {
db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; self.tokenids
.insert_batch(&mut iter.by_ref().take(1000))?;
debug!("Inserted smaller batch"); debug!("Inserted smaller batch");
} }
info!("Deleting starts"); info!("Deleting starts");
let batch2: Vec<_> = db let batch2: Vec<_> = self
.tokenids .tokenids
.iter() .iter()
.filter_map(|(key, _)| key.starts_with(b"!").then_some(key)) .filter_map(|(key, _)| key.starts_with(b"!").then_some(key))
.collect(); .collect();
for key in batch2 { for key in batch2 {
db.tokenids.remove(&key)?; self.tokenids.remove(&key)?;
} }
Ok(()) Ok(())
})?; })?;
@ -924,9 +935,9 @@ impl KeyValueDatabase {
migration(10, || { migration(10, || {
// Add other direction for shortstatekeys // Add other direction for shortstatekeys
for (statekey, shortstatekey) in for (statekey, shortstatekey) in
db.statekey_shortstatekey.iter() self.statekey_shortstatekey.iter()
{ {
db.shortstatekey_statekey self.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?; .insert(&shortstatekey, &statekey)?;
} }
@ -939,7 +950,9 @@ impl KeyValueDatabase {
})?; })?;
migration(11, || { migration(11, || {
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; self.db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
Ok(()) Ok(())
})?; })?;
@ -1125,39 +1138,6 @@ impl KeyValueDatabase {
); );
} }
services().admin.start_handler();
// Set emergency access for the grapevine user
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin account \
recovery!"
);
services().admin.send_message(
RoomMessageEventContent::text_plain(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin \
account recovery!",
),
);
}
}
Err(error) => {
error!(
%error,
"Could not set the configured emergency password for the \
Grapevine user",
);
}
};
services().sending.start_handler();
Self::start_cleanup_task();
Ok(()) Ok(())
} }
@ -1210,36 +1190,6 @@ impl KeyValueDatabase {
} }
} }
/// Sets the emergency password and push rules for the @grapevine account in
/// case emergency password is set
fn set_emergency_access() -> Result<bool> {
let admin_bot = services().globals.admin_bot_user_id.as_ref();
services().users.set_password(
admin_bot,
services().globals.emergency_password().as_deref(),
)?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(admin_bot), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
admin_bot,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
}
/// If the current version is older than `new_version`, execute a migration /// If the current version is older than `new_version`, execute a migration
/// function. /// function.
fn migration<F>(new_version: u64, migration: F) -> Result<(), Error> fn migration<F>(new_version: u64, migration: F) -> Result<(), Error>

View file

@ -23,19 +23,27 @@ use hyper_util::{
}; };
use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::{ use ruma::{
api::federation::discovery::ServerSigningKeys, serde::Base64, DeviceId, api::federation::discovery::ServerSigningKeys,
MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, events::{
OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, push_rules::PushRulesEventContent,
UserId, room::message::RoomMessageEventContent, GlobalAccountDataEvent,
GlobalAccountDataEventType,
},
push::Ruleset,
serde::Base64,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomAliasId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId,
ServerName, UserId,
}; };
use tokio::sync::{broadcast, Mutex, RwLock, Semaphore}; use tokio::sync::{broadcast, Mutex, RwLock, Semaphore};
use tracing::{error, Instrument}; use tracing::{error, warn, Instrument};
use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::TokioAsyncResolver;
use crate::{ use crate::{
api::server_server::FedDest, api::server_server::FedDest,
observability::FilterReloadHandles, observability::FilterReloadHandles,
service::media::MediaFileKey, service::media::MediaFileKey,
services,
utils::on_demand_hashmap::{OnDemandHashMap, TokenSet}, utils::on_demand_hashmap::{OnDemandHashMap, TokenSet},
Config, Error, Result, Config, Error, Result,
}; };
@ -406,6 +414,66 @@ impl Service {
&self.config.emergency_password &self.config.emergency_password
} }
/// If the emergency password option is set, attempts to set the emergency
/// password and push rules for the @grapevine account.
///
/// If an error occurs, it is logged.
pub(crate) fn set_emergency_access(&self) {
let inner = || -> Result<bool> {
let admin_bot = self.admin_bot_user_id.as_ref();
services().users.set_password(
admin_bot,
self.emergency_password().as_deref(),
)?;
let (ruleset, res) = match self.emergency_password() {
Some(_) => (Ruleset::server_default(admin_bot), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
admin_bot,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
};
match inner() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin account \
recovery!"
);
services().admin.send_message(
RoomMessageEventContent::text_plain(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin \
account recovery!",
),
);
}
}
Err(error) => {
error!(
%error,
"Could not set the configured emergency password for the \
Grapevine user",
);
}
};
}
pub(crate) fn supported_room_versions(&self) -> Vec<RoomVersionId> { pub(crate) fn supported_room_versions(&self) -> Vec<RoomVersionId> {
self.stable_room_versions.clone() self.stable_room_versions.clone()
} }