diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 819ef98c..7934535f 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -6,7 +6,7 @@ use lru_cache::LruCache; use ruma::{ api::federation::discovery::{OldVerifyKey, ServerSigningKeys}, signatures::Ed25519KeyPair, - DeviceId, ServerName, UserId, + DeviceId, OwnedServerName, ServerName, UserId, }; use crate::{ @@ -361,4 +361,28 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" self.global.insert(b"version", &new_version.to_be_bytes())?; Ok(()) } + + fn set_server_name(&self, server_name: &ServerName) -> Result<()> { + self.global.insert(b"server_name", server_name.as_bytes()) + } + + fn server_name(&self) -> Result> { + let opt_bytes = self + .global + .get(b"server_name") + .map_err(|_| Error::bad_database("Failed to read from globals"))?; + + // `server_name` has not been set yet + let Some(bytes) = opt_bytes else { + return Ok(None); + }; + + let utf8 = String::from_utf8(bytes) + .map_err(|_| Error::bad_database("Invalid UTF-8 in server_name"))?; + + let server_name = OwnedServerName::try_from(utf8) + .map_err(|_| Error::bad_database("Invalid server_name"))?; + + Ok(Some(server_name)) + } } diff --git a/src/error.rs b/src/error.rs index 72bcd337..69d6b38f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -75,6 +75,9 @@ pub(crate) enum ServeCommand { #[allow(missing_docs)] #[derive(Error, Debug)] pub(crate) enum ServerNameChanged { + #[error("failed to read saved server_name")] + ReadSavedServerName(#[source] crate::utils::error::Error), + #[error("failed to check if there are any users")] NonZeroUsers(#[source] crate::utils::error::Error), @@ -83,6 +86,9 @@ pub(crate) enum ServerNameChanged { #[error("`server_name` in the database and config file differ")] Renamed, + + #[error("failed to save the configured server_name")] + SaveServerName(#[source] crate::utils::error::Error), } /// Observability initialization errors diff --git a/src/service/globals.rs b/src/service/globals.rs index 1189bb13..0de48be0 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -307,6 +307,10 @@ impl Service { /// Check if `server_name` in the DB and config differ, return error if so /// + /// This function will save the currently configured `server_name` if the + /// check passes, so that future calls to this function will continue to + /// check against the first-configured value. + /// /// Matrix resource ownership is based on the server name; changing it /// requires recreating the database from scratch. This check needs to be /// done before background tasks are started to avoid data races. @@ -317,6 +321,20 @@ impl Service { ) -> Result<(), crate::error::ServerNameChanged> { use crate::error::ServerNameChanged as Error; + let config = &*services().globals.config.server_name; + + let opt_saved = + self.saved_server_name().map_err(Error::ReadSavedServerName)?; + + // Check against saved server name + if let Some(saved) = opt_saved { + if saved == config { + return Ok(()); + } + + return Err(Error::Renamed); + } + let non_zero_users = services() .users .count() @@ -328,10 +346,18 @@ impl Service { .exists(&self.admin_bot_user_id) .map_err(Error::AdminBotExists)?; + // Fall back to checking against the admin bot user ID if non_zero_users && !admin_bot_exists { return Err(Error::Renamed); } + // If the server_name wasn't saved and the admin bot user ID check + // didn't fail, save the current server_name + services() + .globals + .save_server_name(config) + .map_err(Error::SaveServerName)?; + Ok(()) } @@ -618,6 +644,18 @@ impl Service { self.shutdown.store(true, atomic::Ordering::Relaxed); self.rotate.fire(); } + + pub(crate) fn save_server_name( + &self, + server_name: &ServerName, + ) -> Result<()> { + self.db.set_server_name(server_name) + } + + // Named this way to avoid conflicts with the existing `fn server_name` + pub(crate) fn saved_server_name(&self) -> Result> { + self.db.server_name() + } } fn reqwest_client_builder(config: &Config) -> Result { diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 28e7e512..a232f81c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -8,7 +8,7 @@ use ruma::{ api::federation::discovery::{OldVerifyKey, ServerSigningKeys, VerifyKey}, serde::Base64, signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, ServerName, UserId, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerName, ServerName, UserId, }; use serde::Deserialize; @@ -119,4 +119,6 @@ pub(crate) trait Data: Send + Sync { ) -> Result>; fn database_version(&self) -> Result; fn bump_database_version(&self, new_version: u64) -> Result<()>; + fn set_server_name(&self, server_name: &ServerName) -> Result<()>; + fn server_name(&self) -> Result>; }