diff --git a/rustfmt.toml b/rustfmt.toml index 739b454f..58355ad5 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,17 @@ -unstable_features = true -imports_granularity="Crate" +edition = "2021" + +condense_wildcard_suffixes = true +format_code_in_doc_comments = true +format_macro_bodies = true +format_macro_matchers = true +format_strings = true +group_imports = "StdExternalCrate" +hex_literal_case = "Upper" +imports_granularity = "Crate" +max_width = 80 +newline_style = "Unix" +reorder_impl_items = true +use_field_init_shorthand = true +use_small_heuristics = "Off" +use_try_shorthand = true +wrap_comments = true diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 98097d22..fbf03179 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,14 +1,18 @@ -use crate::{services, utils, Error, Result}; +use std::{fmt::Debug, mem, time::Duration}; + use bytes::BytesMut; use ruma::api::{ - appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, + SendAccessToken, }; -use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; +use crate::{services, utils, Error, Result}; + /// Sends a request to an appservice /// -/// Only returns None if there is no url specified in the appservice registration file +/// Only returns None if there is no url specified in the appservice +/// registration file #[tracing::instrument(skip(request))] pub(crate) async fn send_request( registration: Registration, @@ -45,7 +49,8 @@ where .parse() .unwrap(), ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); + *http_request.uri_mut() = + parts.try_into().expect("our manipulation is always valid"); let mut reqwest_request = reqwest::Request::try_from(http_request)?; @@ -70,9 +75,8 @@ where // reqwest::Response -> http::Response conversion let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); + let mut http_response_builder = + http::Response::builder().status(status).version(response.version()); mem::swap( response.headers_mut(), http_response_builder diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 91b5dd4d..4d529b68 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -1,22 +1,25 @@ -use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{api::client_server, services, utils, Error, Result, Ruma}; +use register::RegistrationKind; use ruma::{ api::client::{ account::{ change_password, deactivate, get_3pids, get_username_availability, register::{self, LoginType}, - request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, - whoami, ThirdPartyIdRemovalStatus, + request_3pid_management_token_via_email, + request_3pid_management_token_via_msisdn, whoami, + ThirdPartyIdRemovalStatus, }, error::ErrorKind, uiaa::{AuthFlow, AuthType, UiaaInfo}, }, - events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, + events::{ + room::message::RoomMessageEventContent, GlobalAccountDataEventType, + }, push, UserId, }; use tracing::{info, warn}; -use register::RegistrationKind; +use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; +use crate::{api::client_server, services, utils, Error, Result, Ruma}; const RANDOM_USER_ID_LENGTH: usize = 10; @@ -29,7 +32,8 @@ const RANDOM_USER_ID_LENGTH: usize = 10; /// - The server name of the user id matches this server /// - No user or appservice on this server already claimed this username /// -/// Note: This will not reserve the username, so the username might become invalid when trying to register +/// Note: This will not reserve the username, so the username might become +/// invalid when trying to register pub(crate) async fn get_register_available_route( body: Ruma, ) -> Result { @@ -40,7 +44,8 @@ pub(crate) async fn get_register_available_route( ) .ok() .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == services().globals.server_name() + !user_id.is_historical() + && user_id.server_name() == services().globals.server_name() }) .ok_or(Error::BadRequest( ErrorKind::InvalidUsername, @@ -58,27 +63,35 @@ pub(crate) async fn get_register_available_route( // TODO add check for appservice namespaces // If no if check is true we have an username that's available to be used. - Ok(get_username_availability::v3::Response { available: true }) + Ok(get_username_availability::v3::Response { + available: true, + }) } /// # `POST /_matrix/client/r0/register` /// /// Register an account on this homeserver. /// -/// You can use [`GET /_matrix/client/r0/register/available`](get_register_available_route) +/// You can use [`GET +/// /_matrix/client/r0/register/available`](get_register_available_route) /// to check if the user id is valid and available. /// /// - Only works if registration is enabled -/// - If type is guest: ignores all parameters except `initial_device_display_name` +/// - If type is guest: ignores all parameters except +/// `initial_device_display_name` /// - If sender is not appservice: Requires UIAA (but we only use a dummy stage) -/// - If type is not guest and no username is given: Always fails after UIAA check +/// - If type is not guest and no username is given: Always fails after UIAA +/// check /// - Creates a new account and populates it with default account data -/// - If `inhibit_login` is false: Creates a device and returns `device_id` and `access_token` +/// - If `inhibit_login` is false: Creates a device and returns `device_id` and +/// `access_token` #[allow(clippy::too_many_lines)] pub(crate) async fn register_route( body: Ruma, ) -> Result { - if !services().globals.allow_registration() && body.appservice_info.is_none() { + if !services().globals.allow_registration() + && body.appservice_info.is_none() + { return Err(Error::BadRequest( ErrorKind::Forbidden, "Registration has been disabled.", @@ -158,7 +171,8 @@ pub(crate) async fn register_route( }; body.appservice_info.is_some() } else { - // No registration token necessary, but clients must still go through the flow + // No registration token necessary, but clients must still go through + // the flow uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Dummy], @@ -174,8 +188,11 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { let (worked, uiaainfo) = services().uiaa.try_auth( - &UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid"), + &UserId::parse_with_server_name( + "", + services().globals.server_name(), + ) + .expect("we know this is valid"), "".into(), auth, &uiaainfo, @@ -187,8 +204,11 @@ pub(crate) async fn register_route( } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services().uiaa.create( - &UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid"), + &UserId::parse_with_server_name( + "", + services().globals.server_name(), + ) + .expect("we know this is valid"), "".into(), &uiaainfo, &json, @@ -211,9 +231,7 @@ pub(crate) async fn register_route( // Default to pretty displayname let displayname = user_id.localpart().to_owned(); - services() - .users - .set_displayname(&user_id, Some(displayname.clone()))?; + services().users.set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data services().account_data.update( @@ -260,29 +278,24 @@ pub(crate) async fn register_route( info!("New user {} registered on this server.", user_id); if body.appservice_info.is_none() && !is_guest { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "New user {user_id} registered on this server." - ))); + services().admin.send_message(RoomMessageEventContent::notice_plain( + format!("New user {user_id} registered on this server."), + )); } // If this is the first real user, grant them admin privileges // Note: the server user, @grapevine:servername, is generated first if !is_guest { if let Some(admin_room) = services().admin.get_admin_room()? { - if services() - .rooms - .state_cache - .room_joined_count(&admin_room)? + if services().rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { - services() - .admin - .make_user_admin(&user_id, displayname) - .await?; + services().admin.make_user_admin(&user_id, displayname).await?; - warn!("Granting {} admin privileges as the first user", user_id); + warn!( + "Granting {} admin privileges as the first user", + user_id + ); } } } @@ -302,19 +315,23 @@ pub(crate) async fn register_route( /// /// - Requires UIAA to verify user password /// - Changes the password of the sender user -/// - The password hash is calculated using argon2 with 32 character salt, the plain password is +/// - The password hash is calculated using argon2 with 32 character salt, the +/// plain password is /// not saved /// -/// If `logout_devices` is true it does the following for each device except the sender device: +/// If `logout_devices` is true it does the following for each device except the +/// sender device: /// - Invalidates access token -/// - Deletes device metadata (device ID, device display name, last seen IP, last seen timestamp) +/// - Deletes device metadata (device ID, device display name, last seen IP, +/// last seen timestamp) /// - Forgets to-device events /// - Triggers device list updates pub(crate) async fn change_password_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -327,27 +344,25 @@ pub(crate) async fn change_password_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services().uiaa.try_auth( + sender_user, + sender_device, + auth, + &uiaainfo, + )?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services() - .users - .set_password(sender_user, Some(&body.new_password))?; + services().users.set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one @@ -362,11 +377,9 @@ pub(crate) async fn change_password_route( } info!("User {} changed their password.", sender_user); - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "User {sender_user} changed their password." - ))); + services().admin.send_message(RoomMessageEventContent::notice_plain( + format!("User {sender_user} changed their password."), + )); Ok(change_password::v3::Response {}) } @@ -376,14 +389,17 @@ pub(crate) async fn change_password_route( /// Get `user_id` of the sender user. /// /// Note: Also works for Application Services -pub(crate) async fn whoami_route(body: Ruma) -> Result { +pub(crate) async fn whoami_route( + body: Ruma, +) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let device_id = body.sender_device.as_ref().cloned(); Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services().users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services().users.is_deactivated(sender_user)? + && body.appservice_info.is_none(), }) } @@ -393,7 +409,8 @@ pub(crate) async fn whoami_route(body: Ruma) -> Result, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -414,19 +432,19 @@ pub(crate) async fn deactivate_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services().uiaa.try_auth( + sender_user, + sender_device, + auth, + &uiaainfo, + )?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -439,11 +457,9 @@ pub(crate) async fn deactivate_route( services().users.deactivate_account(sender_user)?; info!("User {} deactivated their account.", sender_user); - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "User {sender_user} deactivated their account." - ))); + services().admin.send_message(RoomMessageEventContent::notice_plain( + format!("User {sender_user} deactivated their account."), + )); Ok(deactivate::v3::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, @@ -458,16 +474,19 @@ pub(crate) async fn deactivate_route( pub(crate) async fn third_party_route( body: Ruma, ) -> Result { - let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let _sender_user = + body.sender_user.as_ref().expect("user is authenticated"); Ok(get_3pids::v3::Response::new(Vec::new())) } /// # `POST /_matrix/client/v3/account/3pid/email/requestToken` /// -/// "This API should be used to request validation tokens when adding an email address to an account" +/// "This API should be used to request validation tokens when adding an email +/// address to an account" /// -/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +/// - 403 signals that The homeserver does not allow the third party identifier +/// as a contact option. pub(crate) async fn request_3pid_management_token_via_email_route( _body: Ruma, ) -> Result { @@ -479,9 +498,11 @@ pub(crate) async fn request_3pid_management_token_via_email_route( /// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken` /// -/// "This API should be used to request validation tokens when adding an phone number to an account" +/// "This API should be used to request validation tokens when adding an phone +/// number to an account" /// -/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +/// - 403 signals that The homeserver does not allow the third party identifier +/// as a contact option. pub(crate) async fn request_3pid_management_token_via_msisdn_route( _body: Ruma, ) -> Result { diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index c880bd5d..5d1a0c30 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,4 +1,3 @@ -use crate::{services, Error, Result, Ruma}; use rand::seq::SliceRandom; use ruma::{ api::{ @@ -12,6 +11,8 @@ use ruma::{ OwnedRoomAliasId, }; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/directory/room/{roomAlias}` /// /// Creates a new room alias on this server. @@ -32,30 +33,18 @@ pub(crate) async fn create_alias_route( "Room alias is not in namespace.", )); } - } else if services() - .appservice - .is_exclusive_alias(&body.room_alias) - .await - { + } else if services().appservice.is_exclusive_alias(&body.room_alias).await { return Err(Error::BadRequest( ErrorKind::Exclusive, "Room alias reserved by appservice.", )); } - if services() - .rooms - .alias - .resolve_local_alias(&body.room_alias)? - .is_some() - { + if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() { return Err(Error::Conflict("Alias already exists.")); } - services() - .rooms - .alias - .set_alias(&body.room_alias, &body.room_id)?; + services().rooms.alias.set_alias(&body.room_alias, &body.room_id)?; Ok(create_alias::v3::Response::new()) } @@ -83,11 +72,7 @@ pub(crate) async fn delete_alias_route( "Room alias is not in namespace.", )); } - } else if services() - .appservice - .is_exclusive_alias(&body.room_alias) - .await - { + } else if services().appservice.is_exclusive_alias(&body.room_alias).await { return Err(Error::BadRequest( ErrorKind::Exclusive, "Room alias reserved by appservice.", @@ -157,7 +142,10 @@ pub(crate) async fn get_alias_helper( .alias .resolve_local_alias(&room_alias)? .ok_or_else(|| { - Error::bad_config("Appservice lied to us. Room does not exist.") + Error::bad_config( + "Appservice lied to us. Room does not \ + exist.", + ) })?, ); break; diff --git a/src/api/client_server/backup.rs b/src/api/client_server/backup.rs index daab262d..4cfffcc1 100644 --- a/src/api/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -1,15 +1,16 @@ -use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, delete_backup_keys, delete_backup_keys_for_room, - delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys, - get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info, - update_backup_version, + delete_backup_keys_for_session, delete_backup_version, get_backup_info, + get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, }, error::ErrorKind, }; +use crate::{services, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/room_keys/version` /// /// Creates a new backup. @@ -17,23 +18,27 @@ pub(crate) async fn create_backup_version_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = services() - .key_backups - .create_backup(sender_user, &body.algorithm)?; + let version = + services().key_backups.create_backup(sender_user, &body.algorithm)?; - Ok(create_backup_version::v3::Response { version }) + Ok(create_backup_version::v3::Response { + version, + }) } /// # `PUT /_matrix/client/r0/room_keys/version/{version}` /// -/// Update information about an existing backup. Only `auth_data` can be modified. +/// Update information about an existing backup. Only `auth_data` can be +/// modified. pub(crate) async fn update_backup_version_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + services().key_backups.update_backup( + sender_user, + &body.version, + &body.algorithm, + )?; Ok(update_backup_version::v3::Response {}) } @@ -88,9 +93,7 @@ pub(crate) async fn get_backup_info_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, version: body.version.clone(), }) } @@ -99,15 +102,14 @@ pub(crate) async fn get_backup_info_route( /// /// Delete an existing key backup. /// -/// - Deletes both information about the backup, as well as all key data related to the backup +/// - Deletes both information about the backup, as well as all key data related +/// to the backup pub(crate) async fn delete_backup_version_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_backup(sender_user, &body.version)?; + services().key_backups.delete_backup(sender_user, &body.version)?; Ok(delete_backup_version::v3::Response {}) } @@ -116,7 +118,8 @@ pub(crate) async fn delete_backup_version_route( /// /// Add the received backup keys to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_route( @@ -132,7 +135,8 @@ pub(crate) async fn add_backup_keys_route( { return Err(Error::BadRequest( ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", + "You may only manipulate the most recently created version of the \ + backup.", )); } @@ -154,9 +158,7 @@ pub(crate) async fn add_backup_keys_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -164,7 +166,8 @@ pub(crate) async fn add_backup_keys_route( /// /// Add the received backup keys to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_for_room_route( @@ -180,7 +183,8 @@ pub(crate) async fn add_backup_keys_for_room_route( { return Err(Error::BadRequest( ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", + "You may only manipulate the most recently created version of the \ + backup.", )); } @@ -200,9 +204,7 @@ pub(crate) async fn add_backup_keys_for_room_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -210,7 +212,8 @@ pub(crate) async fn add_backup_keys_for_room_route( /// /// Add the received backup key to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_for_session_route( @@ -226,7 +229,8 @@ pub(crate) async fn add_backup_keys_for_session_route( { return Err(Error::BadRequest( ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", + "You may only manipulate the most recently created version of the \ + backup.", )); } @@ -244,9 +248,7 @@ pub(crate) async fn add_backup_keys_for_session_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -260,7 +262,9 @@ pub(crate) async fn get_backup_keys_route( let rooms = services().key_backups.get_all(sender_user, &body.version)?; - Ok(get_backup_keys::v3::Response { rooms }) + Ok(get_backup_keys::v3::Response { + rooms, + }) } /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}` @@ -271,11 +275,15 @@ pub(crate) async fn get_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services() - .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + let sessions = services().key_backups.get_room( + sender_user, + &body.version, + &body.room_id, + )?; - Ok(get_backup_keys_for_room::v3::Response { sessions }) + Ok(get_backup_keys_for_room::v3::Response { + sessions, + }) } /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` @@ -288,13 +296,20 @@ pub(crate) async fn get_backup_keys_for_session_route( let key_data = services() .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? + .get_session( + sender_user, + &body.version, + &body.room_id, + &body.session_id, + )? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Backup key not found for this user's session.", ))?; - Ok(get_backup_keys_for_session::v3::Response { key_data }) + Ok(get_backup_keys_for_session::v3::Response { + key_data, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys` @@ -305,9 +320,7 @@ pub(crate) async fn delete_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_all_keys(sender_user, &body.version)?; + services().key_backups.delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { count: services() @@ -315,9 +328,7 @@ pub(crate) async fn delete_backup_keys_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -329,9 +340,11 @@ pub(crate) async fn delete_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + services().key_backups.delete_room_keys( + sender_user, + &body.version, + &body.room_id, + )?; Ok(delete_backup_keys_for_room::v3::Response { count: services() @@ -339,9 +352,7 @@ pub(crate) async fn delete_backup_keys_for_room_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -366,8 +377,6 @@ pub(crate) async fn delete_backup_keys_for_session_route( .count_keys(sender_user, &body.version)? .try_into() .expect("count should fit in UInt"), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } diff --git a/src/api/client_server/capabilities.rs b/src/api/client_server/capabilities.rs index f248d1b9..61152bf0 100644 --- a/src/api/client_server/capabilities.rs +++ b/src/api/client_server/capabilities.rs @@ -1,12 +1,15 @@ -use crate::{services, Result, Ruma}; +use std::collections::BTreeMap; + use ruma::api::client::discovery::get_capabilities::{ self, Capabilities, RoomVersionStability, RoomVersionsCapability, }; -use std::collections::BTreeMap; + +use crate::{services, Result, Ruma}; /// # `GET /_matrix/client/r0/capabilities` /// -/// Get information on the supported feature set and other relevent capabilities of this server. +/// Get information on the supported feature set and other relevent capabilities +/// of this server. pub(crate) async fn get_capabilities_route( _body: Ruma, ) -> Result { @@ -24,5 +27,7 @@ pub(crate) async fn get_capabilities_route( available, }; - Ok(get_capabilities::v3::Response { capabilities }) + Ok(get_capabilities::v3::Response { + capabilities, + }) } diff --git a/src/api/client_server/config.rs b/src/api/client_server/config.rs index c650de1f..d4265078 100644 --- a/src/api/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -1,18 +1,21 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{ config::{ - get_global_account_data, get_room_account_data, set_global_account_data, - set_room_account_data, + get_global_account_data, get_room_account_data, + set_global_account_data, set_room_account_data, }, error::ErrorKind, }, - events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent}, + events::{ + AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent, + }, serde::Raw, }; use serde::Deserialize; use serde_json::{json, value::RawValue as RawJsonValue}; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// /// Sets some account data for the sender user. @@ -22,7 +25,9 @@ pub(crate) async fn set_global_account_data_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::BadJson, "Data is invalid.") + })?; let event_type = body.event_type.to_string(); @@ -48,7 +53,9 @@ pub(crate) async fn set_room_account_data_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::BadJson, "Data is invalid.") + })?; let event_type = body.event_type.to_string(); @@ -78,11 +85,16 @@ pub(crate) async fn get_global_account_data_route( .get(None, sender_user, body.event_type.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = + serde_json::from_str::(event.get()) + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })? + .content; - Ok(get_global_account_data::v3::Response { account_data }) + Ok(get_global_account_data::v3::Response { + account_data, + }) } /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` @@ -98,11 +110,16 @@ pub(crate) async fn get_room_account_data_route( .get(Some(&body.room_id), sender_user, body.event_type.clone())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = + serde_json::from_str::(event.get()) + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })? + .content; - Ok(get_room_account_data::v3::Response { account_data }) + Ok(get_room_account_data::v3::Response { + account_data, + }) } #[derive(Deserialize)] diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index 7642ebe3..0ada0bbc 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,60 +1,57 @@ -use crate::{services, Error, Result, Ruma}; +use std::collections::HashSet; + use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + api::client::{ + context::get_context, error::ErrorKind, filter::LazyLoadOptions, + }, events::StateEventType, uint, }; -use std::collections::HashSet; use tracing::error; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// /// Allows loading room history around an event. /// -/// - Only works if the user is joined (TODO: always allow, but only show events if the user was +/// - Only works if the user is joined (TODO: always allow, but only show events +/// if the user was /// joined, depending on `history_visibility`) #[allow(clippy::too_many_lines)] pub(crate) async fn get_context_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); - let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members, - } => (true, *include_redundant_members), - LazyLoadOptions::Disabled => (false, false), - }; + let (lazy_load_enabled, lazy_load_send_redundant) = + match &body.filter.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members, + } => (true, *include_redundant_members), + LazyLoadOptions::Disabled => (false, false), + }; let mut lazy_loaded = HashSet::new(); - let base_token = services() - .rooms - .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Base event id not found.", - ))?; + let base_token = + services().rooms.timeline.get_pdu_count(&body.event_id)?.ok_or( + Error::BadRequest(ErrorKind::NotFound, "Base event id not found."), + )?; - let base_event = - services() - .rooms - .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Base event not found.", - ))?; + let base_event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or( + Error::BadRequest(ErrorKind::NotFound, "Base event not found."), + )?; let room_id = base_event.room_id.clone(); - if !services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? - { + if !services().rooms.state_accessor.user_can_see_event( + sender_user, + &room_id, + &body.event_id, + )? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this event.", @@ -72,8 +69,8 @@ pub(crate) async fn get_context_route( } // Use limit with maximum 100 - let half_limit = - usize::try_from(body.limit.min(uint!(100)) / uint!(2)).expect("0-50 should fit in usize"); + let half_limit = usize::try_from(body.limit.min(uint!(100)) / uint!(2)) + .expect("0-50 should fit in usize"); let base_event = base_event.to_room_event(); @@ -108,10 +105,8 @@ pub(crate) async fn get_context_route( .last() .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_before: Vec<_> = + events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let events_after: Vec<_> = services() .rooms @@ -140,41 +135,33 @@ pub(crate) async fn get_context_route( } } - let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( - events_after - .last() - .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? { - Some(s) => s, - None => services() - .rooms - .state - .get_room_shortstatehash(&room_id)? - .expect("All rooms have state"), - }; + let shortstatehash = + match services().rooms.state_accessor.pdu_shortstatehash( + events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id), + )? { + Some(s) => s, + None => services() + .rooms + .state + .get_room_shortstatehash(&room_id)? + .expect("All rooms have state"), + }; - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; + let state_ids = + services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let end_token = events_after .last() .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_after: Vec<_> = + events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services() - .rooms - .short - .get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = + services().rooms.short.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { diff --git a/src/api/client_server/device.rs b/src/api/client_server/device.rs index 0db1e833..59ae4710 100644 --- a/src/api/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -1,11 +1,14 @@ -use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{ - device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, + device::{ + self, delete_device, delete_devices, get_device, get_devices, + update_device, + }, error::ErrorKind, uiaa::{AuthFlow, AuthType, UiaaInfo}, }; use super::SESSION_ID_LENGTH; +use crate::{services, utils, Error, Result, Ruma}; /// # `GET /_matrix/client/r0/devices` /// @@ -21,7 +24,9 @@ pub(crate) async fn get_devices_route( .filter_map(Result::ok) .collect(); - Ok(get_devices::v3::Response { devices }) + Ok(get_devices::v3::Response { + devices, + }) } /// # `GET /_matrix/client/r0/devices/{deviceId}` @@ -37,7 +42,9 @@ pub(crate) async fn get_device_route( .get_device_metadata(sender_user, &body.body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; - Ok(get_device::v3::Response { device }) + Ok(get_device::v3::Response { + device, + }) } /// # `PUT /_matrix/client/r0/devices/{deviceId}` @@ -55,9 +62,11 @@ pub(crate) async fn update_device_route( device.display_name = body.display_name.clone(); - services() - .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + services().users.update_device_metadata( + sender_user, + &body.device_id, + &device, + )?; Ok(update_device::v3::Response {}) } @@ -68,14 +77,16 @@ pub(crate) async fn update_device_route( /// /// - Requires UIAA to verify user password /// - Invalidates access token -/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) +/// - Deletes device metadata (device id, device display name, last seen ip, +/// last seen ts) /// - Forgets to-device events /// - Triggers device list updates pub(crate) async fn delete_device_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); // UIAA let mut uiaainfo = UiaaInfo { @@ -89,27 +100,25 @@ pub(crate) async fn delete_device_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services().uiaa.try_auth( + sender_user, + sender_device, + auth, + &uiaainfo, + )?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services() - .users - .remove_device(sender_user, &body.device_id)?; + services().users.remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -122,14 +131,16 @@ pub(crate) async fn delete_device_route( /// /// For each device: /// - Invalidates access token -/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) +/// - Deletes device metadata (device id, device display name, last seen ip, +/// last seen ts) /// - Forgets to-device events /// - Triggers device list updates pub(crate) async fn delete_devices_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); // UIAA let mut uiaainfo = UiaaInfo { @@ -143,19 +154,19 @@ pub(crate) async fn delete_devices_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services().uiaa.try_auth( + sender_user, + sender_device, + auth, + &uiaainfo, + )?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 1e0e145f..6776818d 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -1,10 +1,9 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ api::{ client::{ directory::{ - get_public_rooms, get_public_rooms_filtered, get_room_visibility, - set_room_visibility, + get_public_rooms, get_public_rooms_filtered, + get_room_visibility, set_room_visibility, }, error::ErrorKind, room, @@ -18,7 +17,9 @@ use ruma::{ canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + history_visibility::{ + HistoryVisibility, RoomHistoryVisibilityEventContent, + }, join_rules::{JoinRule, RoomJoinRulesEventContent}, topic::RoomTopicEventContent, }, @@ -28,6 +29,8 @@ use ruma::{ }; use tracing::{error, info, warn}; +use crate::{services, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/publicRooms` /// /// Lists the public rooms on this server. @@ -91,7 +94,9 @@ pub(crate) async fn set_room_visibility_route( services().rooms.directory.set_public(&body.room_id)?; info!("{} made {} public", sender_user, body.room_id); } - room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => { + services().rooms.directory.set_not_public(&body.room_id)?; + } _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -115,7 +120,11 @@ pub(crate) async fn get_room_visibility_route( } Ok(get_room_visibility::v3::Response { - visibility: if services().rooms.directory.is_public_room(&body.room_id)? { + visibility: if services() + .rooms + .directory + .is_public_room(&body.room_id)? + { room::Visibility::Public } else { room::Visibility::Private @@ -131,8 +140,8 @@ pub(crate) async fn get_public_rooms_filtered_helper( filter: &Filter, _network: &RoomNetwork, ) -> Result { - if let Some(other_server) = - server.filter(|server| *server != services().globals.server_name().as_str()) + if let Some(other_server) = server + .filter(|server| *server != services().globals.server_name().as_str()) { let response = services() .sending @@ -174,10 +183,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( } }; - num_since = characters - .collect::() - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?; + num_since = characters.collect::().parse().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token.") + })?; if backwards { num_since = num_since.saturating_sub(limit); @@ -195,12 +203,19 @@ pub(crate) async fn get_public_rooms_filtered_helper( canonical_alias: services() .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? + .room_state_get( + &room_id, + &StateEventType::RoomCanonicalAlias, + "", + )? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomCanonicalAliasEventContent| c.alias) .map_err(|_| { - Error::bad_database("Invalid canonical alias event in database.") + Error::bad_database( + "Invalid canonical alias event in \ + database.", + ) }) })?, name: services().rooms.state_accessor.get_name(&room_id)?, @@ -222,36 +237,55 @@ pub(crate) async fn get_public_rooms_filtered_helper( serde_json::from_str(s.content.get()) .map(|c: RoomTopicEventContent| Some(c.topic)) .map_err(|_| { - error!("Invalid room topic event in database for room {}", room_id); - Error::bad_database("Invalid room topic event in database.") + error!( + "Invalid room topic event in database for \ + room {}", + room_id + ); + Error::bad_database( + "Invalid room topic event in database.", + ) }) })?, world_readable: services() .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get( + &room_id, + &StateEventType::RoomHistoryVisibility, + "", + )? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable + c.history_visibility + == HistoryVisibility::WorldReadable }) .map_err(|_| { Error::bad_database( - "Invalid room history visibility event in database.", + "Invalid room history visibility event in \ + database.", ) }) })?, guest_can_join: services() .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? + .room_state_get( + &room_id, + &StateEventType::RoomGuestAccess, + "", + )? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomGuestAccessEventContent| { c.guest_access == GuestAccess::CanJoin }) .map_err(|_| { - Error::bad_database("Invalid room guest access event in database.") + Error::bad_database( + "Invalid room guest access event in \ + database.", + ) }) })?, avatar_url: services() @@ -262,7 +296,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( serde_json::from_str(s.content.get()) .map(|c: RoomAvatarEventContent| c.url) .map_err(|_| { - Error::bad_database("Invalid room avatar event in database.") + Error::bad_database( + "Invalid room avatar event in database.", + ) }) }) .transpose()? @@ -270,33 +306,59 @@ pub(crate) async fn get_public_rooms_filtered_helper( join_rule: services() .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? + .room_state_get( + &room_id, + &StateEventType::RoomJoinRules, + "", + )? .map(|s| { serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, + .map(|c: RoomJoinRulesEventContent| { + match c.join_rule { + JoinRule::Public => { + Some(PublicRoomJoinRule::Public) + } + JoinRule::Knock => { + Some(PublicRoomJoinRule::Knock) + } + _ => None, + } }) .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") + error!( + "Invalid room join rule event in \ + database: {}", + e + ); + Error::BadDatabase( + "Invalid room join rule event in database.", + ) }) }) .transpose()? .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, + .ok_or_else(|| { + Error::bad_database( + "Missing room join rule event for room.", + ) + })?, room_type: services() .rooms .state_accessor .room_state_get(&room_id, &StateEventType::RoomCreate, "")? .map(|s| { - serde_json::from_str::(s.content.get()).map_err( - |e| { - error!("Invalid room create event in database: {}", e); - Error::BadDatabase("Invalid room create event in database.") - }, + serde_json::from_str::( + s.content.get(), ) + .map_err(|e| { + error!( + "Invalid room create event in database: {}", + e + ); + Error::BadDatabase( + "Invalid room create event in database.", + ) + }) }) .transpose()? .and_then(|e| e.room_type), @@ -306,10 +368,8 @@ pub(crate) async fn get_public_rooms_filtered_helper( }) .filter_map(Result::<_>::ok) .filter(|chunk| { - if let Some(query) = filter - .generic_search_term - .as_ref() - .map(|q| q.to_lowercase()) + if let Some(query) = + filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(name) = &chunk.name { if name.as_str().to_lowercase().contains(&query) { @@ -324,7 +384,8 @@ pub(crate) async fn get_public_rooms_filtered_helper( } if let Some(canonical_alias) = &chunk.canonical_alias { - if canonical_alias.as_str().to_lowercase().contains(&query) { + if canonical_alias.as_str().to_lowercase().contains(&query) + { return true; } } @@ -339,7 +400,8 @@ pub(crate) async fn get_public_rooms_filtered_helper( all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); - let total_room_count_estimate = all_rooms.len().try_into().unwrap_or(UInt::MAX); + let total_room_count_estimate = + all_rooms.len().try_into().unwrap_or(UInt::MAX); let chunk: Vec<_> = all_rooms .into_iter() @@ -353,11 +415,12 @@ pub(crate) async fn get_public_rooms_filtered_helper( Some(format!("p{num_since}")) }; - let next_batch = if chunk.len() < limit.try_into().expect("UInt should fit in usize") { - None - } else { - Some(format!("n{}", num_since + limit)) - }; + let next_batch = + if chunk.len() < limit.try_into().expect("UInt should fit in usize") { + None + } else { + Some(format!("n{}", num_since + limit)) + }; Ok(get_public_rooms_filtered::v3::Response { chunk, diff --git a/src/api/client_server/filter.rs b/src/api/client_server/filter.rs index 0ceefe40..fc2f2a1c 100644 --- a/src/api/client_server/filter.rs +++ b/src/api/client_server/filter.rs @@ -1,9 +1,10 @@ -use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ error::ErrorKind, filter::{create_filter, get_filter}, }; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// /// Loads a filter that was previously created. @@ -13,8 +14,13 @@ pub(crate) async fn get_filter_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services().users.get_filter(sender_user, &body.filter_id)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); + let Some(filter) = + services().users.get_filter(sender_user, &body.filter_id)? + else { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Filter not found.", + )); }; Ok(get_filter::v3::Response::new(filter)) diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 06c06dcb..4e68e81a 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,13 +1,16 @@ -use super::SESSION_ID_LENGTH; -use crate::{services, utils, Error, Result, Ruma}; +use std::{ + collections::{hash_map, BTreeMap, HashMap, HashSet}, + time::{Duration, Instant}, +}; + use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::{ error::ErrorKind, keys::{ - claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, - upload_signing_keys, + claim_keys, get_key_changes, get_keys, upload_keys, + upload_signatures, upload_signing_keys, }, uiaa::{AuthFlow, AuthType, UiaaInfo}, }, @@ -17,28 +20,32 @@ use ruma::{ DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; -use std::{ - collections::{hash_map, BTreeMap, HashMap, HashSet}, - time::{Duration, Instant}, -}; use tracing::debug; +use super::SESSION_ID_LENGTH; +use crate::{services, utils, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/keys/upload` /// /// Publish end-to-end encryption keys for the sender device. /// /// - Adds one time keys -/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) +/// - If there are no device keys yet: Adds device keys (TODO: merge with +/// existing keys?) pub(crate) async fn upload_keys_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); for (key_key, key_value) in &body.one_time_keys { - services() - .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; + services().users.add_one_time_key( + sender_user, + sender_device, + key_key, + key_value, + )?; } if let Some(device_keys) = &body.device_keys { @@ -49,9 +56,11 @@ pub(crate) async fn upload_keys_route( .get_device_keys(sender_user, sender_device)? .is_none() { - services() - .users - .add_device_keys(sender_user, sender_device, device_keys)?; + services().users.add_device_keys( + sender_user, + sender_device, + device_keys, + )?; } } @@ -68,14 +77,17 @@ pub(crate) async fn upload_keys_route( /// /// - Always fetches users from other servers over federation /// - Gets master keys, self-signing keys, user signing keys and device keys. -/// - The master and self-signing keys contain signatures that the user is allowed to see +/// - The master and self-signing keys contain signatures that the user is +/// allowed to see pub(crate) async fn get_keys_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let response = - get_keys_helper(Some(sender_user), &body.device_keys, |u| u == sender_user).await?; + let response = get_keys_helper(Some(sender_user), &body.device_keys, |u| { + u == sender_user + }) + .await?; Ok(response) } @@ -100,7 +112,8 @@ pub(crate) async fn upload_signing_keys_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); // UIAA let mut uiaainfo = UiaaInfo { @@ -114,19 +127,19 @@ pub(crate) async fn upload_signing_keys_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services().uiaa.try_auth( + sender_user, + sender_device, + auth, + &uiaainfo, + )?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -156,8 +169,9 @@ pub(crate) async fn upload_signatures_route( for (user_id, keys) in &body.signed_keys { for (key_id, key) in keys { - let key = serde_json::to_value(key) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; + let key = serde_json::to_value(key).map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON") + })?; for signature in key .get("signatures") @@ -189,9 +203,12 @@ pub(crate) async fn upload_signatures_route( ))? .to_owned(), ); - services() - .users - .sign_key(user_id, key_id, signature, sender_user)?; + services().users.sign_key( + user_id, + key_id, + signature, + sender_user, + )?; } } } @@ -204,7 +221,8 @@ pub(crate) async fn upload_signatures_route( /// # `POST /_matrix/client/r0/keys/changes` /// -/// Gets a list of users who have updated their device identity keys since the previous sync token. +/// Gets a list of users who have updated their device identity keys since the +/// previous sync token. /// /// - TODO: left users pub(crate) async fn get_key_changes_route( @@ -219,14 +237,15 @@ pub(crate) async fn get_key_changes_route( .users .keys_changed( sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), + body.from.parse().map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid `from`.", + ) + })?, + Some(body.to.parse().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.") + })?), ) .filter_map(Result::ok), ); @@ -243,10 +262,16 @@ pub(crate) async fn get_key_changes_route( .keys_changed( room_id.as_ref(), body.from.parse().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.") + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid `from`.", + ) })?, Some(body.to.parse().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.") + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid `to`.", + ) })?), ) .filter_map(Result::ok), @@ -287,16 +312,24 @@ pub(crate) async fn get_keys_helper bool>( let mut container = BTreeMap::new(); for device_id in services().users.all_device_ids(user_id) { let device_id = device_id?; - if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { + if let Some(mut keys) = + services().users.get_device_keys(user_id, &device_id)? + { let metadata = services() .users .get_device_metadata(user_id, &device_id)? .ok_or_else(|| { - Error::bad_database("all_device_keys contained nonexistent device.") + Error::bad_database( + "all_device_keys contained nonexistent device.", + ) })?; add_unsigned_device_display_name(&mut keys, metadata) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| { + Error::bad_database( + "invalid device keys in database", + ) + })?; container.insert(device_id, keys); } } @@ -304,7 +337,9 @@ pub(crate) async fn get_keys_helper bool>( } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { + if let Some(mut keys) = + services().users.get_device_keys(user_id, device_id)? + { let metadata = services() .users .get_device_metadata(user_id, device_id)? @@ -314,29 +349,35 @@ pub(crate) async fn get_keys_helper bool>( ))?; add_unsigned_device_display_name(&mut keys, metadata) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| { + Error::bad_database( + "invalid device keys in database", + ) + })?; container.insert(device_id.to_owned(), keys); } device_keys.insert(user_id.to_owned(), container); } } - if let Some(master_key) = - services() - .users - .get_master_key(sender_user, user_id, &allowed_signatures)? - { + if let Some(master_key) = services().users.get_master_key( + sender_user, + user_id, + &allowed_signatures, + )? { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = - services() - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? - { + if let Some(self_signing_key) = services().users.get_self_signing_key( + sender_user, + user_id, + &allowed_signatures, + )? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { + if let Some(user_signing_key) = + services().users.get_user_signing_key(user_id)? + { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -345,17 +386,13 @@ pub(crate) async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); let back_off = |id| async { - match services() - .globals - .bad_query_ratelimiter - .write() - .await - .entry(id) - { + match services().globals.bad_query_ratelimiter.write().await.entry(id) { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + } } }; @@ -370,7 +407,8 @@ pub(crate) async fn get_keys_helper bool>( .get(server) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } @@ -379,7 +417,9 @@ pub(crate) async fn get_keys_helper bool>( debug!("Backing off query from {:?}", server); return ( server, - Err(Error::BadServerResponse("bad query, still backing off")), + Err(Error::BadServerResponse( + "bad query, still backing off", + )), ); } } @@ -417,15 +457,19 @@ pub(crate) async fn get_keys_helper bool>( &user, &allowed_signatures, )? { - let (_, our_master_key) = - services().users.parse_master_key(&user, &our_master_key)?; + let (_, our_master_key) = services() + .users + .parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); } - let json = serde_json::to_value(master_key).expect("to_value always works"); - let raw = serde_json::from_value(json).expect("Raw::from_value always works"); + let json = serde_json::to_value(master_key) + .expect("to_value always works"); + let raw = serde_json::from_value(json) + .expect("Raw::from_value always works"); services().users.add_cross_signing_keys( &user, &raw, &None, &None, - // Dont notify. A notification would trigger another key request resulting in an endless loop + // Dont notify. A notification would trigger another key + // request resulting in an endless loop false, )?; master_keys.insert(user, raw); @@ -454,11 +498,13 @@ fn add_unsigned_device_display_name( metadata: ruma::api::client::device::Device, ) -> serde_json::Result<()> { if let Some(display_name) = metadata.display_name { - let mut object = keys.deserialize_as::>()?; + let mut object = keys + .deserialize_as::>()?; let unsigned = object.entry("unsigned").or_insert_with(|| json!({})); if let serde_json::Value::Object(unsigned_object) = unsigned { - unsigned_object.insert("device_display_name".to_owned(), display_name.into()); + unsigned_object + .insert("device_display_name".to_owned(), display_name.into()); } *keys = Raw::from_json(serde_json::value::to_raw_value(&object)?); @@ -468,7 +514,10 @@ fn add_unsigned_device_display_name( } pub(crate) async fn claim_keys_helper( - one_time_keys_input: &BTreeMap>, + one_time_keys_input: &BTreeMap< + OwnedUserId, + BTreeMap, + >, ) -> Result { let mut one_time_keys = BTreeMap::new(); @@ -484,11 +533,11 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = - services() - .users - .take_one_time_key(user_id, device_id, key_algorithm)? - { + if let Some(one_time_keys) = services().users.take_one_time_key( + user_id, + device_id, + key_algorithm, + )? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); container.insert(device_id.clone(), c); diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 2a0f05e7..e4b90ed3 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,14 +1,15 @@ use std::time::Duration; -use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; use ruma::api::client::{ error::ErrorKind, media::{ - create_content, get_content, get_content_as_filename, get_content_thumbnail, - get_media_config, + create_content, get_content, get_content_as_filename, + get_content_thumbnail, get_media_config, }, }; +use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; + const MXC_LENGTH: usize = 32; /// # `GET /_matrix/media/r0/config` @@ -110,9 +111,12 @@ pub(crate) async fn get_content_route( content_disposition, cross_origin_resource_policy: Some("cross-origin".to_owned()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() + && body.allow_remote + { let remote_content_response = - get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; + get_remote_content(&mxc, &body.server_name, body.media_id.clone()) + .await?; Ok(remote_content_response) } else { Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) @@ -130,21 +134,32 @@ pub(crate) async fn get_content_as_filename_route( let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { - content_type, file, .. + content_type, + file, + .. }) = services().media.get(mxc.clone()).await? { Ok(get_content_as_filename::v3::Response { file, content_type, - content_disposition: Some(format!("inline; filename={}", body.filename)), + content_disposition: Some(format!( + "inline; filename={}", + body.filename + )), cross_origin_resource_policy: Some("cross-origin".to_owned()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() + && body.allow_remote + { let remote_content_response = - get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; + get_remote_content(&mxc, &body.server_name, body.media_id.clone()) + .await?; Ok(get_content_as_filename::v3::Response { - content_disposition: Some(format!("inline: filename={}", body.filename)), + content_disposition: Some(format!( + "inline: filename={}", + body.filename + )), content_type: remote_content_response.content_type, file: remote_content_response.file, cross_origin_resource_policy: Some("cross-origin".to_owned()), @@ -165,17 +180,19 @@ pub(crate) async fn get_content_thumbnail_route( let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { - content_type, file, .. + content_type, + file, + .. }) = services() .media .get_thumbnail( mxc.clone(), - body.width - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - body.height - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + body.width.try_into().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.") + })?, + body.height.try_into().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.") + })?, ) .await? { @@ -184,7 +201,9 @@ pub(crate) async fn get_content_thumbnail_route( content_type, cross_origin_resource_policy: Some("cross-origin".to_owned()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() + && body.allow_remote + { let get_thumbnail_response = services() .sending .send_federation_request( diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index be3d59fa..fa044f27 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1,11 +1,18 @@ +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + sync::Arc, + time::{Duration, Instant}, +}; + use ruma::{ api::{ client::{ error::ErrorKind, membership::{ - ban_user, forget_room, get_member_events, invite_user, join_room_by_id, - join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, - unban_user, ThirdPartySigned, + ban_user, forget_room, get_member_events, invite_user, + join_room_by_id, join_room_by_id_or_alias, joined_members, + joined_rooms, kick_user, leave_room, unban_user, + ThirdPartySigned, }, }, federation::{self, membership::create_invite}, @@ -19,31 +26,27 @@ use ruma::{ StateEventType, TimelineEventType, }, serde::Base64, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, - OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, - sync::Arc, - time::{Duration, Instant}, -}; use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; +use super::get_alias_helper; use crate::{ service::pdu::{gen_event_id_canonical_json, PduBuilder}, services, utils, Error, PduEvent, Result, Ruma, }; -use super::get_alias_helper; - /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. /// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation +/// - If the server knowns about this room: creates the join event and does auth +/// rules locally +/// - If the server does not know about the room: asks other servers over +/// federation pub(crate) async fn join_room_by_id_route( body: Ruma, ) -> Result { @@ -86,15 +89,19 @@ pub(crate) async fn join_room_by_id_route( /// /// Tries to join the sender user into a room. /// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation +/// - If the server knowns about this room: creates the join event and does auth +/// rules locally +/// - If the server does not know about the room: asks other servers over +/// federation pub(crate) async fn join_room_by_id_or_alias_route( body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + let sender_user = + body.sender_user.as_deref().expect("user is authenticated"); let body = body.body; - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) + { Ok(room_id) => { let mut servers = body.server_name.clone(); servers.extend( @@ -104,8 +111,12 @@ pub(crate) async fn join_room_by_id_or_alias_route( .invite_state(sender_user, &room_id)? .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|event| { + serde_json::from_str(event.json().get()).ok() + }) + .filter_map(|event: serde_json::Value| { + event.get("sender").cloned() + }) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), @@ -164,7 +175,10 @@ pub(crate) async fn invite_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if let invite_user::v3::InvitationRecipient::UserId { user_id } = &body.recipient { + if let invite_user::v3::InvitationRecipient::UserId { + user_id, + } = &body.recipient + { invite_helper( sender_user, user_id, @@ -187,10 +201,8 @@ pub(crate) async fn kick_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if let Ok(true) = services() - .rooms - .state_cache - .is_left(sender_user, &body.room_id) + if let Ok(true) = + services().rooms.state_cache.is_left(sender_user, &body.room_id) { return Ok(kick_user::v3::Response {}); } @@ -233,7 +245,8 @@ pub(crate) async fn kick_user_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), + content: to_raw_value(&event) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(body.user_id.to_string()), redacts: None, @@ -257,10 +270,8 @@ pub(crate) async fn ban_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if let Ok(Some(membership_event)) = services() - .rooms - .state_accessor - .get_member(&body.room_id, sender_user) + if let Ok(Some(membership_event)) = + services().rooms.state_accessor.get_member(&body.room_id, sender_user) { if membership_event.membership == MembershipState::Ban { return Ok(ban_user::v3::Response {}); @@ -288,12 +299,16 @@ pub(crate) async fn ban_user_route( }), |event| { serde_json::from_str(event.content.get()) - .map(|event: RoomMemberEventContent| RoomMemberEventContent { - membership: MembershipState::Ban, - join_authorized_via_users_server: None, - ..event + .map(|event: RoomMemberEventContent| { + RoomMemberEventContent { + membership: MembershipState::Ban, + join_authorized_via_users_server: None, + ..event + } + }) + .map_err(|_| { + Error::bad_database("Invalid member event in database.") }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) }, )?; @@ -314,7 +329,8 @@ pub(crate) async fn ban_user_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), + content: to_raw_value(&event) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(body.user_id.to_string()), redacts: None, @@ -338,10 +354,8 @@ pub(crate) async fn unban_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if let Ok(Some(membership_event)) = services() - .rooms - .state_accessor - .get_member(&body.room_id, sender_user) + if let Ok(Some(membership_event)) = + services().rooms.state_accessor.get_member(&body.room_id, sender_user) { if membership_event.membership != MembershipState::Ban { return Ok(unban_user::v3::Response {}); @@ -386,7 +400,8 @@ pub(crate) async fn unban_user_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), + content: to_raw_value(&event) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(body.user_id.to_string()), redacts: None, @@ -406,19 +421,17 @@ pub(crate) async fn unban_user_route( /// /// Forgets about a room. /// -/// - If the sender user currently left the room: Stops sender user from receiving information about the room +/// - If the sender user currently left the room: Stops sender user from +/// receiving information about the room /// -/// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to -/// be called from every device +/// Note: Other devices of the user have no way of knowing the room was +/// forgotten, so this has to be called from every device pub(crate) async fn forget_room_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .rooms - .state_cache - .forget(&body.room_id, sender_user)?; + services().rooms.state_cache.forget(&body.room_id, sender_user)?; Ok(forget_room::v3::Response::new()) } @@ -443,7 +456,8 @@ pub(crate) async fn joined_rooms_route( /// # `POST /_matrix/client/r0/rooms/{roomId}/members` /// -/// Lists all joined users in a room (TODO: at a specific point in time, with a specific membership). +/// Lists all joined users in a room (TODO: at a specific point in time, with a +/// specific membership). /// /// - Only works if the user is currently joined pub(crate) async fn get_member_events_route( @@ -516,7 +530,9 @@ pub(crate) async fn joined_members_route( ); } - Ok(joined_members::v3::Response { joined }) + Ok(joined_members::v3::Response { + joined, + }) } #[allow(clippy::too_many_lines)] @@ -529,7 +545,9 @@ async fn join_room_by_id_helper( ) -> Result { let sender_user = sender_user.expect("user is authenticated"); - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { + if let Ok(true) = + services().rooms.state_cache.is_joined(sender_user, room_id) + { return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), }); @@ -560,19 +578,25 @@ async fn join_room_by_id_helper( "", )?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") + let join_rules_event_content: Option = + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()) + .map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database( + "Invalid join rules event in db.", + ) + }) }) - }) - .transpose()?; + .transpose()?; let restriction_rooms = match join_rules_event_content { Some(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), + join_rule: + JoinRule::Restricted(restricted) + | JoinRule::KnockRestricted(restricted), }) => restricted .allow .into_iter() @@ -584,37 +608,38 @@ async fn join_room_by_id_helper( _ => Vec::new(), }; - let authorized_user = if restriction_rooms.iter().any(|restriction_room_id| { - services() - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { - let mut auth_user = None; - for user in services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .collect::>() - { - if user.server_name() == services().globals.server_name() - && services().rooms.state_accessor.user_can_invite( - room_id, - &user, - sender_user, - &state_lock, - ) + let authorized_user = + if restriction_rooms.iter().any(|restriction_room_id| { + services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .unwrap_or(false) + }) { + let mut auth_user = None; + for user in services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .collect::>() { - auth_user = Some(user); - break; + if user.server_name() == services().globals.server_name() + && services().rooms.state_accessor.user_can_invite( + room_id, + &user, + sender_user, + &state_lock, + ) + { + auth_user = Some(user); + break; + } } - } - auth_user - } else { - None - }; + auth_user + } else { + None + }; let event = RoomMemberEventContent { membership: MembershipState::Join, @@ -634,7 +659,8 @@ async fn join_room_by_id_helper( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), + content: to_raw_value(&event) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(sender_user.to_string()), redacts: None, @@ -645,21 +671,24 @@ async fn join_room_by_id_helper( ) .await { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Ok(_event_id) => { + return Ok(join_room_by_id::v3::Response::new( + room_id.to_owned(), + )) + } Err(e) => e, }; if restriction_rooms.is_empty() - && servers - .iter() - .any(|s| *s != services().globals.server_name()) + && servers.iter().any(|s| *s != services().globals.server_name()) { return Err(error); } info!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" - ); + "We couldn't do the join locally, maybe federation can help to \ + satisfy the restricted join requirements" + ); let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; @@ -672,24 +701,32 @@ async fn join_room_by_id_helper( { room_version_id } - _ => return Err(Error::BadServerResponse("Room version is not supported")), + _ => { + return Err(Error::BadServerResponse( + "Room version is not supported", + )) + } }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str( + make_join_response.event.get(), + ) + .map_err(|_| { + Error::BadServerResponse( + "Invalid make_join event json received from server.", + ) + })?; let join_authorized_via_users_server = join_event_stub .get("content") .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() + s.as_object()?.get("join_authorised_via_users_server")?.as_str() }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String( + services().globals.server_name().as_str().to_owned(), + ), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -714,10 +751,12 @@ async fn join_room_by_id_helper( .expect("event is valid, we just created it"), ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + // We don't leave the event id in the pdu because that's only allowed in + // v1 or v2 rooms join_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + // In order to create a compatible ref hash (EventID) the `hashes` field + // needs to be present ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), services().globals.keypair(), @@ -729,8 +768,11 @@ async fn join_room_by_id_helper( // Generate event id let event_id = format!( "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") + ruma::signatures::reference_hash( + &join_event_stub, + &room_version_id + ) + .expect("ruma can calculate reference hashes") ); let event_id = <&EventId>::try_from(event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); @@ -751,7 +793,9 @@ async fn join_room_by_id_helper( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + pdu: PduEvent::convert_to_outgoing_federation_event( + join_event.clone(), + ), omit_members: false, }, ) @@ -809,27 +853,35 @@ async fn join_room_by_id_helper( { room_version } - _ => return Err(Error::BadServerResponse("Room version is not supported")), + _ => { + return Err(Error::BadServerResponse( + "Room version is not supported", + )) + } }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str( + make_join_response.event.get(), + ) + .map_err(|_| { + Error::BadServerResponse( + "Invalid make_join event json received from server.", + ) + })?; let join_authorized_via_users_server = join_event_stub .get("content") .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() + s.as_object()?.get("join_authorised_via_users_server")?.as_str() }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String( + services().globals.server_name().as_str().to_owned(), + ), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -854,10 +906,12 @@ async fn join_room_by_id_helper( .expect("event is valid, we just created it"), ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + // We don't leave the event id in the pdu because that's only allowed in + // v1 or v2 rooms join_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + // In order to create a compatible ref hash (EventID) the `hashes` field + // needs to be present ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), services().globals.keypair(), @@ -869,8 +923,11 @@ async fn join_room_by_id_helper( // Generate event id let event_id = format!( "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") + ruma::signatures::reference_hash( + &join_event_stub, + &room_version_id + ) + .expect("ruma can calculate reference hashes") ); let event_id = <&EventId>::try_from(event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); @@ -892,7 +949,9 @@ async fn join_room_by_id_helper( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + pdu: PduEvent::convert_to_outgoing_federation_event( + join_event.clone(), + ), omit_members: false, }, ) @@ -901,7 +960,10 @@ async fn join_room_by_id_helper( info!("send_join finished"); if let Some(signed_raw) = &send_join_response.room_state.event { - info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); + info!( + "There is a signed event. This room is probably using \ + restricted joins. Adding signature to our event" + ); let Ok((signed_event_id, signed_value)) = gen_event_id_canonical_json(signed_raw, &room_version_id) else { @@ -941,7 +1003,8 @@ async fn join_room_by_id_helper( } Err(e) => { warn!( - "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", + "Server {remote_server} sent invalid signature in \ + sendjoin signatures for event {signed_value:?}: {e:?}", ); } } @@ -950,8 +1013,10 @@ async fn join_room_by_id_helper( services().rooms.short.get_or_create_shortroomid(room_id)?; info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + let parsed_join_pdu = + PduEvent::from_id_val(event_id, join_event.clone()).map_err( + |_| Error::BadServerResponse("Invalid join event PDU."), + )?; let mut state = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); @@ -960,58 +1025,61 @@ async fn join_room_by_id_helper( services() .rooms .event_handler - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .fetch_join_signing_keys( + &send_join_response, + &room_version_id, + &pub_key_map, + ) .await?; info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { + for result in send_join_response.room_state.state.iter().map(|pdu| { + validate_and_add_event_id(pdu, &room_version_id, &pub_key_map) + }) { let Ok((event_id, value)) = result.await else { continue; }; - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("Invalid PDU in send_join response: {} {:?}", e, value); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err( + |e| { + warn!( + "Invalid PDU in send_join response: {} {:?}", + e, value + ); + Error::BadServerResponse( + "Invalid PDU in send_join response.", + ) + }, + )?; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; + services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &pdu.kind.to_string().into(), + state_key, + )?; state.insert(shortstatekey, pdu.event_id.clone()); } } info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + for result in + send_join_response.room_state.auth_chain.iter().map(|pdu| { + validate_and_add_event_id(pdu, &room_version_id, &pub_key_map) + }) { let Ok((event_id, value)) = result.await else { continue; }; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; + services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; } info!("Running send_join auth check"); let authenticated = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &state_res::RoomVersion::new(&room_version_id) + .expect("room version is supported"), &parsed_join_pdu, // TODO: third party invite None::, @@ -1024,7 +1092,10 @@ async fn join_room_by_id_helper( &services() .rooms .short - .get_or_create_shortstatekey(&k.to_string().into(), s) + .get_or_create_shortstatekey( + &k.to_string().into(), + s, + ) .ok()?, )?, ) @@ -1044,33 +1115,42 @@ async fn join_room_by_id_helper( } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; + let (statehash_before_join, new, removed) = + services().rooms.state_compressor.save_state( + room_id, + Arc::new( + state + .into_iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) + .collect::>()?, + ), + )?; services() .rooms .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .force_state( + room_id, + statehash_before_join, + new, + removed, + &state_lock, + ) .await?; info!("Updating joined counts for new room"); services().rooms.state_cache.update_joined_count(room_id)?; - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + // We append to state before appending the pdu, so we don't have a + // moment in time with the pdu without it's state. This is okay + // because append_pdu can't fail. + let statehash_after_join = + services().rooms.state.append_to_state(&parsed_join_pdu)?; info!("Appending new room join event"); services() @@ -1085,12 +1165,14 @@ async fn join_room_by_id_helper( .await?; info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; + // We set the room state after inserting the pdu, so that we never have + // a moment in time where events in the current room state do + // not exist + services().rooms.state.set_room_state( + room_id, + statehash_after_join, + &state_lock, + )?; } Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) @@ -1125,7 +1207,8 @@ async fn make_join_request( ) .await; - make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); + make_join_response_and_server = + make_join_response.map(|r| (r, remote_server.clone())); if make_join_response_and_server.is_ok() { break; @@ -1140,10 +1223,11 @@ async fn validate_and_add_event_id( room_version: &RoomVersionId, pub_key_map: &RwLock>>, ) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; let event_id = EventId::parse(format!( "${}", ruma::signatures::reference_hash(&value, room_version) @@ -1152,41 +1236,39 @@ async fn validate_and_add_event_id( .expect("ruma's reference hashes are valid event ids"); let back_off = |id| async { - match services() - .globals - .bad_event_ratelimiter - .write() - .await - .entry(id) - { + match services().globals.bad_event_ratelimiter.write().await.entry(id) { Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + } } }; - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(&event_id) + if let Some((time, tries)) = + services().globals.bad_event_ratelimiter.read().await.get(&event_id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); + return Err(Error::BadServerResponse( + "bad event, still backing off", + )); } } - if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) - { + if let Err(e) = ruma::signatures::verify_event( + &*pub_key_map.read().await, + &value, + room_version, + ) { warn!("Event {} failed verification {:?} {}", event_id, pdu, e); back_off(event_id).await; return Err(Error::BadServerResponse("Event failed verification.")); @@ -1233,27 +1315,30 @@ pub(crate) async fn invite_helper( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = + services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + )?; - let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = + services().rooms.state.calculate_invite_state(&pdu)?; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services().rooms.state.get_room_version(room_id)?; + let room_version_id = + services().rooms.state.get_room_version(room_id)?; let response = services() .sending @@ -1263,7 +1348,9 @@ pub(crate) async fn invite_helper( room_id: room_id.to_owned(), event_id: (*pdu.event_id).to_owned(), room_version: room_version_id.clone(), - event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), + event: PduEvent::convert_to_outgoing_federation_event( + pdu_json.clone(), + ), invite_room_state, }, ) @@ -1271,8 +1358,10 @@ pub(crate) async fn invite_helper( let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and hashes checks - let Ok((event_id, value)) = gen_event_id_canonical_json(&response.event, &room_version_id) + // We do not add the event_id field to the pdu here because of signature + // and hashes checks + let Ok((event_id, value)) = + gen_event_id_canonical_json(&response.event, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -1282,22 +1371,42 @@ pub(crate) async fn invite_helper( }; if *pdu.event_id != *event_id { - warn!("Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", user_id.server_name(), pdu_json, value); + warn!( + "Server {} changed invite event, that's not allowed in the \ + spec: ours: {:?}, theirs: {:?}", + user_id.server_name(), + pdu_json, + value + ); } let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event needs an origin field.", - ))?) + serde_json::to_value(value.get("origin").ok_or( + Error::BadRequest( + ErrorKind::InvalidParam, + "Event needs an origin field.", + ), + )?) .expect("CanonicalJson is valid json value"), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Origin field is invalid.", + ) + })?; let pdu_id: Vec = services() .rooms .event_handler - .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .handle_incoming_pdu( + &origin, + &event_id, + room_id, + value, + true, + &pub_key_map, + ) .await? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, @@ -1317,11 +1426,7 @@ pub(crate) async fn invite_helper( return Ok(()); } - if !services() - .rooms - .state_cache - .is_joined(sender_user, room_id)? - { + if !services().rooms.state_cache.is_joined(sender_user, room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -1387,7 +1492,9 @@ pub(crate) async fn leave_all_rooms(user_id: &UserId) -> Result<()> { .collect::>(); for room_id in all_rooms { - let Ok(room_id) = room_id else { continue }; + let Ok(room_id) = room_id else { + continue; + }; if let Err(error) = leave_room(user_id, &room_id, None).await { warn!(%user_id, %room_id, %error, "failed to leave room"); @@ -1443,8 +1550,10 @@ pub(crate) async fn leave_room( Some(e) => e, }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let mut event: RoomMemberEventContent = + serde_json::from_str(member_event.content.get()).map_err(|_| { + Error::bad_database("Invalid member event in database.") + })?; event.membership = MembershipState::Leave; event.reason = reason; @@ -1456,7 +1565,8 @@ pub(crate) async fn leave_room( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), + content: to_raw_value(&event) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(user_id.to_string()), redacts: None, @@ -1495,19 +1605,16 @@ pub(crate) async fn leave_room( Ok(()) } +#[allow(clippy::too_many_lines)] async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut make_leave_response_and_server = Err(Error::BadServerResponse( "No server available to assist in leaving.", )); - let invite_state = services() - .rooms - .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "User is not invited.", - ))?; + let invite_state = + services().rooms.state_cache.invite_state(user_id, room_id)?.ok_or( + Error::BadRequest(ErrorKind::BadState, "User is not invited."), + )?; let servers: HashSet<_> = invite_state .iter() @@ -1530,7 +1637,8 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { ) .await; - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); + make_leave_response_and_server = + make_leave_response.map(|r| (r, remote_server)); if make_leave_response_and_server.is_ok() { break; @@ -1548,18 +1656,28 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { { version } - _ => return Err(Error::BadServerResponse("Room version is not supported")), + _ => { + return Err(Error::BadServerResponse( + "Room version is not supported", + )) + } }; let mut leave_event_stub = serde_json::from_str::( make_leave_response.event.get(), ) - .map_err(|_| Error::BadServerResponse("Invalid make_leave event json received from server."))?; + .map_err(|_| { + Error::BadServerResponse( + "Invalid make_leave event json received from server.", + ) + })?; // TODO: Is origin needed? leave_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String( + services().globals.server_name().as_str().to_owned(), + ), ); leave_event_stub.insert( "origin_server_ts".to_owned(), @@ -1569,10 +1687,12 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { .expect("Timestamp is valid js_int value"), ), ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + // We don't leave the event id in the pdu because that's only allowed in v1 + // or v2 rooms leave_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + // In order to create a compatible ref hash (EventID) the `hashes` field + // needs to be present ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), services().globals.keypair(), @@ -1605,7 +1725,9 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { federation::membership::create_leave_event::v2::Request { room_id: room_id.to_owned(), event_id, - pdu: PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + pdu: PduEvent::convert_to_outgoing_federation_event( + leave_event.clone(), + ), }, ) .await?; diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 2b920295..202b6c05 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,7 +1,8 @@ -use crate::{ - service::{pdu::PduBuilder, rooms::timeline::PduCount}, - services, utils, Error, Result, Ruma, +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, }; + use ruma::{ api::client::{ error::ErrorKind, @@ -10,18 +11,21 @@ use ruma::{ events::{StateEventType, TimelineEventType}, uint, }; -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, + +use crate::{ + service::{pdu::PduBuilder, rooms::timeline::PduCount}, + services, utils, Error, Result, Ruma, }; /// # `PUT /_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}` /// /// Send a message event into the room. /// -/// - Is a NOOP if the txn id was already used before and returns the same event id again +/// - Is a NOOP if the txn id was already used before and returns the same event +/// id again /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed pub(crate) async fn send_message_event_route( body: Ruma, ) -> Result { @@ -50,29 +54,37 @@ pub(crate) async fn send_message_event_route( } // Check if this is a new transaction id - if let Some(response) = - services() - .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - { + if let Some(response) = services().transaction_ids.existing_txnid( + sender_user, + sender_device, + &body.txn_id, + )? { // The client might have sent a txnid of the /sendToDevice endpoint // This txnid has no response associated with it if response.is_empty() { return Err(Error::BadRequest( ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", + "Tried to use txn id already used for an incompatible \ + endpoint.", )); } let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? + .map_err(|_| { + Error::bad_database("Invalid txnid bytes in database.") + })? .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; - return Ok(send_message_event::v3::Response { event_id }); + .map_err(|_| { + Error::bad_database("Invalid event id in txnid data.") + })?; + return Ok(send_message_event::v3::Response { + event_id, + }); } let mut unsigned = BTreeMap::new(); - unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + unsigned + .insert("transaction_id".to_owned(), body.txn_id.to_string().into()); let event_id = services() .rooms @@ -81,7 +93,12 @@ pub(crate) async fn send_message_event_route( PduBuilder { event_type: body.event_type.to_string().into(), content: serde_json::from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, + .map_err(|_| { + Error::BadRequest( + ErrorKind::BadJson, + "Invalid JSON body.", + ) + })?, unsigned: Some(unsigned), state_key: None, redacts: None, @@ -101,23 +118,23 @@ pub(crate) async fn send_message_event_route( drop(state_lock); - Ok(send_message_event::v3::Response::new( - (*event_id).to_owned(), - )) + Ok(send_message_event::v3::Response::new((*event_id).to_owned())) } /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. /// -/// - Only works if the user is joined (TODO: always allow, but only show events where the user was +/// - Only works if the user is joined (TODO: always allow, but only show events +/// where the user was /// joined, depending on `history_visibility`) #[allow(clippy::too_many_lines)] pub(crate) async fn get_message_events_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); let from = match body.from.clone() { Some(from) => PduCount::try_from_string(&from)?, @@ -127,15 +144,17 @@ pub(crate) async fn get_message_events_route( }, }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); services() .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) + .lazy_load_confirm_delivery( + sender_user, + sender_device, + &body.room_id, + from, + ) .await?; let limit = body @@ -162,7 +181,11 @@ pub(crate) async fn get_message_events_route( services() .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) + .user_can_see_event( + sender_user, + &body.room_id, + &pdu.event_id, + ) .unwrap_or(false) }) .take_while(|&(k, _)| Some(k) != to) @@ -214,7 +237,11 @@ pub(crate) async fn get_message_events_route( services() .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) + .user_can_see_event( + sender_user, + &body.room_id, + &pdu.event_id, + ) .unwrap_or(false) }) .take_while(|&(k, _)| Some(k) != to) @@ -254,11 +281,13 @@ pub(crate) async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( - &body.room_id, - &StateEventType::RoomMember, - ll_id.as_str(), - )? { + if let Some(member_event) = + services().rooms.state_accessor.room_state_get( + &body.room_id, + &StateEventType::RoomMember, + ll_id.as_str(), + )? + { resp.state.push(member_event.to_state_event()); } } diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index f9a96904..4bcfc5af 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,20 +1,25 @@ -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; +use std::sync::Arc; + use ruma::{ api::{ client::{ error::ErrorKind, profile::{ - get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, + get_avatar_url, get_display_name, get_profile, set_avatar_url, + set_display_name, }, }, federation::{self, query::get_profile_information::v1::ProfileField}, }, - events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, + events::{ + room::member::RoomMemberEventContent, StateEventType, TimelineEventType, + }, }; use serde_json::value::to_raw_value; -use std::sync::Arc; use tracing::warn; +use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// /// Updates the displayname. @@ -25,9 +30,7 @@ pub(crate) async fn set_displayname_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .users - .set_displayname(sender_user, body.displayname.clone())?; + services().users.set_displayname(sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms let all_rooms_joined: Vec<_> = services() @@ -53,14 +56,18 @@ pub(crate) async fn set_displayname_route( )? .ok_or_else(|| { Error::bad_database( - "Tried to send displayname update for user not in the \ - room.", + "Tried to send displayname update for \ + user not in the room.", ) })? .content .get(), ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? + .map_err(|_| { + Error::bad_database( + "Database contains invalid PDU.", + ) + })? }) .expect("event is valid, we just created it"), unsigned: None, @@ -88,7 +95,12 @@ pub(crate) async fn set_displayname_route( if let Err(error) = services() .rooms .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + ) .await { warn!(%error, "failed to add PDU"); @@ -138,13 +150,9 @@ pub(crate) async fn set_avatar_url_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .users - .set_avatar_url(sender_user, body.avatar_url.clone())?; + services().users.set_avatar_url(sender_user, body.avatar_url.clone())?; - services() - .users - .set_blurhash(sender_user, body.blurhash.clone())?; + services().users.set_blurhash(sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms let all_joined_rooms: Vec<_> = services() @@ -170,14 +178,18 @@ pub(crate) async fn set_avatar_url_route( )? .ok_or_else(|| { Error::bad_database( - "Tried to send displayname update for user not in the \ - room.", + "Tried to send displayname update for \ + user not in the room.", ) })? .content .get(), ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? + .map_err(|_| { + Error::bad_database( + "Database contains invalid PDU.", + ) + })? }) .expect("event is valid, we just created it"), unsigned: None, @@ -205,7 +217,12 @@ pub(crate) async fn set_avatar_url_route( if let Err(error) = services() .rooms .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + ) .await { warn!(%error, "failed to add PDU"); @@ -219,7 +236,8 @@ pub(crate) async fn set_avatar_url_route( /// /// Returns the `avatar_url` and `blurhash` of the user. /// -/// - If user is on another server: Fetches `avatar_url` and `blurhash` over federation +/// - If user is on another server: Fetches `avatar_url` and `blurhash` over +/// federation pub(crate) async fn get_avatar_url_route( body: Ruma, ) -> Result { diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index 63a12715..9e101707 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -1,17 +1,18 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, push::{ - delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, - get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions, - set_pushrule_enabled, RuleScope, + delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, + get_pushrule_enabled, get_pushrules_all, set_pusher, set_pushrule, + set_pushrule_actions, set_pushrule_enabled, RuleScope, }, }, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, push::{AnyPushRuleRef, InsertPushRuleError, RemovePushRuleError}, }; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/pushrules` /// /// Retrieves the push rules event for this user. @@ -71,12 +72,11 @@ pub(crate) async fn get_pushrule_route( .map(Into::into); if let Some(rule) = rule { - Ok(get_pushrule::v3::Response { rule }) + Ok(get_pushrule::v3::Response { + rule, + }) } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - )) + Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")) } } @@ -109,7 +109,9 @@ pub(crate) async fn set_pushrule_route( ))?; let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })?; if let Err(error) = account_data.content.global.insert( body.rule.clone(), @@ -119,16 +121,20 @@ pub(crate) async fn set_pushrule_route( let err = match error { InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( ErrorKind::InvalidParam, - "Rule IDs starting with a dot are reserved for server-default rules.", + "Rule IDs starting with a dot are reserved for server-default \ + rules.", ), InsertPushRuleError::InvalidRuleId => Error::BadRequest( ErrorKind::InvalidParam, "Rule ID containing invalid characters.", ), - InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest( - ErrorKind::InvalidParam, - "Can't place a push rule relatively to a server-default rule.", - ), + InsertPushRuleError::RelativeToServerDefaultRule => { + Error::BadRequest( + ErrorKind::InvalidParam, + "Can't place a push rule relatively to a server-default \ + rule.", + ) + } InsertPushRuleError::UnknownRuleId => Error::BadRequest( ErrorKind::NotFound, "The before or after rule could not be found.", @@ -147,7 +153,8 @@ pub(crate) async fn set_pushrule_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; Ok(set_pushrule::v3::Response {}) @@ -193,7 +200,9 @@ pub(crate) async fn get_pushrule_actions_route( "Push rule not found.", ))?; - Ok(get_pushrule_actions::v3::Response { actions }) + Ok(get_pushrule_actions::v3::Response { + actions, + }) } /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` @@ -224,7 +233,9 @@ pub(crate) async fn set_pushrule_actions_route( ))?; let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })?; if account_data .content @@ -242,7 +253,8 @@ pub(crate) async fn set_pushrule_actions_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; Ok(set_pushrule_actions::v3::Response {}) @@ -276,7 +288,9 @@ pub(crate) async fn get_pushrule_enabled_route( ))?; let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })?; let global = account_data.content.global; let enabled = global @@ -287,7 +301,9 @@ pub(crate) async fn get_pushrule_enabled_route( "Push rule not found.", ))?; - Ok(get_pushrule_enabled::v3::Response { enabled }) + Ok(get_pushrule_enabled::v3::Response { + enabled, + }) } /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` @@ -318,7 +334,9 @@ pub(crate) async fn set_pushrule_enabled_route( ))?; let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })?; if account_data .content @@ -336,7 +354,8 @@ pub(crate) async fn set_pushrule_enabled_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; Ok(set_pushrule_enabled::v3::Response {}) @@ -370,12 +389,12 @@ pub(crate) async fn delete_pushrule_route( ))?; let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| { + Error::bad_database("Invalid account data event in db.") + })?; - if let Err(error) = account_data - .content - .global - .remove(body.kind.clone(), &body.rule_id) + if let Err(error) = + account_data.content.global.remove(body.kind.clone(), &body.rule_id) { let err = match error { RemovePushRuleError::ServerDefault => Error::BadRequest( @@ -395,7 +414,8 @@ pub(crate) async fn delete_pushrule_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; Ok(delete_pushrule::v3::Response {}) @@ -424,9 +444,7 @@ pub(crate) async fn set_pushers_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .pusher - .set_pusher(sender_user, body.action.clone())?; + services().pusher.set_pusher(sender_user, body.action.clone())?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index a1a7bbbf..5fdfc7de 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,20 +1,27 @@ -use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; +use std::collections::BTreeMap; + use ruma::{ - api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, + api::client::{ + error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt, + }, events::{ receipt::{ReceiptThread, ReceiptType}, RoomAccountDataEventType, }, MilliSecondsSinceUnixEpoch, }; -use std::collections::BTreeMap; + +use crate::{ + service::rooms::timeline::PduCount, services, Error, Result, Ruma, +}; /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// /// Sets different types of read markers. /// /// - Updates fully-read account data event to `fully_read` -/// - If `read_receipt` is set: Update private marker and public read receipt EDU +/// - If `read_receipt` is set: Update private marker and public read receipt +/// EDU pub(crate) async fn set_read_marker_route( body: Ruma, ) -> Result { @@ -30,7 +37,8 @@ pub(crate) async fn set_read_marker_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), + &serde_json::to_value(fully_read_event) + .expect("to json value always works"), )?; } @@ -42,14 +50,9 @@ pub(crate) async fn set_read_marker_route( } if let Some(event) = &body.private_read_receipt { - let count = services() - .rooms - .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?; + let count = services().rooms.timeline.get_pdu_count(event)?.ok_or( + Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."), + )?; let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -59,11 +62,11 @@ pub(crate) async fn set_read_marker_route( } PduCount::Normal(c) => c, }; - services() - .rooms - .edus - .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + services().rooms.edus.read_receipt.private_read_set( + &body.room_id, + sender_user, + count, + )?; } if let Some(event) = &body.read_receipt { @@ -86,7 +89,9 @@ pub(crate) async fn set_read_marker_route( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), + content: ruma::events::receipt::ReceiptEventContent( + receipt_content, + ), room_id: body.room_id.clone(), }, )?; @@ -105,7 +110,8 @@ pub(crate) async fn create_receipt_route( if matches!( &body.receipt_type, - create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate + create_receipt::v3::ReceiptType::Read + | create_receipt::v3::ReceiptType::ReadPrivate ) { services() .rooms @@ -124,7 +130,8 @@ pub(crate) async fn create_receipt_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), + &serde_json::to_value(fully_read_event) + .expect("to json value always works"), )?; } create_receipt::v3::ReceiptType::Read => { @@ -146,7 +153,9 @@ pub(crate) async fn create_receipt_route( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), + content: ruma::events::receipt::ReceiptEventContent( + receipt_content, + ), room_id: body.room_id.clone(), }, )?; diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 0f774dd9..65683d81 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use crate::{service::pdu::PduBuilder, services, Result, Ruma}; use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, }; - use serde_json::value::to_raw_value; +use crate::{service::pdu::PduBuilder, services, Result, Ruma}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// /// Tries to send a redaction event into the room. @@ -54,5 +54,7 @@ pub(crate) async fn redact_event_route( drop(state_lock); let event_id = (*event_id).to_owned(); - Ok(redact_event::v3::Response { event_id }) + Ok(redact_event::v3::Response { + event_id, + }) } diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index 353e4cc7..475adb9a 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -23,10 +23,7 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( }, }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = body @@ -36,27 +33,22 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( .try_into() .expect("0-100 should fit in usize"); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - Some(&body.event_type), - Some(&body.rel_type), - from, - to, - limit, - )?; + let res = services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + Some(&body.event_type), + Some(&body.rel_type), + from, + to, + limit, + )?; - Ok( - get_relating_events_with_rel_type_and_event_type::v1::Response { - chunk: res.chunk, - next_batch: res.next_batch, - prev_batch: res.prev_batch, - }, - ) + Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { + chunk: res.chunk, + next_batch: res.next_batch, + prev_batch: res.prev_batch, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` @@ -74,10 +66,7 @@ pub(crate) async fn get_relating_events_with_rel_type_route( }, }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = body @@ -87,19 +76,16 @@ pub(crate) async fn get_relating_events_with_rel_type_route( .try_into() .expect("0-100 should fit in usize"); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - Some(&body.rel_type), - from, - to, - limit, - )?; + let res = services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + Some(&body.rel_type), + from, + to, + limit, + )?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -123,10 +109,7 @@ pub(crate) async fn get_relating_events_route( }, }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = body @@ -136,17 +119,14 @@ pub(crate) async fn get_relating_events_route( .try_into() .expect("0-100 should fit in usize"); - services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - None, - from, - to, - limit, - ) + services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + None, + from, + to, + limit, + ) } diff --git a/src/api/client_server/report.rs b/src/api/client_server/report.rs index 25e8f10a..7bcf3f25 100644 --- a/src/api/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -1,14 +1,14 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{error::ErrorKind, room::report_content}, events::room::message, int, }; +use crate::{services, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/rooms/{roomId}/report/{eventId}` /// /// Reports an inappropriate event to homeserver admins -/// pub(crate) async fn report_event_route( body: Ruma, ) -> Result { diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index b34f04bc..b0a17e6a 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,6 +1,5 @@ -use crate::{ - api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma, -}; +use std::{cmp::max, collections::BTreeMap, sync::Arc}; + use ruma::{ api::client::{ error::ErrorKind, @@ -11,7 +10,9 @@ use ruma::{ canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + history_visibility::{ + HistoryVisibility, RoomHistoryVisibilityEventContent, + }, join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, name::RoomNameEventContent, @@ -26,9 +27,13 @@ use ruma::{ CanonicalJsonObject, OwnedRoomAliasId, RoomAliasId, RoomId, RoomVersionId, }; use serde_json::{json, value::to_raw_value}; -use std::{cmp::max, collections::BTreeMap, sync::Arc}; use tracing::{info, warn}; +use crate::{ + api::client_server::invite_helper, service::pdu::PduBuilder, services, + Error, Result, Ruma, +}; + /// # `POST /_matrix/client/r0/createRoom` /// /// Creates a new room. @@ -79,32 +84,27 @@ pub(crate) async fn create_room_route( } let alias: Option = - body.room_alias_name - .as_ref() - .map_or(Ok(None), |localpart| { - // TODO: Check for invalid characters and maximum length - let alias = RoomAliasId::parse(format!( - "#{}:{}", - localpart, - services().globals.server_name() - )) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - - if services() - .rooms - .alias - .resolve_local_alias(&alias)? - .is_some() - { - Err(Error::BadRequest( - ErrorKind::RoomInUse, - "Room alias already exists.", - )) - } else { - Ok(Some(alias)) - } + body.room_alias_name.as_ref().map_or(Ok(None), |localpart| { + // TODO: Check for invalid characters and maximum length + let alias = RoomAliasId::parse(format!( + "#{}:{}", + localpart, + services().globals.server_name() + )) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") })?; + if services().rooms.alias.resolve_local_alias(&alias)?.is_some() { + Err(Error::BadRequest( + ErrorKind::RoomInUse, + "Room alias already exists.", + )) + } else { + Ok(Some(alias)) + } + })?; + if let Some(alias) = &alias { if let Some(info) = &body.appservice_info { if !info.aliases.is_match(alias.as_str()) { @@ -159,7 +159,10 @@ pub(crate) async fn create_room_route( content.insert( "creator".into(), json!(&sender_user).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") + Error::BadRequest( + ErrorKind::BadJson, + "Invalid creation content", + ) })?, ); } @@ -171,7 +174,10 @@ pub(crate) async fn create_room_route( content.insert( "room_version".into(), json!(room_version.as_str()).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") + Error::BadRequest( + ErrorKind::BadJson, + "Invalid creation content", + ) })?, ); content @@ -187,20 +193,30 @@ pub(crate) async fn create_room_route( | RoomVersionId::V7 | RoomVersionId::V8 | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), + | RoomVersionId::V10 => { + RoomCreateEventContent::new_v1(sender_user.clone()) + } RoomVersionId::V11 => RoomCreateEventContent::new_v11(), _ => unreachable!("Validity of room version already checked"), }; let mut content = serde_json::from_str::( to_raw_value(&content) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))? + .map_err(|_| { + Error::BadRequest( + ErrorKind::BadJson, + "Invalid creation content", + ) + })? .get(), ) .unwrap(); content.insert( "room_version".into(), json!(room_version.as_str()).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") + Error::BadRequest( + ErrorKind::BadJson, + "Invalid creation content", + ) })?, ); content @@ -209,9 +225,7 @@ pub(crate) async fn create_room_route( // Validate creation content let de_result = serde_json::from_str::( - to_raw_value(&content) - .expect("Invalid creation content") - .get(), + to_raw_value(&content).expect("Invalid creation content").get(), ); if de_result.is_err() { @@ -228,7 +242,8 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), + content: to_raw_value(&content) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), redacts: None, @@ -285,17 +300,24 @@ pub(crate) async fn create_room_route( } } - let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"); + let mut power_levels_content = + serde_json::to_value(RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"); - if let Some(power_level_content_override) = &body.power_level_content_override { - let json: JsonObject = serde_json::from_str(power_level_content_override.json().get()) - .map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.") - })?; + if let Some(power_level_content_override) = + &body.power_level_content_override + { + let json: JsonObject = + serde_json::from_str(power_level_content_override.json().get()) + .map_err(|_| { + Error::BadRequest( + ErrorKind::BadJson, + "Invalid power_level_content_override.", + ) + })?; for (key, value) in json { power_levels_content[key] = value; @@ -353,11 +375,13 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { - RoomPreset::PublicChat => JoinRule::Public, - // according to spec "invite" is the default - _ => JoinRule::Invite, - })) + content: to_raw_value(&RoomJoinRulesEventContent::new( + match preset { + RoomPreset::PublicChat => JoinRule::Public, + // according to spec "invite" is the default + _ => JoinRule::Invite, + }, + )) .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), @@ -397,10 +421,12 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { - RoomPreset::PublicChat => GuestAccess::Forbidden, - _ => GuestAccess::CanJoin, - })) + content: to_raw_value(&RoomGuestAccessEventContent::new( + match preset { + RoomPreset::PublicChat => GuestAccess::Forbidden, + _ => GuestAccess::CanJoin, + }, + )) .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), @@ -414,10 +440,14 @@ pub(crate) async fn create_room_route( // 6. Events listed in initial_state for event in &body.initial_state { - let mut pdu_builder = event.deserialize_as::().map_err(|e| { - warn!("Invalid initial state event: {:?}", e); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.") - })?; + let mut pdu_builder = + event.deserialize_as::().map_err(|e| { + warn!("Invalid initial state event: {:?}", e); + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid initial state event.", + ) + })?; // Implicit state key defaults to "" pdu_builder.state_key.get_or_insert_with(String::new); @@ -432,7 +462,12 @@ pub(crate) async fn create_room_route( services() .rooms .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + ) .await?; } @@ -444,8 +479,10 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(name.clone())) - .expect("event is valid, we just created it"), + content: to_raw_value(&RoomNameEventContent::new( + name.clone(), + )) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), redacts: None, @@ -483,7 +520,8 @@ pub(crate) async fn create_room_route( drop(state_lock); for user_id in &body.invite { if let Err(error) = - invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await + invite_helper(sender_user, user_id, &room_id, None, body.is_direct) + .await { warn!(%error, "invite helper failed"); }; @@ -507,20 +545,19 @@ pub(crate) async fn create_room_route( /// /// Gets a single event. /// -/// - You have to currently be joined to the room (TODO: Respect history visibility) +/// - You have to currently be joined to the room (TODO: Respect history +/// visibility) pub(crate) async fn get_room_event_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() - .rooms - .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| { + let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else( + || { warn!("Event not found, event ID: {:?}", &body.event_id); Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + }, + )?; if !services().rooms.state_accessor.user_can_see_event( sender_user, @@ -545,17 +582,14 @@ pub(crate) async fn get_room_event_route( /// /// Lists all aliases of the room. /// -/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if `history_visibility` is world readable +/// - Only users joined to the room are allowed to call this TODO: Allow any +/// user to call it if `history_visibility` is world readable pub(crate) async fn get_room_aliases_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_cache - .is_joined(sender_user, &body.room_id)? - { + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -588,10 +622,7 @@ pub(crate) async fn upgrade_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .globals - .supported_room_versions() - .contains(&body.new_version) + if !services().globals.supported_room_versions().contains(&body.new_version) { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, @@ -601,10 +632,7 @@ pub(crate) async fn upgrade_room_route( // Create a replacement room let replacement_room = RoomId::new(services().globals.server_name()); - services() - .rooms - .short - .get_or_create_shortroomid(&replacement_room)?; + services().rooms.short.get_or_create_shortroomid(&replacement_room)?; let mutex_state = Arc::clone( services() @@ -617,8 +645,9 @@ pub(crate) async fn upgrade_room_route( ); let state_lock = mutex_state.lock().await; - // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further - // Fail if the sender does not have the required permissions + // Send a m.room.tombstone event to the old room to indicate that it is not + // intended to be used any further Fail if the sender does not have the + // required permissions let tombstone_event_id = services() .rooms .timeline @@ -659,7 +688,9 @@ pub(crate) async fn upgrade_room_route( .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .ok_or_else(|| { + Error::bad_database("Found room without m.room.create event.") + })? .content .get(), ) @@ -671,7 +702,8 @@ pub(crate) async fn upgrade_room_route( (*tombstone_event_id).to_owned(), )); - // Send a m.room.create event containing a predecessor field and the applicable room_version + // Send a m.room.create event containing a predecessor field and the + // applicable room_version match body.new_version { RoomVersionId::V1 | RoomVersionId::V2 @@ -686,7 +718,10 @@ pub(crate) async fn upgrade_room_route( create_event_content.insert( "creator".into(), json!(&sender_user).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + Error::BadRequest( + ErrorKind::BadJson, + "Error forming creation event", + ) })?, ); } @@ -698,15 +733,21 @@ pub(crate) async fn upgrade_room_route( } create_event_content.insert( "room_version".into(), - json!(&body.new_version) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + json!(&body.new_version).try_into().map_err(|_| { + Error::BadRequest( + ErrorKind::BadJson, + "Error forming creation event", + ) + })?, ); create_event_content.insert( "predecessor".into(), - json!(predecessor) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + json!(predecessor).try_into().map_err(|_| { + Error::BadRequest( + ErrorKind::BadJson, + "Error forming creation event", + ) + })?, ); // Validate creation event content @@ -784,16 +825,15 @@ pub(crate) async fn upgrade_room_route( // Replicate transferable state events to the new room for event_type in transferable_state_events { - let event_content = - match services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &event_type, "")? - { - Some(v) => v.content.clone(), - // Skipping missing events. - None => continue, - }; + let event_content = match services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &event_type, "")? + { + Some(v) => v.content.clone(), + // Skipping missing events. + None => continue, + }; services() .rooms @@ -820,30 +860,39 @@ pub(crate) async fn upgrade_room_route( .local_aliases_for_room(&body.room_id) .filter_map(Result::ok) { - services() - .rooms - .alias - .set_alias(&alias, &replacement_room)?; + services().rooms.alias.set_alias(&alias, &replacement_room)?; } // Get the old room power levels - let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + let mut power_levels_event_content: RoomPowerLevelsEventContent = + serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get( + &body.room_id, + &StateEventType::RoomPowerLevels, + "", + )? + .ok_or_else(|| { + Error::bad_database( + "Found room without m.room.create event.", + ) + })? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid room event in database."))?; - // Setting events_default and invite to the greater of 50 and users_default + 1 - let new_level = max(int!(50), power_levels_event_content.users_default + int!(1)); + // Setting events_default and invite to the greater of 50 and users_default + // + 1 + let new_level = + max(int!(50), power_levels_event_content.users_default + int!(1)); power_levels_event_content.events_default = new_level; power_levels_event_content.invite = new_level; - // Modify the power levels in the old room to prevent sending of events and inviting new users + // Modify the power levels in the old room to prevent sending of events and + // inviting new users let _ = services() .rooms .timeline @@ -865,5 +914,7 @@ pub(crate) async fn upgrade_room_route( drop(state_lock); // Return the replacement room id - Ok(upgrade_room::v3::Response { replacement_room }) + Ok(upgrade_room::v3::Response { + replacement_room, + }) } diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index 2b6c9f12..340eb839 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -1,22 +1,27 @@ -use crate::{services, Error, Result, Ruma}; +use std::collections::BTreeMap; + use ruma::{ api::client::{ error::ErrorKind, search::search_events::{ self, - v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, + v3::{ + EventContextResult, ResultCategories, ResultRoomEvents, + SearchResult, + }, }, }, uint, }; -use std::collections::BTreeMap; +use crate::{services, Error, Result, Ruma}; /// # `POST /_matrix/client/r0/search` /// /// Searches rooms for messages. /// -/// - Only works if the user is currently joined to the room (TODO: Respect history visibility) +/// - Only works if the user is currently joined to the room (TODO: Respect +/// history visibility) #[allow(clippy::too_many_lines)] pub(crate) async fn search_events_route( body: Ruma, @@ -46,11 +51,7 @@ pub(crate) async fn search_events_route( let mut searches = Vec::new(); for room_id in room_ids { - if !services() - .rooms - .state_cache - .is_joined(sender_user, &room_id)? - { + if !services().rooms.state_cache.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -102,7 +103,11 @@ pub(crate) async fn search_events_route( services() .rooms .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .user_can_see_event( + sender_user, + &pdu.room_id, + &pdu.event_id, + ) .unwrap_or(false) }) .map(|pdu| pdu.to_room_event()) diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 1fb6ee43..3894a05f 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,3 @@ -use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -17,6 +15,9 @@ use ruma::{ use serde::Deserialize; use tracing::{info, warn}; +use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; +use crate::{services, utils, Error, Result, Ruma}; + #[derive(Debug, Deserialize)] struct Claims { sub: String, @@ -24,30 +25,36 @@ struct Claims { /// # `GET /_matrix/client/r0/login` /// -/// Get the supported login types of this server. One of these should be used as the `type` field -/// when logging in. +/// Get the supported login types of this server. One of these should be used as +/// the `type` field when logging in. pub(crate) async fn get_login_types_route( _body: Ruma, ) -> Result { Ok(get_login_types::v3::Response::new(vec![ get_login_types::v3::LoginType::Password(PasswordLoginType::default()), - get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()), + get_login_types::v3::LoginType::ApplicationService( + ApplicationServiceLoginType::default(), + ), ])) } /// # `POST /_matrix/client/r0/login` /// -/// Authenticates the user and returns an access token it can use in subsequent requests. +/// Authenticates the user and returns an access token it can use in subsequent +/// requests. /// -/// - The user needs to authenticate using their password (or if enabled using a json web token) +/// - The user needs to authenticate using their password (or if enabled using a +/// json web token) /// - If `device_id` is known: invalidates old access token of that device /// - If `device_id` is unknown: creates a new device /// - Returns access token that is associated with the user and device /// -/// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to see -/// supported login types. +/// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to +/// see supported login types. #[allow(clippy::too_many_lines)] -pub(crate) async fn login_route(body: Ruma) -> Result { +pub(crate) async fn login_route( + body: Ruma, +) -> Result { // To allow deprecated login methods #![allow(deprecated)] // Validate login method @@ -59,18 +66,29 @@ pub(crate) async fn login_route(body: Ruma) -> Result { - let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - UserId::parse_with_server_name( - user_id.to_lowercase(), - services().globals.server_name(), - ) - } else if let Some(user) = user { - UserId::parse(user) - } else { - warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); - } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; + let user_id = + if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = + identifier + { + UserId::parse_with_server_name( + user_id.to_lowercase(), + services().globals.server_name(), + ) + } else if let Some(user) = user { + UserId::parse(user) + } else { + warn!("Bad login type: {:?}", &body.login_info); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Bad login type.", + )); + } + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; if services().appservice.is_exclusive_user_id(&user_id).await { return Err(Error::BadRequest( @@ -79,13 +97,12 @@ pub(crate) async fn login_route(body: Ruma) -> Result) -> Result) -> Result { - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { + login::v3::LoginInfo::Token(login::v3::Token { + token, + }) => { + if let Some(jwt_decoding_key) = + services().globals.jwt_decoding_key() + { let token = jsonwebtoken::decode::( token, jwt_decoding_key, &jsonwebtoken::Validation::default(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Token is invalid.", + ) + })?; let username = token.claims.sub.to_lowercase(); - let user_id = - UserId::parse_with_server_name(username, services().globals.server_name()) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })?; + let user_id = UserId::parse_with_server_name( + username, + services().globals.server_name(), + ) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; if services().appservice.is_exclusive_user_id(&user_id).await { return Err(Error::BadRequest( @@ -131,26 +164,40 @@ pub(crate) async fn login_route(body: Ruma) -> Result { - let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - UserId::parse_with_server_name( - user_id.to_lowercase(), - services().globals.server_name(), - ) - } else if let Some(user) = user { - UserId::parse(user) - } else { - warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); - } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; + login::v3::LoginInfo::ApplicationService( + login::v3::ApplicationService { + identifier, + user, + }, + ) => { + let user_id = + if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = + identifier + { + UserId::parse_with_server_name( + user_id.to_lowercase(), + services().globals.server_name(), + ) + } else if let Some(user) = user { + UserId::parse(user) + } else { + warn!("Bad login type: {:?}", &body.login_info); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Bad login type.", + )); + } + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; if let Some(info) = &body.appservice_info { if !info.is_user_match(&user_id) { @@ -225,12 +272,16 @@ pub(crate) async fn login_route(body: Ruma) -> Result) -> Result { +pub(crate) async fn logout_route( + body: Ruma, +) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_device = + body.sender_device.as_ref().expect("user is authenticated"); if let Some(info) = &body.appservice_info { if !info.is_user_match(sender_user) { @@ -251,12 +302,13 @@ pub(crate) async fn logout_route(body: Ruma) -> Result, ) -> Result { diff --git a/src/api/client_server/space.rs b/src/api/client_server/space.rs index f01d888c..a9b1af35 100644 --- a/src/api/client_server/space.rs +++ b/src/api/client_server/space.rs @@ -1,19 +1,18 @@ -use crate::{services, Result, Ruma}; use ruma::{api::client::space::get_hierarchy, uint}; +use crate::{services, Result, Ruma}; + /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// -/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space. +/// Paginates over the space tree in a depth-first manner to locate child rooms +/// of a given space. pub(crate) async fn get_hierarchy_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let skip = body - .from - .as_ref() - .and_then(|s| s.parse::().ok()) - .unwrap_or(0); + let skip = + body.from.as_ref().and_then(|s| s.parse::().ok()).unwrap_or(0); let limit = body .limit @@ -23,8 +22,10 @@ pub(crate) async fn get_hierarchy_route( .expect("0-100 should fit in usize"); // Plus one to skip the space room itself - let max_depth = usize::try_from(body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3))) - .expect("0-10 should fit in usize") + let max_depth = usize::try_from( + body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3)), + ) + .expect("0-10 should fit in usize") + 1; services() diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 518ae621..ace26880 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -1,25 +1,30 @@ use std::sync::Arc; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse}; use ruma::{ api::client::{ error::ErrorKind, state::{get_state_events, get_state_events_for_key, send_state_event}, }, events::{ - room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType, + room::canonical_alias::RoomCanonicalAliasEventContent, + AnyStateEventContent, StateEventType, }, serde::Raw, EventId, RoomId, UserId, }; use tracing::log::warn; +use crate::{ + service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse, +}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}` /// /// Sends a state event into the room. /// /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed /// - If event is new `canonical_alias`: Rejects if alias is incorrect pub(crate) async fn send_state_event_for_key_route( body: Ruma, @@ -37,7 +42,9 @@ pub(crate) async fn send_state_event_for_key_route( .await?; let event_id = (*event_id).to_owned(); - Ok(send_state_event::v3::Response { event_id }) + Ok(send_state_event::v3::Response { + event_id, + }) } /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}` @@ -45,7 +52,8 @@ pub(crate) async fn send_state_event_for_key_route( /// Sends a state event into the room. /// /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed /// - If event is new `canonical_alias`: Rejects if alias is incorrect pub(crate) async fn send_state_event_for_empty_key_route( body: Ruma, @@ -53,7 +61,9 @@ pub(crate) async fn send_state_event_for_empty_key_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Forbid m.room.encryption if encryption is disabled - if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { + if body.event_type == StateEventType::RoomEncryption + && !services().globals.allow_encryption() + { return Err(Error::BadRequest( ErrorKind::Forbidden, "Encryption has been disabled", @@ -70,14 +80,18 @@ pub(crate) async fn send_state_event_for_empty_key_route( .await?; let event_id = (*event_id).to_owned(); - Ok(send_state_event::v3::Response { event_id }.into()) + Ok(send_state_event::v3::Response { + event_id, + } + .into()) } /// # `GET /_matrix/client/r0/rooms/{roomid}/state` /// /// Get all state events for a room. /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub(crate) async fn get_state_events_route( body: Ruma, ) -> Result { @@ -110,7 +124,8 @@ pub(crate) async fn get_state_events_route( /// /// Get single state event of a room. /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub(crate) async fn get_state_events_for_key_route( body: Ruma, ) -> Result { @@ -140,8 +155,9 @@ pub(crate) async fn get_state_events_for_key_route( })?; Ok(get_state_events_for_key::v3::Response { - content: serde_json::from_str(event.content.get()) - .map_err(|_| Error::bad_database("Invalid event content in database"))?, + content: serde_json::from_str(event.content.get()).map_err(|_| { + Error::bad_database("Invalid event content in database") + })?, }) } @@ -149,7 +165,8 @@ pub(crate) async fn get_state_events_for_key_route( /// /// Get single state event of a room. /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub(crate) async fn get_state_events_for_empty_key_route( body: Ruma, ) -> Result> { @@ -179,8 +196,9 @@ pub(crate) async fn get_state_events_for_empty_key_route( })?; Ok(get_state_events_for_key::v3::Response { - content: serde_json::from_str(event.content.get()) - .map_err(|_| Error::bad_database("Invalid event content in database"))?, + content: serde_json::from_str(event.content.get()).map_err(|_| { + Error::bad_database("Invalid event content in database") + })?, } .into()) } @@ -194,10 +212,11 @@ async fn send_state_event_for_key_helper( ) -> Result> { let sender_user = sender; - // TODO: Review this check, error if event is unparsable, use event type, allow alias if it - // previously existed - if let Ok(canonical_alias) = - serde_json::from_str::(json.json().get()) + // TODO: Review this check, error if event is unparsable, use event type, + // allow alias if it previously existed + if let Ok(canonical_alias) = serde_json::from_str::< + RoomCanonicalAliasEventContent, + >(json.json().get()) { let mut aliases = canonical_alias.alt_aliases.clone(); @@ -216,8 +235,8 @@ async fn send_state_event_for_key_helper( { return Err(Error::BadRequest( ErrorKind::Forbidden, - "You are only allowed to send canonical_alias \ - events when it's aliases already exists", + "You are only allowed to send canonical_alias events when \ + it's aliases already exists", )); } } @@ -240,7 +259,8 @@ async fn send_state_event_for_key_helper( .build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), - content: serde_json::from_str(json.json().get()).expect("content is valid json"), + content: serde_json::from_str(json.json().get()) + .expect("content is valid json"), unsigned: None, state_key: Some(state_key), redacts: None, diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index cc2d56d5..71ec8497 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,6 +1,7 @@ -use crate::{ - service::{pdu::EventHash, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, Ruma, RumaResponse, +use std::{ + collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, + sync::Arc, + time::Duration, }; use ruma::{ @@ -9,8 +10,9 @@ use ruma::{ sync::sync_events::{ self, v3::{ - Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, - LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, + Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, + JoinedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, + Rooms, State, Timeline, ToDevice, }, v4::SlidingOp, DeviceLists, UnreadNotificationsCount, @@ -21,21 +23,23 @@ use ruma::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, TimelineEventType, }, - uint, DeviceId, EventId, JsOption, OwnedDeviceId, OwnedUserId, RoomId, UInt, UserId, -}; -use std::{ - collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, - sync::Arc, - time::Duration, + uint, DeviceId, EventId, JsOption, OwnedDeviceId, OwnedUserId, RoomId, + UInt, UserId, }; use tokio::sync::watch::Sender; use tracing::{debug, error, info}; +use crate::{ + service::{pdu::EventHash, rooms::timeline::PduCount}, + services, utils, Error, PduEvent, Result, Ruma, RumaResponse, +}; + /// # `GET /_matrix/client/r0/sync` /// /// Synchronize the client's state with the latest state on the server. /// -/// - This endpoint takes a `since` parameter which should be the `next_batch` value from a +/// - This endpoint takes a `since` parameter which should be the `next_batch` +/// value from a /// previous request for incremental syncs. /// /// Calling this endpoint without a `since` parameter returns: @@ -44,26 +48,34 @@ use tracing::{debug, error, info}; /// - Joined and invited member counts, heroes /// - All state events /// -/// Calling this endpoint with a `since` parameter from a previous `next_batch` returns: -/// For joined rooms: -/// - Some of the most recent events of each timeline that happened after `since` -/// - If user joined the room after `since`: All state events (unless lazy loading is activated) and +/// Calling this endpoint with a `since` parameter from a previous `next_batch` +/// returns: For joined rooms: +/// - Some of the most recent events of each timeline that happened after +/// `since` +/// - If user joined the room after `since`: All state events (unless lazy +/// loading is activated) and /// all device list updates in that room -/// - If the user was already in the room: A list of all events that are in the state now, but were +/// - If the user was already in the room: A list of all events that are in the +/// state now, but were /// not in the state at `since` -/// - If the state we send contains a member event: Joined and invited member counts, heroes +/// - If the state we send contains a member event: Joined and invited member +/// counts, heroes /// - Device list updates that happened after `since` -/// - If there are events in the timeline we send or the user send updated their read mark: Notification counts +/// - If there are events in the timeline we send or the user send updated their +/// read mark: Notification counts /// - EDUs that are active now (read receipts, typing updates, presence) /// - TODO: Allow multiple sync streams to support Pantalaimon /// /// For invited rooms: -/// - If the user was invited after `since`: A subset of the state of the room at the point of the invite +/// - If the user was invited after `since`: A subset of the state of the room +/// at the point of the invite /// /// For left rooms: -/// - If the user left after `since`: `prev_batch` token, empty state (TODO: subset of the state at the point of the leave) +/// - If the user left after `since`: `prev_batch` token, empty state (TODO: +/// subset of the state at the point of the leave) /// -/// - Sync is handled in an async task, multiple requests from the same device with the same +/// - Sync is handled in an async task, multiple requests from the same device +/// with the same /// `since` will be cached pub(crate) async fn sync_events_route( body: Ruma, @@ -154,7 +166,8 @@ async fn sync_helper_wrapper( .entry((sender_user, sender_device)) { Entry::Occupied(o) => { - // Only remove if the device didn't start a different /sync already + // Only remove if the device didn't start a different /sync + // already if o.get().0 == since { o.remove(); } @@ -164,8 +177,7 @@ async fn sync_helper_wrapper( } } - tx.send(Some(r.map(|(r, _)| r))) - .expect("receiver should not be dropped"); + tx.send(Some(r.map(|(r, _)| r))).expect("receiver should not be dropped"); } #[allow(clippy::too_many_lines)] @@ -192,21 +204,19 @@ async fn sync_helper( .unwrap_or_default(), }; - let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members: redundant, - } => (true, redundant), - LazyLoadOptions::Disabled => (false, false), - }; + let (lazy_load_enabled, lazy_load_send_redundant) = + match filter.room.state.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members: redundant, + } => (true, redundant), + LazyLoadOptions::Disabled => (false, false), + }; let full_state = body.full_state; let mut joined_rooms = BTreeMap::new(); - let since = body - .since - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); + let since = + body.since.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); let sincecount = PduCount::Normal(since); // Users that have left any encrypted rooms the sender was in @@ -252,11 +262,8 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_left(&sender_user) - .collect(); + let all_left_rooms: Vec<_> = + services().rooms.state_cache.rooms_left(&sender_user).collect(); for result in all_left_rooms { let (room_id, _) = result?; @@ -294,7 +301,8 @@ async fn sync_helper( .try_into() .expect("Timestamp is valid js_int value"), kind: TimelineEventType::RoomMember, - content: serde_json::from_str(r#"{ "membership": "leave"}"#).unwrap(), + content: serde_json::from_str(r#"{ "membership": "leave"}"#) + .unwrap(), state_key: Some(sender_user.to_string()), unsigned: None, // The following keys are dropped on conversion @@ -312,7 +320,9 @@ async fn sync_helper( left_rooms.insert( room_id, LeftRoom { - account_data: RoomAccountData { events: Vec::new() }, + account_data: RoomAccountData { + events: Vec::new(), + }, timeline: Timeline { limited: false, prev_batch: Some(next_batch_string.clone()), @@ -329,21 +339,22 @@ async fn sync_helper( let mut left_state_events = Vec::new(); - let since_shortstatehash = services() - .rooms - .user - .get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = + services().rooms.user.get_token_shortstatehash(&room_id, since)?; let since_state_ids = match since_shortstatehash { - Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, + Some(s) => { + services().rooms.state_accessor.state_full_ids(s).await? + } None => HashMap::new(), }; - let Some(left_event_id) = services().rooms.state_accessor.room_state_get_id( - &room_id, - &StateEventType::RoomMember, - sender_user.as_str(), - )? + let Some(left_event_id) = + services().rooms.state_accessor.room_state_get_id( + &room_id, + &StateEventType::RoomMember, + sender_user.as_str(), + )? else { error!("Left room but no left state event"); continue; @@ -364,10 +375,11 @@ async fn sync_helper( .state_full_ids(left_shortstatehash) .await?; - let leave_shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + let leave_shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &StateEventType::RoomMember, + sender_user.as_str(), + )?; left_state_ids.insert(leave_shortstatekey, left_event_id); @@ -383,7 +395,8 @@ async fn sync_helper( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || *sender_user == state_key { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services().rooms.timeline.get_pdu(&id)? + else { error!("Pdu in state not found: {}", id); continue; }; @@ -401,7 +414,9 @@ async fn sync_helper( left_rooms.insert( room_id.clone(), LeftRoom { - account_data: RoomAccountData { events: Vec::new() }, + account_data: RoomAccountData { + events: Vec::new(), + }, timeline: Timeline { limited: false, prev_batch: Some(next_batch_string.clone()), @@ -415,11 +430,8 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_invited(&sender_user) - .collect(); + let all_invited_rooms: Vec<_> = + services().rooms.state_cache.rooms_invited(&sender_user).collect(); for result in all_invited_rooms { let (room_id, invite_state_events) = result?; @@ -469,23 +481,29 @@ async fn sync_helper( services() .rooms .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .room_state_get( + &other_room_id, + &StateEventType::RoomEncryption, + "", + ) .ok()? .is_some(), ) }) .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need to tell - // them + // If the user doesn't share an encrypted room with the target anymore, + // we need to tell them if dont_share_encrypted_room { device_list_left.insert(user_id); } } // Remove all to-device events the device received *last time* - services() - .users - .remove_to_device_events(&sender_user, &sender_device, since)?; + services().users.remove_to_device_events( + &sender_user, + &sender_device, + since, + )?; let response = sync_events::v3::Response { next_batch: next_batch_string, @@ -504,7 +522,11 @@ async fn sync_helper( .into_iter() .filter_map(|(_, v)| { serde_json::from_str(v.json().get()) - .map_err(|_| Error::bad_database("Invalid account event in database.")) + .map_err(|_| { + Error::bad_database( + "Invalid account event in database.", + ) + }) .ok() }) .collect(), @@ -581,7 +603,8 @@ async fn load_joined_room( drop(insert_lock); } - let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = + load_timeline(sender_user, room_id, sincecount, 10)?; let send_notification_counts = !timeline_pdus.is_empty() || services() @@ -598,104 +621,126 @@ async fn load_joined_room( services() .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) + .lazy_load_confirm_delivery( + sender_user, + sender_device, + room_id, + sincecount, + ) .await?; // Database queries: - let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? + let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? else { error!("Room {} has no state", room_id); return Err(Error::BadDatabase("Room has no state")); }; - let since_shortstatehash = services() - .rooms - .user - .get_token_shortstatehash(room_id, since)?; + let since_shortstatehash = + services().rooms.user.get_token_shortstatehash(room_id, since)?; - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services() + let ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) = if timeline_pdus.is_empty() + && since_shortstatehash == Some(current_shortstatehash) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || { + let joined_member_count = services() + .rooms + .state_cache + .room_joined_count(room_id)? + .unwrap_or(0); + let invited_member_count = services() + .rooms + .state_cache + .room_invited_count(room_id)? + .unwrap_or(0); + + // Recalculate heroes (first 5 members) + let mut heroes = Vec::new(); + + if joined_member_count + invited_member_count <= 5 { + // Go through all PDUs and for each member event, check if the + // user is still joined or invited until we have + // 5 or we reach the end + + for hero in services() .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services() - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); + .timeline + .all_pdus(sender_user, room_id)? + .filter_map(Result::ok) + .filter(|(_, pdu)| { + pdu.kind == TimelineEventType::RoomMember + }) + .map(|(_, pdu)| { + let content: RoomMemberEventContent = + serde_json::from_str(pdu.content.get()).map_err( + |_| { + Error::bad_database( + "Invalid member event in database.", + ) + }, + )?; - // Recalculate heroes (first 5 members) - let mut heroes = Vec::new(); - - if joined_member_count + invited_member_count <= 5 { - // Go through all PDUs and for each member event, check if the user is still joined or - // invited until we have 5 or we reach the end - - for hero in services() - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(Result::ok) - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = - serde_json::from_str(pdu.content.get()).map_err(|_| { - Error::bad_database("Invalid member event in database.") + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| { + Error::bad_database( + "Invalid UserId in member PDU.", + ) })?; - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; - - // The membership was and still is invite or join - if matches!( - content.membership, - MembershipState::Join | MembershipState::Invite - ) && (services() + // The membership was and still is invite or join + if matches!( + content.membership, + MembershipState::Join | MembershipState::Invite + ) && (services() + .rooms + .state_cache + .is_joined(&user_id, room_id)? + || services() .rooms .state_cache - .is_joined(&user_id, room_id)? - || services() - .rooms - .state_cache - .is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(state_key.clone())) - } else { - Ok(None) - } + .is_invited(&user_id, room_id)?) + { + Ok::<_, Error>(Some(state_key.clone())) } else { Ok(None) } - }) - .filter_map(Result::ok) - .flatten() - { - if heroes.contains(&hero) || hero == sender_user.as_str() { - continue; + } else { + Ok(None) } - - heroes.push(hero); + }) + .filter_map(Result::ok) + .flatten() + { + if heroes.contains(&hero) || hero == sender_user.as_str() { + continue; } + + heroes.push(hero); } + } - Ok::<_, Error>(( - Some(joined_member_count), - Some(invited_member_count), - heroes, - )) - }; + Ok::<_, Error>(( + Some(joined_member_count), + Some(invited_member_count), + heroes, + )) + }; - let since_sender_member: Option = since_shortstatehash + let since_sender_member: Option = + since_shortstatehash .and_then(|shortstatehash| { services() .rooms @@ -710,270 +755,303 @@ async fn load_joined_room( .transpose()? .and_then(|pdu| { serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .map_err(|_| { + Error::bad_database("Invalid PDU in database.") + }) .ok() }); - let joined_since_last_sync = since_sender_member - .map_or(true, |member| member.membership != MembershipState::Join); + let joined_since_last_sync = since_sender_member + .map_or(true, |member| member.membership != MembershipState::Join); - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + let (joined_member_count, invited_member_count, heroes) = + calculate_counts()?; + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + let mut i = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services() + .rooms + .short + .get_statekey_from_short(shortstatekey)?; + + if event_type != StateEventType::RoomMember { + let Some(pdu) = services().rooms.timeline.get_pdu(&id)? + else { + error!("Pdu in state not found: {}", id); + continue; + }; + state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled + || full_state + || timeline_users.contains(&state_key) + // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 + || *sender_user == state_key + { + let Some(pdu) = services().rooms.timeline.get_pdu(&id)? + else { + error!("Pdu in state not found: {}", id); + continue; + }; + + // This check is in case a bad user ID made it into the + // database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + } + + // Reset lazy loading because this is an initial sync + services().rooms.lazy_loading.lazy_load_reset( + sender_user, + sender_device, + room_id, + )?; + + // The state_events above should contain all timeline_users, let's + // mark them as lazy loaded. + services() + .rooms + .lazy_loading + .lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ) + .await; + + ( + heroes, + joined_member_count, + invited_member_count, + true, + state_events, + ) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.unwrap(); + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + if since_shortstatehash != current_shortstatehash { let current_state_ids = services() .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; + let since_state_ids = services() + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - let mut i = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services() - .rooms - .short - .get_statekey_from_short(shortstatekey)?; - - if event_type != StateEventType::RoomMember { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - state_events.push(pdu); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled - || full_state - || timeline_users.contains(&state_key) - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || *sender_user == state_key - { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let Some(pdu) = + services().rooms.timeline.get_pdu(&id)? + else { error!("Pdu in state not found: {}", id); continue; }; - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + } + Err(e) => error!( + "Invalid state key for member event: {}", + e + ), + } } - state_events.push(pdu); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } + state_events.push(pdu); + tokio::task::yield_now().await; } } + } - // Reset lazy loading because this is an initial sync - services().rooms.lazy_loading.lazy_load_reset( + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } + + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, room_id, - )?; - - // The state_events above should contain all timeline_users, let's mark them as lazy - // loaded. - services() - .rooms - .lazy_loading - .lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ) - .await; - - ( - heroes, - joined_member_count, - invited_member_count, - true, - state_events, - ) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); - - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services() - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; - - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - } - Err(e) => error!("Invalid state key for member event: {}", e), - } - } - - state_events.push(pdu); - tokio::task::yield_now().await; - } - } - } - - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } - - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant - { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( + &event.sender, + )? || lazy_load_send_redundant + { + if let Some(member_event) = + services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomMember, event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); - } + )? + { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); } } + } - services() - .rooms - .lazy_loading - .lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ) - .await; + services() + .rooms + .lazy_loading + .lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ) + .await; - let encrypted_room = services() - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); - - let since_encryption = services().rooms.state_accessor.state_get( - since_shortstatehash, + let encrypted_room = services() + .rooms + .state_accessor + .state_get( + current_shortstatehash, &StateEventType::RoomEncryption, "", - )?; + )? + .is_some(); - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + let since_encryption = services().rooms.state_accessor.state_get( + since_shortstatehash, + &StateEventType::RoomEncryption, + "", + )?; - let send_member_count = state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); + // Calculations: + let new_encrypted_room = + encrypted_room && since_encryption.is_none(); - if encrypted_room { - for state_event in &state_events { - if state_event.kind != TimelineEventType::RoomMember { + let send_member_count = state_events + .iter() + .any(|event| event.kind == TimelineEventType::RoomMember); + + if encrypted_room { + for state_event in &state_events { + if state_event.kind != TimelineEventType::RoomMember { + continue; + } + + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| { + Error::bad_database( + "Invalid UserId in member PDU.", + ) + })?; + + if user_id == sender_user { continue; } - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; - - if user_id == sender_user { - continue; - } - - let new_membership = serde_json::from_str::( + let new_membership = + serde_json::from_str::( state_event.content.get(), ) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .map_err(|_| { + Error::bad_database("Invalid PDU in database.") + })? .membership; - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room( + sender_user, + &user_id, + room_id, + )? { + device_list_updates.insert(user_id); } - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - } - _ => {} } + MembershipState::Leave => { + // Write down users that have left encrypted + // rooms we are in + left_encrypted_users.insert(user_id); + } + _ => {} } } } + } - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services() - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(sender_user, user_id, room_id) - .unwrap_or(false) - }), - ); - } + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined + // users + device_list_updates.extend( + services() + .rooms + .state_cache + .room_members(room_id) + .flatten() + .filter(|user_id| { + // Don't send key updates from the sender to the + // sender + sender_user != user_id + }) + .filter(|user_id| { + // Only send keys if the sender doesn't share an + // encrypted room with the target already + !share_encrypted_room(sender_user, user_id, room_id) + .unwrap_or(false) + }), + ); + } - let (joined_member_count, invited_member_count, heroes) = if send_member_count { + let (joined_member_count, invited_member_count, heroes) = + if send_member_count { calculate_counts()? } else { (None, None, Vec::new()) }; - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; // Look for device list updates in this room device_list_updates.extend( @@ -984,12 +1062,7 @@ async fn load_joined_room( ); let notification_count = send_notification_counts - .then(|| { - services() - .rooms - .user - .notification_count(sender_user, room_id) - }) + .then(|| services().rooms.user.notification_count(sender_user, room_id)) .transpose()? .map(|x| x.try_into().expect("notification count can't go that high")); @@ -998,9 +1071,9 @@ async fn load_joined_room( .transpose()? .map(|x| x.try_into().expect("highlight count can't go that high")); - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + let prev_batch = timeline_pdus.first().map_or( + Ok::<_, Error>(None), + |(pdu_count, _)| { Ok(Some(match pdu_count { PduCount::Backfilled(_) => { error!("timeline in backfill state?!"); @@ -1008,12 +1081,11 @@ async fn load_joined_room( } PduCount::Normal(c) => c.to_string(), })) - })?; + }, + )?; - let room_events: Vec<_> = timeline_pdus - .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); + let room_events: Vec<_> = + timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); let mut edus: Vec<_> = services() .rooms @@ -1024,24 +1096,20 @@ async fn load_joined_room( .map(|(_, _, v)| v) .collect(); - if services() - .rooms - .edus - .typing - .last_typing_update(room_id) - .await? - > since - { + if services().rooms.edus.typing.last_typing_update(room_id).await? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id).await?) - .expect("event is valid, we just created it"), + &serde_json::to_string( + &services().rooms.edus.typing.typings_all(room_id).await?, + ) + .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), ); } - // Save the state after this sync so we can send the correct state diff next sync + // Save the state after this sync so we can send the correct state diff next + // sync services().rooms.user.associate_token_shortstatehash( room_id, next_batch, @@ -1056,7 +1124,11 @@ async fn load_joined_room( .into_iter() .filter_map(|(_, v)| { serde_json::from_str(v.json().get()) - .map_err(|_| Error::bad_database("Invalid account event in database.")) + .map_err(|_| { + Error::bad_database( + "Invalid account event in database.", + ) + }) .ok() }) .collect(), @@ -1064,7 +1136,8 @@ async fn load_joined_room( summary: RoomSummary { heroes, joined_member_count: joined_member_count.map(UInt::new_saturating), - invited_member_count: invited_member_count.map(UInt::new_saturating), + invited_member_count: invited_member_count + .map(UInt::new_saturating), }, unread_notifications: UnreadNotificationsCount { highlight_count, @@ -1081,7 +1154,9 @@ async fn load_joined_room( .map(|pdu| pdu.to_sync_state_event()) .collect(), }, - ephemeral: Ephemeral { events: edus }, + ephemeral: Ephemeral { + events: edus, + }, unread_thread_notifications: BTreeMap::new(), }) } @@ -1094,10 +1169,7 @@ fn load_timeline( ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; let limited; - if services() - .rooms - .timeline - .last_timeline_count(sender_user, room_id)? + if services().rooms.timeline.last_timeline_count(sender_user, room_id)? > roomsincecount { let mut non_timeline_pdus = services() @@ -1121,8 +1193,9 @@ fn load_timeline( .rev() .collect::>(); - // They /sync response doesn't always return all messages, so we say the output is - // limited unless there are events in non_timeline_pdus + // They /sync response doesn't always return all messages, so we say the + // output is limited unless there are events in + // non_timeline_pdus limited = non_timeline_pdus.next().is_some(); } else { timeline_pdus = Vec::new(); @@ -1147,7 +1220,11 @@ fn share_encrypted_room( services() .rooms .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .room_state_get( + &other_room_id, + &StateEventType::RoomEncryption, + "", + ) .ok()? .is_some(), ) @@ -1167,11 +1244,8 @@ pub(crate) async fn sync_events_v4_route( let next_batch = services().globals.next_count()?; - let globalsince = body - .pos - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); + let globalsince = + body.pos.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); if globalsince == 0 { if let Some(conn_id) = &body.conn_id { @@ -1198,9 +1272,11 @@ pub(crate) async fn sync_events_v4_route( .collect::>(); if body.extensions.to_device.enabled.unwrap_or(false) { - services() - .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; + services().users.remove_to_device_events( + &sender_user, + &sender_device, + globalsince, + )?; } // Users that have left any encrypted rooms the sender was in @@ -1230,29 +1306,36 @@ pub(crate) async fn sync_events_v4_route( .user .get_token_shortstatehash(room_id, globalsince)?; - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services() - .rooms - .state_accessor - .state_get( - shortstatehash, - &StateEventType::RoomMember, - sender_user.as_str(), - ) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = + since_shortstatehash + .and_then(|shortstatehash| { + services() + .rooms + .state_accessor + .state_get( + shortstatehash, + &StateEventType::RoomMember, + sender_user.as_str(), + ) + .transpose() + }) + .transpose()? + .and_then(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| { + Error::bad_database("Invalid PDU in database.") + }) + .ok() + }); let encrypted_room = services() .rooms .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? + .state_get( + current_shortstatehash, + &StateEventType::RoomEncryption, + "", + )? .is_some(); if let Some(since_shortstatehash) = since_shortstatehash { @@ -1261,16 +1344,20 @@ pub(crate) async fn sync_events_v4_route( continue; } - let since_encryption = services().rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = + services().rooms.state_accessor.state_get( + since_shortstatehash, + &StateEventType::RoomEncryption, + "", + )?; let joined_since_last_sync = since_sender_member - .map_or(true, |member| member.membership != MembershipState::Join); + .map_or(true, |member| { + member.membership != MembershipState::Join + }); - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + let new_encrypted_room = + encrypted_room && since_encryption.is_none(); if encrypted_room { let current_state_ids = services() .rooms @@ -1285,43 +1372,58 @@ pub(crate) async fn sync_events_v4_route( for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = + services().rooms.timeline.get_pdu(&id)? + else { error!("Pdu in state not found: {}", id); continue; }; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { let user_id = - UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; + UserId::parse(state_key.clone()) + .map_err(|_| { + Error::bad_database( + "Invalid UserId in member \ + PDU.", + ) + })?; if user_id == sender_user { continue; } - let new_membership = serde_json::from_str::< - RoomMemberEventContent, - >( - pdu.content.get() - ) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let new_membership = + serde_json::from_str::< + RoomMemberEventContent, + >( + pdu.content.get() + ) + .map_err(|_| { + Error::bad_database( + "Invalid PDU in database.", + ) + })? + .membership; match new_membership { MembershipState::Join => { - // A new user joined an encrypted room + // A new user joined an encrypted + // room if !share_encrypted_room( &sender_user, &user_id, room_id, )? { - device_list_changes.insert(user_id); + device_list_changes + .insert(user_id); } } MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); + // Write down users that have left + // encrypted rooms we are in + left_encrypted_users + .insert(user_id); } _ => {} } @@ -1330,7 +1432,8 @@ pub(crate) async fn sync_events_v4_route( } } if joined_since_last_sync || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users + // If the user is in a new encrypted room, give them all + // joined users device_list_changes.extend( services() .rooms @@ -1338,13 +1441,20 @@ pub(crate) async fn sync_events_v4_route( .room_members(room_id) .flatten() .filter(|user_id| { - // Don't send key updates from the sender to the sender + // Don't send key updates from the sender to + // the sender &sender_user != user_id }) .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&sender_user, user_id, room_id) - .unwrap_or(false) + // Only send keys if the sender doesn't + // share an encrypted room with the target + // already + !share_encrypted_room( + &sender_user, + user_id, + room_id, + ) + .unwrap_or(false) }), ); } @@ -1369,14 +1479,18 @@ pub(crate) async fn sync_events_v4_route( services() .rooms .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .room_state_get( + &other_room_id, + &StateEventType::RoomEncryption, + "", + ) .ok()? .is_some(), ) }) .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need to tell - // them + // If the user doesn't share an encrypted room with the target + // anymore, we need to tell them if dont_share_encrypted_room { device_list_left.insert(user_id); } @@ -1403,30 +1517,36 @@ pub(crate) async fn sync_events_v4_route( .map(|mut r| { r.0 = r.0.clamp( uint!(0), - UInt::try_from(all_joined_rooms.len() - 1).unwrap_or(UInt::MAX), + UInt::try_from(all_joined_rooms.len() - 1) + .unwrap_or(UInt::MAX), ); r.1 = r.1.clamp( r.0, - UInt::try_from(all_joined_rooms.len() - 1).unwrap_or(UInt::MAX), + UInt::try_from(all_joined_rooms.len() - 1) + .unwrap_or(UInt::MAX), ); - let room_ids = all_joined_rooms[r.0.try_into().unwrap_or(usize::MAX) + let room_ids = all_joined_rooms[r + .0 + .try_into() + .unwrap_or(usize::MAX) ..=r.1.try_into().unwrap_or(usize::MAX)] .to_vec(); new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { - let todo_room = todo_rooms.entry(room_id.clone()).or_insert(( - BTreeSet::new(), - 0, - u64::MAX, - )); + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0, u64::MAX)); let limit = list .room_details .timeline_limit .map_or(10, u64::from) .min(100); - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); + todo_room.0.extend( + list.room_details + .required_state + .iter() + .cloned(), + ); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( @@ -1446,7 +1566,8 @@ pub(crate) async fn sync_events_v4_route( } }) .collect(), - count: UInt::try_from(all_joined_rooms.len()).unwrap_or(UInt::MAX), + count: UInt::try_from(all_joined_rooms.len()) + .unwrap_or(UInt::MAX), }, ); @@ -1467,9 +1588,11 @@ pub(crate) async fn sync_events_v4_route( if !services().rooms.metadata.exists(room_id)? { continue; } - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); + let todo_room = todo_rooms.entry(room_id.clone()).or_insert(( + BTreeSet::new(), + 0, + u64::MAX, + )); let limit = room.timeline_limit.map_or(10, u64::from).min(100); todo_room.0.extend(room.required_state.iter().cloned()); todo_room.1 = todo_room.1.max(limit); @@ -1510,11 +1633,17 @@ pub(crate) async fn sync_events_v4_route( } let mut rooms = BTreeMap::new(); - for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { + for (room_id, (required_state_request, timeline_limit, roomsince)) in + &todo_rooms + { let roomsincecount = PduCount::Normal(*roomsince); - let (timeline_pdus, limited) = - load_timeline(&sender_user, room_id, roomsincecount, *timeline_limit)?; + let (timeline_pdus, limited) = load_timeline( + &sender_user, + room_id, + roomsincecount, + *timeline_limit, + )?; if roomsince != &0 && timeline_pdus.is_empty() { continue; @@ -1599,12 +1728,18 @@ pub(crate) async fn sync_events_v4_route( rooms.insert( room_id.clone(), sync_events::v4::SlidingSyncRoom { - name: services().rooms.state_accessor.get_name(room_id)?.or(name), + name: services() + .rooms + .state_accessor + .get_name(room_id)? + .or(name), avatar: if let Some(avatar) = avatar { JsOption::Some(avatar) } else { match services().rooms.state_accessor.get_avatar(room_id)? { - JsOption::Some(avatar) => JsOption::from_option(avatar.url), + JsOption::Some(avatar) => { + JsOption::from_option(avatar.url) + } JsOption::Null => JsOption::Null, JsOption::Undefined => JsOption::Undefined, } @@ -1707,7 +1842,8 @@ pub(crate) async fn sync_events_v4_route( device_unused_fallback_key_types: None, }, account_data: sync_events::v4::AccountData { - global: if body.extensions.account_data.enabled.unwrap_or(false) { + global: if body.extensions.account_data.enabled.unwrap_or(false) + { services() .account_data .changes_since(None, &sender_user, globalsince)? @@ -1715,7 +1851,9 @@ pub(crate) async fn sync_events_v4_route( .filter_map(|(_, v)| { serde_json::from_str(v.json().get()) .map_err(|_| { - Error::bad_database("Invalid account event in database.") + Error::bad_database( + "Invalid account event in database.", + ) }) .ok() }) diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 616fc618..22738af8 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -1,4 +1,5 @@ -use crate::{services, Error, Result, Ruma}; +use std::collections::BTreeMap; + use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ @@ -6,7 +7,8 @@ use ruma::{ RoomAccountDataEventType, }, }; -use std::collections::BTreeMap; + +use crate::{services, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// @@ -33,8 +35,9 @@ pub(crate) async fn update_tag_route( }) }, |e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) + serde_json::from_str(e.get()).map_err(|_| { + Error::bad_database("Invalid account data event in db.") + }) }, )?; @@ -78,8 +81,9 @@ pub(crate) async fn delete_tag_route( }) }, |e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) + serde_json::from_str(e.get()).map_err(|_| { + Error::bad_database("Invalid account data event in db.") + }) }, )?; @@ -120,8 +124,9 @@ pub(crate) async fn get_tags_route( }) }, |e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) + serde_json::from_str(e.get()).map_err(|_| { + Error::bad_database("Invalid account data event in db.") + }) }, )?; diff --git a/src/api/client_server/thirdparty.rs b/src/api/client_server/thirdparty.rs index 3c678acf..8cb77e4b 100644 --- a/src/api/client_server/thirdparty.rs +++ b/src/api/client_server/thirdparty.rs @@ -1,7 +1,8 @@ -use crate::{Result, Ruma}; +use std::collections::BTreeMap; + use ruma::api::client::thirdparty::get_protocols; -use std::collections::BTreeMap; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/thirdparty/protocols` /// diff --git a/src/api/client_server/threads.rs b/src/api/client_server/threads.rs index c9142610..9beb8aa2 100644 --- a/src/api/client_server/threads.rs +++ b/src/api/client_server/threads.rs @@ -9,11 +9,8 @@ pub(crate) async fn get_threads_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|l| l.try_into().ok()) - .unwrap_or(10) - .min(100); + let limit = + body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100); let from = if let Some(from) = &body.from { from.parse() diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index fbaad4d9..c29c41de 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -1,6 +1,5 @@ use std::collections::BTreeMap; -use crate::{services, Error, Result, Ruma}; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -9,6 +8,8 @@ use ruma::{ to_device::DeviceIdOrAllDevices, }; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// /// Send a to-device event to a set of client devices. @@ -29,7 +30,8 @@ pub(crate) async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != services().globals.server_name() { + if target_user_id.server_name() != services().globals.server_name() + { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); @@ -38,14 +40,16 @@ pub(crate) async fn send_event_to_device_route( services().sending.send_reliable_edu( target_user_id.server_name(), - serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( - DirectDeviceContent { - sender: sender_user.clone(), - ev_type: body.event_type.clone(), - message_id: count.to_string().into(), - messages, - }, - )) + serde_json::to_vec( + &federation::transactions::edu::Edu::DirectToDevice( + DirectDeviceContent { + sender: sender_user.clone(), + ev_type: body.event_type.clone(), + message_id: count.to_string().into(), + messages, + }, + ), + ) .expect("DirectToDevice EDU can be serialized"), count, )?; @@ -61,20 +65,28 @@ pub(crate) async fn send_event_to_device_route( target_device_id, &body.event_type.to_string(), event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) })?, )?; } DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { + for target_device_id in + services().users.all_device_ids(target_user_id) + { services().users.add_to_device_event( sender_user, target_user_id, &target_device_id?, &body.event_type.to_string(), event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) })?, )?; } @@ -84,9 +96,12 @@ pub(crate) async fn send_event_to_device_route( } // Save transaction id with empty data - services() - .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + services().transaction_ids.add_txnid( + sender_user, + sender_device, + &body.txn_id, + &[], + )?; Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client_server/typing.rs b/src/api/client_server/typing.rs index d9a24530..1e7b1b68 100644 --- a/src/api/client_server/typing.rs +++ b/src/api/client_server/typing.rs @@ -1,6 +1,7 @@ -use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; +use crate::{services, utils, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. @@ -11,11 +12,7 @@ pub(crate) async fn create_typing_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_cache - .is_joined(sender_user, &body.room_id)? - { + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You are not in this room.", diff --git a/src/api/client_server/unversioned.rs b/src/api/client_server/unversioned.rs index 27840af4..9d3dab29 100644 --- a/src/api/client_server/unversioned.rs +++ b/src/api/client_server/unversioned.rs @@ -6,14 +6,16 @@ use crate::{Result, Ruma}; /// # `GET /_matrix/client/versions` /// -/// Get the versions of the specification and unstable features supported by this server. +/// Get the versions of the specification and unstable features supported by +/// this server. /// /// - Versions take the form MAJOR.MINOR.PATCH /// - Only the latest PATCH release will be reported for each MAJOR.MINOR value -/// - Unstable features are namespaced and may include version information in their name +/// - Unstable features are namespaced and may include version information in +/// their name /// -/// Note: Unstable features are used while developing new features. Clients should avoid using -/// unstable features in their stable releases +/// Note: Unstable features are used while developing new features. Clients +/// should avoid using unstable features in their stable releases pub(crate) async fn get_supported_versions_route( _body: Ruma, ) -> Result { @@ -27,7 +29,10 @@ pub(crate) async fn get_supported_versions_route( "v1.4".to_owned(), "v1.5".to_owned(), ], - unstable_features: BTreeMap::from_iter([("org.matrix.e2e_cross_signing".to_owned(), true)]), + unstable_features: BTreeMap::from_iter([( + "org.matrix.e2e_cross_signing".to_owned(), + true, + )]), }; Ok(resp) diff --git a/src/api/client_server/user_directory.rs b/src/api/client_server/user_directory.rs index 66f40a86..fdac5c1b 100644 --- a/src/api/client_server/user_directory.rs +++ b/src/api/client_server/user_directory.rs @@ -1,4 +1,3 @@ -use crate::{services, Result, Ruma}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -7,11 +6,14 @@ use ruma::{ }, }; +use crate::{services, Result, Ruma}; + /// # `POST /_matrix/client/r0/user_directory/search` /// /// Searches all known users for a match. /// -/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) +/// - Hides any local users that aren't in any public rooms (i.e. those that +/// have the join rule set to public) /// and don't share a room with the sender pub(crate) async fn search_users_route( body: Ruma, @@ -38,8 +40,7 @@ pub(crate) async fn search_users_route( .display_name .as_ref() .filter(|name| { - name.to_lowercase() - .contains(&body.search_term.to_lowercase()) + name.to_lowercase().contains(&body.search_term.to_lowercase()) }) .is_some(); @@ -62,10 +63,12 @@ pub(crate) async fn search_users_route( .room_state_get(&room, &StateEventType::RoomJoinRules, "") .map_or(false, |event| { event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| { + serde_json::from_str(event.content.get()).map_or( + false, + |r: RoomJoinRulesEventContent| { r.join_rule == JoinRule::Public - }) + }, + ) }) }) }); @@ -96,5 +99,8 @@ pub(crate) async fn search_users_route( let results = users.by_ref().take(limit).collect(); let limited = users.next().is_some(); - Ok(search_users::v3::Response { results, limited }) + Ok(search_users::v3::Response { + results, + limited, + }) } diff --git a/src/api/client_server/voip.rs b/src/api/client_server/voip.rs index 657a6d20..1cd80d27 100644 --- a/src/api/client_server/voip.rs +++ b/src/api/client_server/voip.rs @@ -1,9 +1,11 @@ -use crate::{services, Result, Ruma}; +use std::time::{Duration, SystemTime}; + use base64::{engine::general_purpose, Engine as _}; use hmac::{Hmac, Mac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use sha1::Sha1; -use std::time::{Duration, SystemTime}; + +use crate::{services, Result, Ruma}; type HmacSha1 = Hmac; @@ -24,7 +26,8 @@ pub(crate) async fn turn_server_route( ) } else { let expiry = SecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), + SystemTime::now() + + Duration::from_secs(services().globals.turn_ttl()), ) .expect("time is valid"); @@ -34,7 +37,8 @@ pub(crate) async fn turn_server_route( .expect("HMAC can take key of any size"); mac.update(username.as_bytes()); - let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes()); + let password: String = + general_purpose::STANDARD.encode(mac.finalize().into_bytes()); (username, password) }; diff --git a/src/api/ruma_wrapper.rs b/src/api/ruma_wrapper.rs index e3486a73..85173fe9 100644 --- a/src/api/ruma_wrapper.rs +++ b/src/api/ruma_wrapper.rs @@ -1,10 +1,12 @@ -use crate::{service::appservice::RegistrationInfo, Error}; -use ruma::{ - api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, - OwnedUserId, -}; use std::ops::Deref; +use ruma::{ + api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, + OwnedServerName, OwnedUserId, +}; + +use crate::{service::appservice::RegistrationInfo, Error}; + mod axum; /// Extractor for Ruma request structs diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 55485023..f2e3e739 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -3,7 +3,9 @@ use std::{collections::BTreeMap, iter::FromIterator, str}; use axum::{ async_trait, body::{Full, HttpBody}, - extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, + extract::{ + rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader, + }, headers::{ authorization::{Bearer, Credentials}, Authorization, @@ -14,7 +16,9 @@ use axum::{ use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{Request, StatusCode}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + api::{ + client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse, + }, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; use serde::Deserialize; @@ -41,7 +45,10 @@ where type Rejection = Error; #[allow(clippy::too_many_lines)] - async fn from_request(req: Request, _state: &S) -> Result { + async fn from_request( + req: Request, + _state: &S, + ) -> Result { #[derive(Deserialize)] struct QueryParams { access_token: Option, @@ -51,22 +58,23 @@ where let (mut parts, mut body) = match req.with_limited_body() { Ok(limited_req) => { let (parts, body) = limited_req.into_parts(); - let body = to_bytes(body) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + let body = to_bytes(body).await.map_err(|_| { + Error::BadRequest(ErrorKind::MissingToken, "Missing token.") + })?; (parts, body) } Err(original_req) => { let (parts, body) = original_req.into_parts(); - let body = to_bytes(body) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + let body = to_bytes(body).await.map_err(|_| { + Error::BadRequest(ErrorKind::MissingToken, "Missing token.") + })?; (parts, body) } }; let metadata = T::METADATA; - let auth_header: Option>> = parts.extract().await?; + let auth_header: Option>> = + parts.extract().await?; let path_params: Path> = parts.extract().await?; let query = parts.uri.query().unwrap_or_default(); @@ -87,9 +95,13 @@ where }; let token = if let Some(token) = token { - if let Some(reg_info) = services().appservice.find_from_token(token).await { + if let Some(reg_info) = + services().appservice.find_from_token(token).await + { Token::Appservice(Box::new(reg_info.clone())) - } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { + } else if let Some((user_id, device_id)) = + services().users.find_from_token(token)? + { Token::User((user_id, OwnedDeviceId::from(device_id))) } else { Token::Invalid @@ -98,13 +110,16 @@ where Token::None }; - let mut json_body = serde_json::from_slice::(&body).ok(); + let mut json_body = + serde_json::from_slice::(&body).ok(); let (sender_user, sender_device, sender_servername, appservice_info) = match (metadata.authentication, token) { (_, Token::Invalid) => { return Err(Error::BadRequest( - ErrorKind::UnknownToken { soft_logout: false }, + ErrorKind::UnknownToken { + soft_logout: false, + }, "Unknown access token.", )) } @@ -121,7 +136,10 @@ where UserId::parse, ) .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) })?; if !info.is_user_match(&user_id) { @@ -153,7 +171,9 @@ where )); } ( - AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, + AuthScheme::AccessToken + | AuthScheme::AccessTokenOptional + | AuthScheme::None, Token::User((user_id, device_id)), ) => (Some(user_id), Some(device_id), None, None), (AuthScheme::ServerSignatures, Token::None) => { @@ -161,7 +181,10 @@ where .extract::>>() .await .map_err(|e| { - warn!("Missing or invalid Authorization header: {}", e); + warn!( + "Missing or invalid Authorization header: {}", + e + ); let msg = match e.reason() { TypedHeaderRejectionReason::Missing => { @@ -189,7 +212,9 @@ where let mut request_map = BTreeMap::from_iter([ ( "method".to_owned(), - CanonicalJsonValue::String(parts.method.to_string()), + CanonicalJsonValue::String( + parts.method.to_string(), + ), ), ( "uri".to_owned(), @@ -197,12 +222,18 @@ where ), ( "origin".to_owned(), - CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), + CanonicalJsonValue::String( + x_matrix.origin.as_str().to_owned(), + ), ), ( "destination".to_owned(), CanonicalJsonValue::String( - services().globals.server_name().as_str().to_owned(), + services() + .globals + .server_name() + .as_str() + .to_owned(), ), ), ( @@ -212,13 +243,17 @@ where ]); if let Some(json_body) = &json_body { - request_map.insert("content".to_owned(), json_body.clone()); + request_map + .insert("content".to_owned(), json_body.clone()); }; let keys_result = services() .rooms .event_handler - .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.clone()]) + .fetch_signing_keys( + &x_matrix.origin, + vec![x_matrix.key.clone()], + ) .await; let keys = match keys_result { @@ -232,22 +267,29 @@ where } }; - let pub_key_map = - BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); + let pub_key_map = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + keys, + )]); - match ruma::signatures::verify_json(&pub_key_map, &request_map) { + match ruma::signatures::verify_json( + &pub_key_map, + &request_map, + ) { Ok(()) => (None, None, Some(x_matrix.origin), None), Err(e) => { warn!( - "Failed to verify json request from {}: {}\n{:?}", + "Failed to verify json request from {}: \ + {}\n{:?}", x_matrix.origin, e, request_map ); if parts.uri.to_string().contains('@') { warn!( - "Request uri contained '@' character. Make sure your \ - reverse proxy gives Grapevine the raw uri (apache: use \ - nocanon)" + "Request uri contained '@' character. \ + Make sure your reverse proxy gives \ + Grapevine the raw uri (apache: use \ + nocanon)" ); } @@ -264,27 +306,36 @@ where | AuthScheme::AccessTokenOptional, Token::None, ) => (None, None, None, None), - (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { + ( + AuthScheme::ServerSignatures, + Token::Appservice(_) | Token::User(_), + ) => { return Err(Error::BadRequest( ErrorKind::Unauthorized, - "Only server signatures should be used on this endpoint.", + "Only server signatures should be used on this \ + endpoint.", )); } (AuthScheme::AppserviceToken, Token::User(_)) => { return Err(Error::BadRequest( ErrorKind::Unauthorized, - "Only appservice access tokens should be used on this endpoint.", + "Only appservice access tokens should be used on this \ + endpoint.", )); } }; - let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); + let mut http_request = + http::Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid") + UserId::parse_with_server_name( + "", + services().globals.server_name(), + ) + .expect("we know this is valid") }); let uiaa_request = json_body @@ -300,14 +351,17 @@ where ) }); - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + if let Some(CanonicalJsonValue::Object(initial_request)) = + uiaa_request + { for (key, value) in initial_request { json_body.entry(key).or_insert(value); } } let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); + serde_json::to_writer(&mut buf, json_body) + .expect("value serialization can't fail"); body = buf.into_inner().freeze(); } @@ -315,11 +369,15 @@ where debug!("{:?}", http_request); - let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { - warn!("try_from_http_request failed: {:?}", e); - debug!("JSON body: {:?}", json_body); - Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") - })?; + let body = T::try_from_http_request(http_request, &path_params) + .map_err(|e| { + warn!("try_from_http_request failed: {:?}", e); + debug!("JSON body: {:?}", json_body); + Error::BadRequest( + ErrorKind::BadJson, + "Failed to deserialize request.", + ) + })?; Ok(Ruma { body, @@ -345,7 +403,8 @@ impl Credentials for XMatrix { fn decode(value: &http::HeaderValue) -> Option { debug_assert!( value.as_bytes().starts_with(b"X-Matrix "), - "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", + "HeaderValue to decode should start with \"X-Matrix ..\", \ + received = {value:?}", ); let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) @@ -359,8 +418,9 @@ impl Credentials for XMatrix { for entry in parameters.split_terminator(',') { let (name, value) = entry.split_once('=')?; - // It's not at all clear why some fields are quoted and others not in the spec, - // let's simply accept either form for every field. + // It's not at all clear why some fields are quoted and others not + // in the spec, let's simply accept either form for + // every field. let value = value .strip_prefix('"') .and_then(|rest| rest.strip_suffix('"')) diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 86c34276..5ae6af46 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1,14 +1,17 @@ #![allow(deprecated)] -use crate::{ - api::client_server::{self, claim_keys_helper, get_keys_helper}, - service::pdu::{gen_event_id_canonical_json, PduBuilder}, - services, utils, Error, PduEvent, Result, Ruma, +use std::{ + collections::BTreeMap, + fmt::Debug, + mem, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::{Duration, Instant, SystemTime}, }; + use axum::{response::IntoResponse, Json}; use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION}; - use ruma::{ api::{ client::error::{Error as RumaError, ErrorKind}, @@ -17,18 +20,29 @@ use ruma::{ backfill::get_backfill, device::get_devices::{self, v1::UserDevice}, directory::{get_public_rooms, get_public_rooms_filtered}, - discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey}, - event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, + discovery::{ + get_server_keys, get_server_version, ServerSigningKeys, + VerifyKey, + }, + event::{ + get_event, get_missing_events, get_room_state, + get_room_state_ids, + }, keys::{claim_keys, get_keys}, - membership::{create_invite, create_join_event, prepare_join_event}, + membership::{ + create_invite, create_join_event, prepare_join_event, + }, query::{get_profile_information, get_room_information}, transactions::{ - edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, + edu::{ + DeviceListUpdateContent, DirectDeviceContent, Edu, + SigningKeyUpdateContent, + }, send_transaction_message, }, }, - EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, - SendAccessToken, + EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, + OutgoingResponse, SendAccessToken, }, directory::{Filter, RoomNetwork}, events::{ @@ -41,28 +55,25 @@ use ruma::{ }, serde::{Base64, JsonObject, Raw}, to_device::DeviceIdOrAllDevices, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, - ServerName, + uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, + MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedServerName, + OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::BTreeMap, - fmt::Debug, - mem, - net::{IpAddr, SocketAddr}, - sync::Arc, - time::{Duration, Instant, SystemTime}, -}; use tokio::sync::RwLock; - use tracing::{debug, error, warn}; +use crate::{ + api::client_server::{self, claim_keys_helper, get_keys_helper}, + service::pdu::{gen_event_id_canonical_json, PduBuilder}, + services, utils, Error, PduEvent, Result, Ruma, +}; + /// Wraps either an literal IP address plus port, or a hostname plus complement /// (colon-plus-port if it was specified). /// -/// Note: A [`FedDest::Named`] might contain an IP address in string form if there -/// was no port specified to construct a [`SocketAddr`] with. +/// Note: A [`FedDest::Named`] might contain an IP address in string form if +/// there was no port specified to construct a [`SocketAddr`] with. /// /// # Examples: /// ```rust @@ -107,7 +118,9 @@ impl FedDest { fn port(&self) -> Option { match &self { Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port.strip_prefix(':').and_then(|x| x.parse().ok()), + Self::Named(_, port) => { + port.strip_prefix(':').and_then(|x| x.parse().ok()) + } } } } @@ -178,7 +191,8 @@ where ); }; - request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); + request_map + .insert("method".to_owned(), T::METADATA.method.to_string().into()); request_map.insert( "uri".to_owned(), http_request @@ -194,8 +208,8 @@ where ); request_map.insert("destination".to_owned(), destination.as_str().into()); - let mut request_json = - serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); + let mut request_json = serde_json::from_value(request_map.into()) + .expect("valid JSON is valid BTreeMap"); ruma::signatures::sign_json( services().globals.server_name().as_str(), @@ -205,17 +219,12 @@ where .expect("our request json is what ruma expects"); let request_json: serde_json::Map = - serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); + serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()) + .unwrap(); - let signatures = request_json["signatures"] - .as_object() - .unwrap() - .values() - .map(|v| { - v.as_object() - .unwrap() - .iter() - .map(|(k, v)| (k, v.as_str().unwrap())) + let signatures = + request_json["signatures"].as_object().unwrap().values().map(|v| { + v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap())) }); for signature_server in signatures { @@ -238,11 +247,8 @@ where let url = reqwest_request.url().clone(); debug!("Sending request to {destination} at {url}"); - let response = services() - .globals - .federation_client() - .execute(reqwest_request) - .await; + let response = + services().globals.federation_client().execute(reqwest_request).await; debug!("Received response from {destination} at {url}"); match response { @@ -285,7 +291,8 @@ where if status == 200 { debug!("Parsing response bytes from {destination}"); - let response = T::IncomingResponse::try_from_http_response(http_response); + let response = + T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && write_destination_to_cache { services() .globals @@ -303,7 +310,9 @@ where "Invalid 200 response from {} on: {} {}", &destination, url, e ); - Error::BadServerResponse("Server returned bad 200 response.") + Error::BadServerResponse( + "Server returned bad 200 response.", + ) }) } else { debug!("Returning error from {destination}"); @@ -343,9 +352,12 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest { /// Returns: `actual_destination`, `Host` header /// Implemented according to the specification at -/// Numbers in comments below refer to bullet points in linked section of specification +/// Numbers in comments below refer to bullet points in linked section of +/// specification #[allow(clippy::too_many_lines)] -async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { +async fn find_actual_destination( + destination: &'_ ServerName, +) -> (FedDest, FedDest) { debug!("Finding actual destination for {destination}"); let destination_str = destination.as_str().to_owned(); let mut hostname = destination_str.clone(); @@ -361,10 +373,15 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe FedDest::Named(host.to_owned(), port.to_owned()) } else { debug!("Requesting well known for {destination}"); - if let Some(delegated_hostname) = request_well_known(destination.as_str()).await { + if let Some(delegated_hostname) = + request_well_known(destination.as_str()).await + { debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - if let Some(host_and_port) = get_ip_with_port(&delegated_hostname) { + hostname = add_port_to_hostname(&delegated_hostname) + .into_uri_string(); + if let Some(host_and_port) = + get_ip_with_port(&delegated_hostname) + { host_and_port } else if let Some(pos) = delegated_hostname.find(':') { debug!("3.2: Hostname with port in .well-known file"); @@ -372,7 +389,8 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe FedDest::Named(host.to_owned(), port.to_owned()) } else { debug!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = query_srv_record(&delegated_hostname).await + if let Some(hostname_override) = + query_srv_record(&delegated_hostname).await { debug!("3.3: SRV lookup successful"); let force_port = hostname_override.port(); @@ -390,25 +408,39 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe .unwrap() .insert( delegated_hostname.clone(), - (override_ip.iter().collect(), force_port.unwrap_or(8448)), + ( + override_ip.iter().collect(), + force_port.unwrap_or(8448), + ), ); } else { - warn!("Using SRV record, but could not resolve to IP"); + warn!( + "Using SRV record, but could not resolve \ + to IP" + ); } if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{port}")) + FedDest::Named( + delegated_hostname, + format!(":{port}"), + ) } else { add_port_to_hostname(&delegated_hostname) } } else { - debug!("3.4: No SRV records, just use the hostname from .well-known"); + debug!( + "3.4: No SRV records, just use the hostname \ + from .well-known" + ); add_port_to_hostname(&delegated_hostname) } } } else { debug!("4: No .well-known or an error occured"); - if let Some(hostname_override) = query_srv_record(&destination_str).await { + if let Some(hostname_override) = + query_srv_record(&destination_str).await + { debug!("4: SRV record found"); let force_port = hostname_override.port(); @@ -425,10 +457,15 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe .unwrap() .insert( hostname.clone(), - (override_ip.iter().collect(), force_port.unwrap_or(8448)), + ( + override_ip.iter().collect(), + force_port.unwrap_or(8448), + ), ); } else { - warn!("Using SRV record, but could not resolve to IP"); + warn!( + "Using SRV record, but could not resolve to IP" + ); } if let Some(port) = force_port { @@ -470,7 +507,11 @@ async fn query_given_srv_record(record: &str) -> Option { .map(|srv| { srv.iter().next().map(|result| { FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), + result + .target() + .to_string() + .trim_end_matches('.') + .to_owned(), format!(":{}", result.port()), ) }) @@ -481,7 +522,8 @@ async fn query_given_srv_record(record: &str) -> Option { async fn query_srv_record(hostname: &'_ str) -> Option { let hostname = hostname.trim_end_matches('.'); - if let Some(host_port) = query_given_srv_record(&format!("_matrix-fed._tcp.{hostname}.")).await + if let Some(host_port) = + query_given_srv_record(&format!("_matrix-fed._tcp.{hostname}.")).await { Some(host_port) } else { @@ -525,17 +567,22 @@ pub(crate) async fn get_server_version_route( /// /// Gets the public signing keys of this server. /// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// - Matrix does not support invalidating public keys, so the key returned by +/// this will be valid /// forever. -// Response type for this endpoint is Json because we need to calculate a signature for the response +// Response type for this endpoint is Json because we need to calculate a +// signature for the response pub(crate) async fn get_server_keys_route() -> Result { - let mut verify_keys: BTreeMap = BTreeMap::new(); + let mut verify_keys: BTreeMap = + BTreeMap::new(); verify_keys.insert( format!("ed25519:{}", services().globals.keypair().version()) .try_into() .expect("found invalid server signing keys in DB"), VerifyKey { - key: Base64::new(services().globals.keypair().public_key().to_vec()), + key: Base64::new( + services().globals.keypair().public_key().to_vec(), + ), }, ); let mut response = serde_json::from_slice( @@ -572,7 +619,8 @@ pub(crate) async fn get_server_keys_route() -> Result { /// /// Gets the public signing keys of this server. /// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// - Matrix does not support invalidating public keys, so the key returned by +/// this will be valid /// forever. pub(crate) async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_server_keys_route().await @@ -627,10 +675,11 @@ pub(crate) async fn get_public_rooms_route( pub(crate) fn parse_incoming_pdu( pdu: &RawJsonValue, ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let value: CanonicalJsonObject = + serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; let room_id: OwnedRoomId = value .get("room_id") @@ -642,7 +691,9 @@ pub(crate) fn parse_incoming_pdu( let room_version_id = services().rooms.state.get_room_version(&room_id)?; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + let Ok((event_id, value)) = + gen_event_id_canonical_json(pdu, &room_version_id) + else { // Event could not be converted to canonical json return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -659,20 +710,19 @@ pub(crate) fn parse_incoming_pdu( pub(crate) async fn send_transaction_message_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); let mut resolved_map = BTreeMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in &body.pdus { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; let room_id: OwnedRoomId = value .get("room_id") .and_then(|id| RoomId::parse(id.as_str()?).ok()) @@ -695,7 +745,8 @@ pub(crate) async fn send_transaction_message_route( continue; } }; - // We do not add the event_id field to the pdu here because of signature and hashes checks + // We do not add the event_id field to the pdu here because of signature + // and hashes checks let mutex = Arc::clone( services() @@ -767,13 +818,15 @@ pub(crate) async fn send_transaction_message_route( .max_by_key(|(_, count)| *count) { let mut user_receipts = BTreeMap::new(); - user_receipts.insert(user_id.clone(), user_updates.data); + user_receipts + .insert(user_id.clone(), user_updates.data); let mut receipts = BTreeMap::new(); receipts.insert(ReceiptType::Read, user_receipts); let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event_id.to_owned(), receipts); + receipt_content + .insert(event_id.to_owned(), receipts); let event = ReceiptEvent { content: ReceiptEventContent(receipt_content), @@ -783,10 +836,15 @@ pub(crate) async fn send_transaction_message_route( .rooms .edus .read_receipt - .readreceipt_update(&user_id, &room_id, event)?; + .readreceipt_update( + &user_id, &room_id, event, + )?; } else { // TODO fetch missing events - debug!("No known event ids in read receipt: {:?}", user_updates); + debug!( + "No known event ids in read receipt: {:?}", + user_updates + ); } } } @@ -818,7 +876,10 @@ pub(crate) async fn send_transaction_message_route( } } } - Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { + Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id, + .. + }) => { services().users.mark_device_key_update(&user_id)?; } Edu::DirectToDevice(DirectDeviceContent { @@ -839,14 +900,19 @@ pub(crate) async fn send_transaction_message_route( for (target_user_id, map) in &messages { for (target_device_id_maybe, event) in map { match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => { + DeviceIdOrAllDevices::DeviceId( + target_device_id, + ) => { services().users.add_to_device_event( &sender, target_user_id, target_device_id, &ev_type.to_string(), event.deserialize_as().map_err(|e| { - warn!("To-Device event is invalid: {event:?} {e}"); + warn!( + "To-Device event is invalid: \ + {event:?} {e}" + ); Error::BadRequest( ErrorKind::InvalidParam, "Event is invalid", @@ -856,20 +922,23 @@ pub(crate) async fn send_transaction_message_route( } DeviceIdOrAllDevices::AllDevices => { - for target_device_id in - services().users.all_device_ids(target_user_id) + for target_device_id in services() + .users + .all_device_ids(target_user_id) { services().users.add_to_device_event( &sender, target_user_id, &target_device_id?, &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, + event.deserialize_as().map_err( + |_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + }, + )?, )?; } } @@ -878,9 +947,12 @@ pub(crate) async fn send_transaction_message_route( } // Save transaction id with empty data - services() - .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; + services().transaction_ids.add_txnid( + &sender, + None, + &message_id, + &[], + )?; } Edu::SigningKeyUpdate(SigningKeyUpdateContent { user_id, @@ -916,31 +988,30 @@ pub(crate) async fn send_transaction_message_route( /// /// Retrieves a single event from the server. /// -/// - Only works if a user of this server is currently invited or joined the room +/// - Only works if a user of this server is currently invited or joined the +/// room pub(crate) async fn get_event_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); - let event = services() - .rooms - .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = + services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else( + || { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + }, + )?; let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database("Invalid room id field in event in database") + })?; if !services() .rooms @@ -978,10 +1049,8 @@ pub(crate) async fn get_event_route( pub(crate) async fn get_backfill_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); debug!("Got backfill request from: {}", sender_servername); @@ -1050,10 +1119,8 @@ pub(crate) async fn get_backfill_route( pub(crate) async fn get_missing_events_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); if !services() .rooms @@ -1075,19 +1142,28 @@ pub(crate) async fn get_missing_events_route( let mut events = Vec::new(); let mut i = 0; - while i < queued_events.len() && events.len() < body.limit.try_into().unwrap_or(usize::MAX) { - if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { - let room_id_str = pdu - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + while i < queued_events.len() + && events.len() < body.limit.try_into().unwrap_or(usize::MAX) + { + if let Some(pdu) = + services().rooms.timeline.get_pdu_json(&queued_events[i])? + { + let room_id_str = + pdu.get("room_id").and_then(|val| val.as_str()).ok_or_else( + || Error::bad_database("Invalid event in database"), + )?; - let event_room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let event_room_id = + <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database( + "Invalid room id field in event in database", + ) + })?; if event_room_id != body.room_id { warn!( - "Evil event detected: Event {} found while searching in room {}", + "Evil event detected: Event {} found while searching in \ + room {}", queued_events[i], body.room_id ); return Err(Error::BadRequest( @@ -1112,19 +1188,29 @@ pub(crate) async fn get_missing_events_route( queued_events.extend_from_slice( &serde_json::from_value::>( - serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { - Error::bad_database("Event in db has no prev_events field.") - })?) + serde_json::to_value( + pdu.get("prev_events").cloned().ok_or_else(|| { + Error::bad_database( + "Event in db has no prev_events field.", + ) + })?, + ) .expect("canonical json is valid json value"), ) - .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, + .map_err(|_| { + Error::bad_database( + "Invalid prev_events content in pdu in db.", + ) + })?, ); events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); } i += 1; } - Ok(get_missing_events::v1::Response { events }) + Ok(get_missing_events::v1::Response { + events, + }) } /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` @@ -1135,10 +1221,8 @@ pub(crate) async fn get_missing_events_route( pub(crate) async fn get_event_authorization_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); if !services() .rooms @@ -1156,22 +1240,22 @@ pub(crate) async fn get_event_authorization_route( .event_handler .acl_check(sender_servername, &body.room_id)?; - let event = services() - .rooms - .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = + services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else( + || { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + }, + )?; let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database("Invalid room id field in event in database") + })?; let auth_chain_ids = services() .rooms @@ -1181,7 +1265,9 @@ pub(crate) async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) + .filter_map(|id| { + services().rooms.timeline.get_pdu_json(&id).ok()? + }) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1193,10 +1279,8 @@ pub(crate) async fn get_event_authorization_route( pub(crate) async fn get_room_state_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); if !services() .rooms @@ -1231,12 +1315,7 @@ pub(crate) async fn get_room_state_route( .into_values() .map(|id| { PduEvent::convert_to_outgoing_federation_event( - services() - .rooms - .timeline - .get_pdu_json(&id) - .unwrap() - .unwrap(), + services().rooms.timeline.get_pdu_json(&id).unwrap().unwrap(), ) }) .collect(); @@ -1250,7 +1329,9 @@ pub(crate) async fn get_room_state_route( Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids .filter_map(|id| { - if let Some(json) = services().rooms.timeline.get_pdu_json(&id).ok()? { + if let Some(json) = + services().rooms.timeline.get_pdu_json(&id).ok()? + { Some(PduEvent::convert_to_outgoing_federation_event(json)) } else { error!("Could not find event json for {id} in db."); @@ -1268,10 +1349,8 @@ pub(crate) async fn get_room_state_route( pub(crate) async fn get_room_state_ids_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); if !services() .rooms @@ -1332,10 +1411,8 @@ pub(crate) async fn create_join_event_template_route( )); } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); services() .rooms @@ -1353,22 +1430,26 @@ pub(crate) async fn create_join_event_template_route( ); let state_lock = mutex_state.lock().await; - // TODO: Grapevine does not implement restricted join rules yet, we always reject + // TODO: Grapevine does not implement restricted join rules yet, we always + // reject let join_rules_event = services().rooms.state_accessor.room_state_get( &body.room_id, &StateEventType::RoomJoinRules, "", )?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") + let join_rules_event_content: Option = + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err( + |e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }, + ) }) - }) - .transpose()?; + .transpose()?; if let Some(join_rules_event_content) = join_rules_event_content { if matches!( @@ -1382,7 +1463,8 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + let room_version_id = + services().rooms.state.get_room_version(&body.room_id)?; if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { @@ -1404,18 +1486,19 @@ pub(crate) async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = + services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + )?; drop(state_lock); @@ -1423,7 +1506,8 @@ pub(crate) async fn create_join_event_template_route( Ok(prepare_join_event::v1::Response { room_version: Some(room_version_id), - event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), + event: to_raw_value(&pdu_json) + .expect("CanonicalJson can be serialized to JSON"), }) } @@ -1440,27 +1524,28 @@ async fn create_join_event( )); } - services() - .rooms - .event_handler - .acl_check(sender_servername, room_id)?; + services().rooms.event_handler.acl_check(sender_servername, room_id)?; - // TODO: Grapevine does not implement restricted join rules yet, we always reject + // TODO: Grapevine does not implement restricted join rules yet, we always + // reject let join_rules_event = services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomJoinRules, "", )?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") + let join_rules_event_content: Option = + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err( + |e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }, + ) }) - }) - .transpose()?; + .transpose()?; if let Some(join_rules_event_content) = join_rules_event_content { if matches!( @@ -1474,21 +1559,21 @@ async fn create_join_event( } } - // We need to return the state prior to joining, let's keep a reference to that here - let shortstatehash = services() - .rooms - .state - .get_room_shortstatehash(room_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; + // We need to return the state prior to joining, let's keep a reference to + // that here + let shortstatehash = + services().rooms.state.get_room_shortstatehash(room_id)?.ok_or( + Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."), + )?; let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and hashes checks + // We do not add the event_id field to the pdu here because of signature and + // hashes checks let room_version_id = services().rooms.state.get_room_version(room_id)?; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + let Ok((event_id, value)) = + gen_event_id_canonical_json(pdu, &room_version_id) + else { // Event could not be converted to canonical json return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -1503,7 +1588,9 @@ async fn create_join_event( ))?) .expect("CanonicalJson is valid json value"), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid.") + })?; let mutex = Arc::clone( services() @@ -1518,7 +1605,14 @@ async fn create_join_event( let pdu_id: Vec = services() .rooms .event_handler - .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .handle_incoming_pdu( + &origin, + &event_id, + room_id, + value, + true, + &pub_key_map, + ) .await? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, @@ -1526,11 +1620,8 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; + let state_ids = + services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let auth_chain_ids = services() .rooms .auth_chain @@ -1548,12 +1639,16 @@ async fn create_join_event( Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .filter_map(|id| { + services().rooms.timeline.get_pdu_json(&id).ok().flatten() + }) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() - .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .filter_map(|(_, id)| { + services().rooms.timeline.get_pdu_json(id).ok().flatten() + }) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), // TODO: handle restricted joins @@ -1567,14 +1662,15 @@ async fn create_join_event( pub(crate) async fn create_join_event_v1_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); - let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + let room_state = + create_join_event(sender_servername, &body.room_id, &body.pdu).await?; - Ok(create_join_event::v1::Response { room_state }) + Ok(create_join_event::v1::Response { + room_state, + }) } /// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}` @@ -1583,10 +1679,8 @@ pub(crate) async fn create_join_event_v1_route( pub(crate) async fn create_join_event_v2_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); let create_join_event::v1::RoomState { auth_chain, @@ -1601,19 +1695,20 @@ pub(crate) async fn create_join_event_v2_route( servers_in_room: None, }; - Ok(create_join_event::v2::Response { room_state }) + Ok(create_join_event::v2::Response { + room_state, + }) } /// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` /// /// Invites a remote user to a room. +#[allow(clippy::too_many_lines)] pub(crate) async fn create_invite_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); services() .rooms @@ -1633,8 +1728,13 @@ pub(crate) async fn create_invite_route( )); } - let mut signed_event = utils::to_canonical_object(&body.event) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; + let mut signed_event = + utils::to_canonical_object(&body.event).map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invite event is invalid.", + ) + })?; ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), @@ -1642,7 +1742,9 @@ pub(crate) async fn create_invite_route( &mut signed_event, &body.room_version, ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event.") + })?; // Generate event id let event_id = EventId::parse(format!( @@ -1668,7 +1770,9 @@ pub(crate) async fn create_invite_route( .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id.") + })?; let invited_user: Box<_> = serde_json::from_value( signed_event @@ -1680,12 +1784,22 @@ pub(crate) async fn create_invite_route( .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "state_key is not a user id.", + ) + })?; let mut invite_state = body.invite_room_state.clone(); let mut event: JsonObject = serde_json::from_str(body.event.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid invite event bytes.", + ) + })?; event.insert("event_id".to_owned(), "$dummy".into()); @@ -1696,7 +1810,8 @@ pub(crate) async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); - // If we are active in the room, the remote server will notify us about the join via /send + // If we are active in the room, the remote server will notify us about the + // join via /send if !services() .rooms .state_cache @@ -1730,10 +1845,8 @@ pub(crate) async fn get_devices_route( )); } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = + body.sender_servername.as_ref().expect("server is authenticated"); Ok(get_devices::v1::Response { user_id: body.user_id.clone(), @@ -1758,14 +1871,16 @@ pub(crate) async fn get_devices_route( }) }) .collect(), - master_key: services().users.get_master_key(None, &body.user_id, &|u| { - u.server_name() == sender_servername - })?, - self_signing_key: services() - .users - .get_self_signing_key(None, &body.user_id, &|u| { - u.server_name() == sender_servername - })?, + master_key: services().users.get_master_key( + None, + &body.user_id, + &|u| u.server_name() == sender_servername, + )?, + self_signing_key: services().users.get_self_signing_key( + None, + &body.user_id, + &|u| u.server_name() == sender_servername, + )?, }) } @@ -1775,14 +1890,10 @@ pub(crate) async fn get_devices_route( pub(crate) async fn get_room_information_route( body: Ruma, ) -> Result { - let room_id = services() - .rooms - .alias - .resolve_local_alias(&body.room_alias)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Room alias not found.", - ))?; + let room_id = + services().rooms.alias.resolve_local_alias(&body.room_alias)?.ok_or( + Error::BadRequest(ErrorKind::NotFound, "Room alias not found."), + )?; Ok(get_room_information::v1::Response { room_id, diff --git a/src/config.rs b/src/config.rs index 0f9a9c9c..2ea4c313 100644 --- a/src/config.rs +++ b/src/config.rs @@ -108,7 +108,10 @@ impl Config { } if was_deprecated { - warn!("Read grapevine documentation and check your configuration if any new configuration parameters should be adjusted"); + warn!( + "Read grapevine documentation and check your configuration if \ + any new configuration parameters should be adjusted" + ); } } } diff --git a/src/config/proxy.rs b/src/config/proxy.rs index 91ada136..bf365265 100644 --- a/src/config/proxy.rs +++ b/src/config/proxy.rs @@ -24,9 +24,10 @@ use crate::Result; /// ## Include vs. Exclude /// If include is an empty list, it is assumed to be `["*"]`. /// -/// If a domain matches both the exclude and include list, the proxy will only be used if it was -/// included because of a more specific rule than it was excluded. In the above example, the proxy -/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. +/// If a domain matches both the exclude and include list, the proxy will only +/// be used if it was included because of a more specific rule than it was +/// excluded. In the above example, the proxy would be used for +/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. #[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "snake_case")] #[derive(Default)] @@ -43,7 +44,9 @@ impl ProxyConfig { pub(crate) fn to_proxy(&self) -> Result> { Ok(match self.clone() { ProxyConfig::None => None, - ProxyConfig::Global { url } => Some(Proxy::all(url)?), + ProxyConfig::Global { + url, + } => Some(Proxy::all(url)?), ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { // first matching proxy proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() @@ -112,25 +115,32 @@ impl WildCardedDomain { WildCardedDomain::Exact(d) => domain == d, } } + pub(crate) fn more_specific_than(&self, other: &Self) -> bool { match (self, other) { (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { - a != b && a.ends_with(b) + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => { + other.matches(a) } + ( + WildCardedDomain::WildCarded(a), + WildCardedDomain::WildCarded(b), + ) => a != b && a.ends_with(b), _ => false, } } } impl std::str::FromStr for WildCardedDomain { type Err = std::convert::Infallible; + fn from_str(s: &str) -> Result { // maybe do some domain validation? Ok(s.strip_prefix("*.") .map(|x| WildCardedDomain::WildCarded(x.to_owned())) - .or_else(|| (s == "*").then(|| WildCardedDomain::WildCarded(String::new()))) + .or_else(|| { + (s == "*").then(|| WildCardedDomain::WildCarded(String::new())) + }) .unwrap_or_else(|| WildCardedDomain::Exact(s.to_owned()))) } } diff --git a/src/database.rs b/src/database.rs index ea964a52..ffbf6ef2 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,23 +1,6 @@ pub(crate) mod abstraction; pub(crate) mod key_value; -use crate::{ - service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services, - SERVICES, -}; -use abstraction::{KeyValueDatabaseEngine, KvTree}; -use lru_cache::LruCache; - -use ruma::{ - events::{ - push_rules::{PushRulesEvent, PushRulesEventContent}, - room::message::RoomMessageEventContent, - GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, - }, - push::Ruleset, - CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, - UserId, -}; use std::{ collections::{BTreeMap, HashMap, HashSet}, fs, @@ -27,8 +10,25 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; +use abstraction::{KeyValueDatabaseEngine, KvTree}; +use lru_cache::LruCache; +use ruma::{ + events::{ + push_rules::{PushRulesEvent, PushRulesEventContent}, + room::message::RoomMessageEventContent, + GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, + }, + push::Ruleset, + CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, + OwnedUserId, RoomId, UserId, +}; use tracing::{debug, error, info, warn}; +use crate::{ + service::rooms::timeline::PduCount, services, utils, Config, Error, + PduEvent, Result, Services, SERVICES, +}; + pub(crate) struct KeyValueDatabase { db: Arc, @@ -74,8 +74,9 @@ pub(crate) struct KeyValueDatabase { // Trees "owned" by `self::key_value::uiaa` // User-interactive authentication pub(super) userdevicesessionid_uiaainfo: Arc, - pub(super) userdevicesessionid_uiaarequest: - RwLock>, + pub(super) userdevicesessionid_uiaarequest: RwLock< + BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>, + >, // Trees "owned" by `self::key_value::rooms::edus` // ReadReceiptId = RoomId + Count + UserId @@ -169,13 +170,15 @@ pub(crate) struct KeyValueDatabase { pub(super) statehash_shortstatehash: Arc, - // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) + // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + + // (shortstatekey+shorteventid--) pub(super) shortstatehash_statediff: Arc, pub(super) shorteventid_authchain: Arc, /// RoomId + EventId -> outlier PDU. - /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. + /// Any pdu that has passed the steps 1-8 in the incoming event + /// /federation/send/txn. pub(super) eventid_outlierpdu: Arc, pub(super) softfailedeventids: Arc, @@ -214,10 +217,12 @@ pub(crate) struct KeyValueDatabase { // EduCount: Count of last EDU sync pub(super) servername_educount: Arc, - // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content + // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id + // (for edus), Data = EDU content pub(super) servernameevent_data: Arc, - // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content + // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for + // edus), Data = EDU content pub(super) servercurrentevent_data: Arc, // Trees "owned" by `self::key_value::appservice` @@ -231,10 +236,14 @@ pub(crate) struct KeyValueDatabase { pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, pub(super) eventidshort_cache: Mutex>, - pub(super) statekeyshort_cache: Mutex>, - pub(super) shortstatekey_cache: Mutex>, - pub(super) our_real_users_cache: RwLock>>>, - pub(super) appservice_in_room_cache: RwLock>>, + pub(super) statekeyshort_cache: + Mutex>, + pub(super) shortstatekey_cache: + Mutex>, + pub(super) our_real_users_cache: + RwLock>>>, + pub(super) appservice_in_room_cache: + RwLock>>, pub(super) lasttimelinecount_cache: Mutex>, } @@ -271,13 +280,15 @@ impl KeyValueDatabase { if sqlite_exists && config.database_backend != "sqlite" { return Err(Error::bad_config( - "Found sqlite at database_path, but is not specified in config.", + "Found sqlite at database_path, but is not specified in \ + config.", )); } if rocksdb_exists && config.database_backend != "rocksdb" { return Err(Error::bad_config( - "Found rocksdb at database_path, but is not specified in config.", + "Found rocksdb at database_path, but is not specified in \ + config.", )); } @@ -294,19 +305,30 @@ impl KeyValueDatabase { Self::check_db_setup(&config)?; if !Path::new(&config.database_path).exists() { - std::fs::create_dir_all(&config.database_path) - .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; + std::fs::create_dir_all(&config.database_path).map_err(|_| { + Error::BadConfig( + "Database folder doesn't exists and couldn't be created \ + (e.g. due to missing permissions). Please create the \ + database folder yourself.", + ) + })?; } #[cfg_attr( not(any(feature = "rocksdb", feature = "sqlite")), allow(unused_variables) )] - let builder: Arc = match &*config.database_backend { + let builder: Arc = match &*config + .database_backend + { #[cfg(feature = "sqlite")] - "sqlite" => Arc::new(Arc::::open(&config)?), + "sqlite" => { + Arc::new(Arc::::open(&config)?) + } #[cfg(feature = "rocksdb")] - "rocksdb" => Arc::new(Arc::::open(&config)?), + "rocksdb" => { + Arc::new(Arc::::open(&config)?) + } _ => { return Err(Error::BadConfig("Database backend not found.")); } @@ -327,28 +349,38 @@ impl KeyValueDatabase { userid_avatarurl: builder.open_tree("userid_avatarurl")?, userid_blurhash: builder.open_tree("userid_blurhash")?, userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, + userdeviceid_metadata: builder + .open_tree("userdeviceid_metadata")?, + userid_devicelistversion: builder + .open_tree("userid_devicelistversion")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, + onetimekeyid_onetimekeys: builder + .open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: builder + .open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?, keyid_key: builder.open_tree("keyid_key")?, userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + userid_selfsigningkeyid: builder + .open_tree("userid_selfsigningkeyid")?, + userid_usersigningkeyid: builder + .open_tree("userid_usersigningkeyid")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaainfo: builder + .open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, + readreceiptid_readreceipt: builder + .open_tree("readreceiptid_readreceipt")?, // "Private" read receipt - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, + roomuserid_privateread: builder + .open_tree("roomuserid_privateread")?, roomuserid_lastprivatereadupdate: builder .open_tree("roomuserid_lastprivatereadupdate")?, presenceid_presence: builder.open_tree("presenceid_presence")?, - userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, + userid_lastpresenceupdate: builder + .open_tree("userid_lastpresenceupdate")?, pduid_pdu: builder.open_tree("pduid_pdu")?, eventid_pduid: builder.open_tree("eventid_pduid")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, @@ -367,9 +399,12 @@ impl KeyValueDatabase { roomuserid_joined: builder.open_tree("roomuserid_joined")?, roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + roomuseroncejoinedids: builder + .open_tree("roomuseroncejoinedids")?, + userroomid_invitestate: builder + .open_tree("userroomid_invitestate")?, + roomuserid_invitecount: builder + .open_tree("roomuserid_invitecount")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, @@ -377,41 +412,57 @@ impl KeyValueDatabase { lazyloadedids: builder.open_tree("lazyloadedids")?, - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + userroomid_notificationcount: builder + .open_tree("userroomid_notificationcount")?, + userroomid_highlightcount: builder + .open_tree("userroomid_highlightcount")?, + roomuserid_lastnotificationread: builder + .open_tree("userroomid_highlightcount")?, - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + statekey_shortstatekey: builder + .open_tree("statekey_shortstatekey")?, + shortstatekey_statekey: builder + .open_tree("shortstatekey_statekey")?, - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + shorteventid_authchain: builder + .open_tree("shorteventid_authchain")?, roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, + shortstatehash_statediff: builder + .open_tree("shortstatehash_statediff")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + shorteventid_shortstatehash: builder + .open_tree("shorteventid_shortstatehash")?, + roomid_shortstatehash: builder + .open_tree("roomid_shortstatehash")?, + roomsynctoken_shortstatehash: builder + .open_tree("roomsynctoken_shortstatehash")?, + statehash_shortstatehash: builder + .open_tree("statehash_shortstatehash")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, softfailedeventids: builder.open_tree("softfailedeventids")?, tofrom_relation: builder.open_tree("tofrom_relation")?, referencedevents: builder.open_tree("referencedevents")?, - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, + roomuserdataid_accountdata: builder + .open_tree("roomuserdataid_accountdata")?, + roomusertype_roomuserdataid: builder + .open_tree("roomusertype_roomuserdataid")?, mediaid_file: builder.open_tree("mediaid_file")?, backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_etag: builder.open_tree("backupid_etag")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, + userdevicetxnid_response: builder + .open_tree("userdevicetxnid_response")?, servername_educount: builder.open_tree("servername_educount")?, servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, + servercurrentevent_data: builder + .open_tree("servercurrentevent_data")?, + id_appserviceregistrations: builder + .open_tree("id_appserviceregistrations")?, senderkey_pusher: builder.open_tree("senderkey_pusher")?, global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, @@ -489,11 +540,13 @@ impl KeyValueDatabase { if !services().users.exists(&grapevine_user)? { error!( - "The {} server user does not exist, and the database is not new.", + "The {} server user does not exist, and the database is \ + not new.", grapevine_user ); return Err(Error::bad_database( - "Cannot reuse an existing database after changing the server name, please delete the old one first." + "Cannot reuse an existing database after changing the \ + server name, please delete the old one first.", )); } } @@ -505,14 +558,15 @@ impl KeyValueDatabase { // MIGRATIONS if services().globals.database_version()? < 1 { for (roomserverid, _) in db.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); + let mut parts = roomserverid.split(|&b| b == 0xFF); + let room_id = + parts.next().expect("split always returns one element"); let Some(servername) = parts.next() else { error!("Migration: Invalid roomserverid in db."); continue; }; let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); + serverroomid.push(0xFF); serverroomid.extend_from_slice(room_id); db.serverroomids.insert(&serverroomid, &[])?; @@ -524,13 +578,16 @@ impl KeyValueDatabase { } if services().globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" + // We accidentally inserted hashed versions of "" into the db + // instead of just "" for (userid, password) in db.userid_password.iter() { let password = utils::string_from_bytes(&password); - let empty_hashed_password = password.map_or(false, |password| { - argon2::verify_encoded(&password, b"").unwrap_or(false) - }); + let empty_hashed_password = + password.map_or(false, |password| { + argon2::verify_encoded(&password, b"") + .unwrap_or(false) + }); if empty_hashed_password { db.userid_password.insert(&userid, b"")?; @@ -567,10 +624,16 @@ impl KeyValueDatabase { if services().users.is_deactivated(&our_user)? { continue; } - for room in services().rooms.state_cache.rooms_joined(&our_user) { - for user in services().rooms.state_cache.room_members(&room?) { + for room in + services().rooms.state_cache.rooms_joined(&our_user) + { + for user in + services().rooms.state_cache.room_members(&room?) + { let user = user?; - if user.server_name() != services().globals.server_name() { + if user.server_name() + != services().globals.server_name() + { info!(?user, "Migration: creating user"); services().users.create(&user, None)?; } @@ -585,16 +648,18 @@ impl KeyValueDatabase { if services().globals.database_version()? < 5 { // Upgrade user data store - for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xff); + for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() + { + let mut parts = roomuserdataid.split(|&b| b == 0xFF); let room_id = parts.next().unwrap(); let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap(); + let event_type = + roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); let mut key = room_id.to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(event_type); db.roomusertype_roomuserdataid @@ -611,7 +676,10 @@ impl KeyValueDatabase { for (roomid, _) in db.roomid_shortstatehash.iter() { let string = utils::string_from_bytes(&roomid).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services().rooms.state_cache.update_joined_count(room_id)?; + services() + .rooms + .state_cache + .update_joined_count(room_id)?; } services().globals.bump_database_version(6)?; @@ -621,7 +689,8 @@ impl KeyValueDatabase { if services().globals.database_version()? < 7 { // Upgrade state store - let mut last_roomstates: HashMap = HashMap::new(); + let mut last_roomstates: HashMap = + HashMap::new(); let mut current_sstatehash: Option = None; let mut current_room = None; let mut current_state = HashSet::new(); @@ -633,7 +702,8 @@ impl KeyValueDatabase { current_state: HashSet<_>, last_roomstates: &mut HashMap<_, _>| { counter += 1; - let last_roomsstatehash = last_roomstates.get(current_room); + let last_roomsstatehash = + last_roomstates.get(current_room); let states_parents = last_roomsstatehash.map_or_else( || Ok(Vec::new()), @@ -641,12 +711,16 @@ impl KeyValueDatabase { services() .rooms .state_compressor - .load_shortstatehash_info(last_roomsstatehash) + .load_shortstatehash_info( + last_roomsstatehash, + ) }, )?; let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { + if let Some(parent_stateinfo) = + states_parents.last() + { let statediffnew = current_state .difference(&parent_stateinfo.1) .copied() @@ -663,21 +737,28 @@ impl KeyValueDatabase { (current_state, HashSet::new()) }; - services().rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - // every state change is 2 event changes on average - 2, - states_parents, - )?; + services() + .rooms + .state_compressor + .save_state_from_diff( + current_sstatehash, + Arc::new(statediffnew), + Arc::new(statediffremoved), + // every state change is 2 event changes on + // average + 2, + states_parents, + )?; Ok::<_, Error>(()) }; - for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) - .expect("number of bytes is correct"); + for (k, seventid) in + db.db.open_tree("stateid_shorteventid")?.iter() + { + let sstatehash = + utils::u64_from_bytes(&k[0..size_of::()]) + .expect("number of bytes is correct"); let sstatekey = k[size_of::()..].to_vec(); if Some(sstatehash) != current_sstatehash { if let Some(current_sstatehash) = current_sstatehash { @@ -687,15 +768,23 @@ impl KeyValueDatabase { current_state, &mut last_roomstates, )?; - last_roomstates - .insert(current_room.clone().unwrap(), current_sstatehash); + last_roomstates.insert( + current_room.clone().unwrap(), + current_sstatehash, + ); } current_state = HashSet::new(); current_sstatehash = Some(sstatehash); - let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); + let event_id = db + .shorteventid_eventid + .get(&seventid) + .unwrap() + .unwrap(); + let string = + utils::string_from_bytes(&event_id).unwrap(); + let event_id = + <&EventId>::try_from(string.as_str()).unwrap(); let pdu = services() .rooms .timeline @@ -710,7 +799,8 @@ impl KeyValueDatabase { let mut val = sstatekey; val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); + current_state + .insert(val.try_into().expect("size is correct")); } if let Some(current_sstatehash) = current_sstatehash { @@ -730,7 +820,8 @@ impl KeyValueDatabase { if services().globals.database_version()? < 8 { // Generate short room ids for all rooms for (room_id, _) in db.roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); + let shortroomid = + services().globals.next_count()?.to_be_bytes(); db.roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } @@ -739,7 +830,7 @@ impl KeyValueDatabase { if !key.starts_with(b"!") { return None; } - let mut parts = key.splitn(2, |&b| b == 0xff); + let mut parts = key.splitn(2, |&b| b == 0xFF); let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); @@ -757,25 +848,26 @@ impl KeyValueDatabase { db.pduid_pdu.insert_batch(&mut batch)?; - let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); + let mut batch2 = + db.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; + } + let mut parts = value.splitn(2, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); - let mut new_value = short_room_id; - new_value.extend_from_slice(count); + let mut new_value = short_room_id; + new_value.extend_from_slice(count); - Some((k, new_value)) - }); + Some((k, new_value)) + }); db.eventid_pduid.insert_batch(&mut batch2)?; @@ -793,7 +885,7 @@ impl KeyValueDatabase { if !key.starts_with(b"!") { return None; } - let mut parts = key.splitn(4, |&b| b == 0xff); + let mut parts = key.splitn(4, |&b| b == 0xFF); let room_id = parts.next().unwrap(); let word = parts.next().unwrap(); let _pdu_id_room = parts.next().unwrap(); @@ -806,7 +898,7 @@ impl KeyValueDatabase { .expect("shortroomid should exist"); let mut new_key = short_room_id; new_key.extend_from_slice(word); - new_key.push(0xff); + new_key.push(0xFF); new_key.extend_from_slice(pdu_id_count); Some((new_key, Vec::new())) }) @@ -836,12 +928,15 @@ impl KeyValueDatabase { if services().globals.database_version()? < 10 { // Add other direction for shortstatekeys - for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { + for (statekey, shortstatekey) in + db.statekey_shortstatekey.iter() + { db.shortstatekey_statekey .insert(&shortstatekey, &statekey)?; } - // Force E2EE device list updates so we can send them over federation + // Force E2EE device list updates so we can send them over + // federation for user_id in services().users.iter().filter_map(Result::ok) { services().users.mark_device_key_update(&user_id)?; } @@ -852,9 +947,7 @@ impl KeyValueDatabase { } if services().globals.database_version()? < 11 { - db.db - .open_tree("userdevicesessionid_uiaarequest")? - .clear()?; + db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; services().globals.bump_database_version(11)?; warn!("Migration: 10 -> 11 finished"); @@ -878,24 +971,34 @@ impl KeyValueDatabase { .get( None, &user, - GlobalAccountDataEventType::PushRules.to_string().into(), + GlobalAccountDataEventType::PushRules + .to_string() + .into(), ) .unwrap() .expect("Username is invalid"); let mut account_data = - serde_json::from_str::(raw_rules_list.get()).unwrap(); + serde_json::from_str::( + raw_rules_list.get(), + ) + .unwrap(); let rules_list = &mut account_data.content.global; //content rule { - let content_rule_transformation = - [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; + let content_rule_transformation = [ + ".m.rules.contains_user_name", + ".m.rule.contains_user_name", + ]; - let rule = rules_list.content.get(content_rule_transformation[0]); + let rule = rules_list + .content + .get(content_rule_transformation[0]); if rule.is_some() { let mut rule = rule.unwrap().clone(); - rule.rule_id = content_rule_transformation[1].to_owned(); + rule.rule_id = + content_rule_transformation[1].to_owned(); rules_list .content .shift_remove(content_rule_transformation[0]); @@ -907,7 +1010,10 @@ impl KeyValueDatabase { { let underride_rule_transformation = [ [".m.rules.call", ".m.rule.call"], - [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], + [ + ".m.rules.room_one_to_one", + ".m.rule.room_one_to_one", + ], [ ".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one", @@ -917,11 +1023,14 @@ impl KeyValueDatabase { ]; for transformation in underride_rule_transformation { - let rule = rules_list.underride.get(transformation[0]); + let rule = + rules_list.underride.get(transformation[0]); if let Some(rule) = rule { let mut rule = rule.clone(); rule.rule_id = transformation[1].to_owned(); - rules_list.underride.shift_remove(transformation[0]); + rules_list + .underride + .shift_remove(transformation[0]); rules_list.underride.insert(rule); } } @@ -930,8 +1039,11 @@ impl KeyValueDatabase { services().account_data.update( None, &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; } @@ -940,7 +1052,8 @@ impl KeyValueDatabase { warn!("Migration: 11 -> 12 finished"); } - // This migration can be reused as-is anytime the server-default rules are updated. + // This migration can be reused as-is anytime the server-default + // rules are updated. if services().globals.database_version()? < 13 { for username in services().users.list_local_users()? { let user = match UserId::parse_with_server_name( @@ -959,15 +1072,21 @@ impl KeyValueDatabase { .get( None, &user, - GlobalAccountDataEventType::PushRules.to_string().into(), + GlobalAccountDataEventType::PushRules + .to_string() + .into(), ) .unwrap() .expect("Username is invalid"); let mut account_data = - serde_json::from_str::(raw_rules_list.get()).unwrap(); + serde_json::from_str::( + raw_rules_list.get(), + ) + .unwrap(); - let user_default_rules = ruma::push::Ruleset::server_default(&user); + let user_default_rules = + ruma::push::Ruleset::server_default(&user); account_data .content .global @@ -976,8 +1095,11 @@ impl KeyValueDatabase { services().account_data.update( None, &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), + GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), )?; } @@ -1018,13 +1140,24 @@ impl KeyValueDatabase { match set_emergency_access() { Ok(pwd_set) => { if pwd_set { - warn!("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!"); - services().admin.send_message(RoomMessageEventContent::text_plain("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!")); + warn!( + "The Grapevine account emergency password is set! \ + Please unset it as soon as you finish admin account \ + recovery!" + ); + services().admin.send_message( + RoomMessageEventContent::text_plain( + "The Grapevine account emergency password is set! \ + Please unset it as soon as you finish admin \ + account recovery!", + ), + ); } } Err(e) => { error!( - "Could not set the configured emergency password for the grapevine user: {}", + "Could not set the configured emergency password for the \ + grapevine user: {}", e ); } @@ -1050,15 +1183,15 @@ impl KeyValueDatabase { #[tracing::instrument] pub(crate) async fn start_cleanup_task() { - use tokio::time::interval; + use std::time::{Duration, Instant}; #[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; + use tokio::time::interval; - use std::time::{Duration, Instant}; - - let timer_interval = - Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval)); + let timer_interval = Duration::from_secs(u64::from( + services().globals.config.cleanup_second_interval, + )); tokio::spawn(async move { let mut i = interval(timer_interval); @@ -1092,11 +1225,14 @@ impl KeyValueDatabase { } } -/// Sets the emergency password and push rules for the @grapevine account in case emergency password is set +/// Sets the emergency password and push rules for the @grapevine account in +/// case emergency password is set fn set_emergency_access() -> Result { - let grapevine_user = - UserId::parse_with_server_name("grapevine", services().globals.server_name()) - .expect("@grapevine:server_name is a valid UserId"); + let grapevine_user = UserId::parse_with_server_name( + "grapevine", + services().globals.server_name(), + ) + .expect("@grapevine:server_name is a valid UserId"); services().users.set_password( &grapevine_user, @@ -1113,7 +1249,9 @@ fn set_emergency_access() -> Result { &grapevine_user, GlobalAccountDataEventType::PushRules.to_string().into(), &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { global: ruleset }, + content: PushRulesEventContent { + global: ruleset, + }, }) .expect("to json value always works"), )?; diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 3b64332e..09f17720 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,8 +1,8 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + use super::Config; use crate::Result; -use std::{future::Future, pin::Pin, sync::Arc}; - #[cfg(feature = "sqlite")] pub(crate) mod sqlite; @@ -22,7 +22,8 @@ pub(crate) trait KeyValueDatabaseEngine: Send + Sync { Ok(()) } fn memory_usage(&self) -> Result { - Ok("Current database engine does not support memory usage reporting.".to_owned()) + Ok("Current database engine does not support memory usage reporting." + .to_owned()) } fn clear_caches(&self) {} } @@ -31,7 +32,10 @@ pub(crate) trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; - fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()>; + fn insert_batch( + &self, + iter: &mut dyn Iterator, Vec)>, + ) -> Result<()>; fn remove(&self, key: &[u8]) -> Result<()>; @@ -44,14 +48,20 @@ pub(crate) trait KvTree: Send + Sync { ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; - fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; + fn increment_batch( + &self, + iter: &mut dyn Iterator>, + ) -> Result<()>; fn scan_prefix<'a>( &'a self, prefix: Vec, ) -> Box, Vec)> + 'a>; - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; + fn watch_prefix<'a>( + &'a self, + prefix: &[u8], + ) -> Pin + Send + 'a>>; fn clear(&self) -> Result<()> { for (key, _) in self.iter() { diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 3c9c87b7..f73a4d0f 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,17 +1,21 @@ -use rocksdb::{ - perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache, - ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType, DBRecoveryMode, DBWithThreadMode, - Direction, IteratorMode, MultiThreaded, Options, ReadOptions, WriteOptions, -}; - -use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; -use crate::{utils, Result}; use std::{ future::Future, pin::Pin, sync::{Arc, RwLock}, }; +use rocksdb::{ + perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache, + ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType, + DBRecoveryMode, DBWithThreadMode, Direction, IteratorMode, MultiThreaded, + Options, ReadOptions, WriteOptions, +}; + +use super::{ + super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree, +}; +use crate::{utils, Result}; + pub(crate) struct Engine { rocks: DBWithThreadMode, max_open_files: i32, @@ -38,7 +42,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &Cache) -> Options { let mut db_opts = Options::default(); db_opts.set_block_based_table_factory(&block_based_options); db_opts.create_if_missing(true); - db_opts.increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX)); + db_opts + .increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX)); db_opts.set_max_open_files(max_open_files); db_opts.set_compression_type(DBCompressionType::Lz4); db_opts.set_bottommost_compression_type(DBCompressionType::Zstd); @@ -69,13 +74,17 @@ impl KeyValueDatabaseEngine for Arc { clippy::cast_sign_loss, clippy::cast_possible_truncation )] - let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; + let cache_capacity_bytes = + (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let rocksdb_cache = Cache::new_lru_cache(cache_capacity_bytes); let db_opts = db_options(config.rocksdb_max_open_files, &rocksdb_cache); - let cfs = DBWithThreadMode::::list_cf(&db_opts, &config.database_path) - .unwrap_or_default(); + let cfs = DBWithThreadMode::::list_cf( + &db_opts, + &config.database_path, + ) + .unwrap_or_default(); let db = DBWithThreadMode::::open_cf_descriptors( &db_opts, @@ -119,14 +128,14 @@ impl KeyValueDatabaseEngine for Arc { #[allow(clippy::as_conversions, clippy::cast_precision_loss)] fn memory_usage(&self) -> Result { - let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; + let stats = + get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; Ok(format!( - "Approximate memory usage of all the mem-tables: {:.3} MB\n\ - Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\ - Approximate memory usage of all the table readers: {:.3} MB\n\ - Approximate memory usage by cache: {:.3} MB\n\ - Approximate memory usage by cache pinned: {:.3} MB\n\ - ", + "Approximate memory usage of all the mem-tables: {:.3} \ + MB\nApproximate memory usage of un-flushed mem-tables: {:.3} \ + MB\nApproximate memory usage of all the table readers: {:.3} \ + MB\nApproximate memory usage by cache: {:.3} MB\nApproximate \ + memory usage by cache pinned: {:.3} MB\n", stats.mem_table_total as f64 / 1024.0 / 1024.0, stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, @@ -154,9 +163,7 @@ impl KvTree for RocksDbEngineTree<'_> { fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let writeoptions = WriteOptions::default(); let lock = self.write_lock.read().unwrap(); - self.db - .rocks - .put_cf_opt(&self.cf(), key, value, &writeoptions)?; + self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?; drop(lock); self.watchers.wake(key); @@ -164,12 +171,13 @@ impl KvTree for RocksDbEngineTree<'_> { Ok(()) } - fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + fn insert_batch( + &self, + iter: &mut dyn Iterator, Vec)>, + ) -> Result<()> { let writeoptions = WriteOptions::default(); for (key, value) in iter { - self.db - .rocks - .put_cf_opt(&self.cf(), key, value, &writeoptions)?; + self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?; } Ok(()) @@ -177,10 +185,7 @@ impl KvTree for RocksDbEngineTree<'_> { fn remove(&self, key: &[u8]) -> Result<()> { let writeoptions = WriteOptions::default(); - Ok(self - .db - .rocks - .delete_cf_opt(&self.cf(), key, &writeoptions)?) + Ok(self.db.rocks.delete_cf_opt(&self.cf(), key, &writeoptions)?) } fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { @@ -230,26 +235,26 @@ impl KvTree for RocksDbEngineTree<'_> { let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?; let new = utils::increment(old.as_deref()); - self.db - .rocks - .put_cf_opt(&self.cf(), key, &new, &writeoptions)?; + self.db.rocks.put_cf_opt(&self.cf(), key, &new, &writeoptions)?; drop(lock); Ok(new) } - fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { + fn increment_batch( + &self, + iter: &mut dyn Iterator>, + ) -> Result<()> { let readoptions = ReadOptions::default(); let writeoptions = WriteOptions::default(); let lock = self.write_lock.write().unwrap(); for key in iter { - let old = self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?; + let old = + self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?; let new = utils::increment(old.as_deref()); - self.db - .rocks - .put_cf_opt(&self.cf(), key, new, &writeoptions)?; + self.db.rocks.put_cf_opt(&self.cf(), key, new, &writeoptions)?; } drop(lock); @@ -277,7 +282,10 @@ impl KvTree for RocksDbEngineTree<'_> { ) } - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + fn watch_prefix<'a>( + &'a self, + prefix: &[u8], + ) -> Pin + Send + 'a>> { self.watchers.watch(prefix) } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 1dc8a47d..f73ba509 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,7 +1,3 @@ -use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; -use crate::{database::Config, Result}; -use parking_lot::{Mutex, MutexGuard}; -use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ cell::RefCell, future::Future, @@ -9,9 +5,15 @@ use std::{ pin::Pin, sync::Arc, }; + +use parking_lot::{Mutex, MutexGuard}; +use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use thread_local::ThreadLocal; use tracing::debug; +use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; +use crate::{database::Config, Result}; + thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); @@ -68,7 +70,11 @@ impl Engine { conn.pragma_update(Some(Main), "page_size", 2048)?; conn.pragma_update(Some(Main), "journal_mode", "WAL")?; conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; - conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; + conn.pragma_update( + Some(Main), + "cache_size", + -i64::from(cache_size_kb), + )?; conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; Ok(conn) @@ -79,18 +85,23 @@ impl Engine { } fn read_lock(&self) -> &Connection { - self.read_conn_tls - .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + self.read_conn_tls.get_or(|| { + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap() + }) } fn read_lock_iterator(&self) -> &Connection { - self.read_iterator_conn_tls - .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + self.read_iterator_conn_tls.get_or(|| { + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap() + }) } pub(crate) fn flush_wal(self: &Arc) -> Result<()> { - self.write_lock() - .pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; + self.write_lock().pragma_update( + Some(Main), + "wal_checkpoint", + "RESTART", + )?; Ok(()) } } @@ -108,7 +119,8 @@ impl KeyValueDatabaseEngine for Arc { // calculates cache-size per permanent connection // 1. convert MB to KiB - // 2. divide by permanent connections + permanent iter connections + write connection + // 2. divide by permanent connections + permanent iter connections + + // write connection // 3. round down to nearest integer #[allow( clippy::as_conversions, @@ -117,9 +129,11 @@ impl KeyValueDatabaseEngine for Arc { clippy::cast_sign_loss )] let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0) - / ((num_cpus::get() as f64 * 2.0) + 1.0)) as u32; + / ((num_cpus::get() as f64 * 2.0) + 1.0)) + as u32; - let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); + let writer = + Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); let arc = Arc::new(Engine { writer, @@ -133,7 +147,13 @@ impl KeyValueDatabaseEngine for Arc { } fn open_tree(&self, name: &str) -> Result> { - self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; + self.write_lock().execute( + &format!( + "CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY \ + KEY, \"value\" BLOB NOT NULL )" + ), + [], + )?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), @@ -161,14 +181,26 @@ pub(crate) struct SqliteTable { type TupleOfBytes = (Vec, Vec); impl SqliteTable { - fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { + fn get_with_guard( + &self, + guard: &Connection, + key: &[u8], + ) -> Result>> { Ok(guard - .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? + .prepare( + format!("SELECT value FROM {} WHERE key = ?", self.name) + .as_str(), + )? .query_row([key], |row| row.get(0)) .optional()?) } - fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { + fn insert_with_guard( + &self, + guard: &Connection, + key: &[u8], + value: &[u8], + ) -> Result<()> { guard.execute( format!( "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", @@ -222,7 +254,10 @@ impl KvTree for SqliteTable { Ok(()) } - fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + fn insert_batch( + &self, + iter: &mut dyn Iterator, Vec)>, + ) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; @@ -236,7 +271,10 @@ impl KvTree for SqliteTable { Ok(()) } - fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { + fn increment_batch( + &self, + iter: &mut dyn Iterator>, + ) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; @@ -282,7 +320,8 @@ impl KvTree for SqliteTable { let statement = Box::leak(Box::new( guard .prepare(&format!( - "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", + "SELECT key, value FROM {} WHERE key <= ? ORDER BY \ + key DESC", &self.name )) .unwrap(), @@ -292,7 +331,9 @@ impl KvTree for SqliteTable { let iterator = Box::new( statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .query_map([from], |row| { + Ok((row.get_unwrap(0), row.get_unwrap(1))) + }) .unwrap() .map(Result::unwrap), ); @@ -304,7 +345,8 @@ impl KvTree for SqliteTable { let statement = Box::leak(Box::new( guard .prepare(&format!( - "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + "SELECT key, value FROM {} WHERE key >= ? ORDER BY \ + key ASC", &self.name )) .unwrap(), @@ -314,7 +356,9 @@ impl KvTree for SqliteTable { let iterator = Box::new( statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .query_map([from], |row| { + Ok((row.get_unwrap(0), row.get_unwrap(1))) + }) .unwrap() .map(Result::unwrap), ); @@ -338,14 +382,20 @@ impl KvTree for SqliteTable { Ok(new) } - fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), ) } - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + fn watch_prefix<'a>( + &'a self, + prefix: &[u8], + ) -> Pin + Send + 'a>> { self.watchers.watch(prefix) } diff --git a/src/database/abstraction/watchers.rs b/src/database/abstraction/watchers.rs index c8737583..9f6e5a00 100644 --- a/src/database/abstraction/watchers.rs +++ b/src/database/abstraction/watchers.rs @@ -4,12 +4,14 @@ use std::{ pin::Pin, sync::RwLock, }; + use tokio::sync::watch; #[derive(Default)] pub(super) struct Watchers { #[allow(clippy::type_complexity)] - watchers: RwLock, (watch::Sender<()>, watch::Receiver<()>)>>, + watchers: + RwLock, (watch::Sender<()>, watch::Receiver<()>)>>, } impl Watchers { @@ -17,7 +19,8 @@ impl Watchers { &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { + let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) + { hash_map::Entry::Occupied(o) => o.get().1.clone(), hash_map::Entry::Vacant(v) => { let (tx, rx) = tokio::sync::watch::channel(()); @@ -31,6 +34,7 @@ impl Watchers { rx.changed().await.unwrap(); }) } + pub(super) fn wake(&self, key: &[u8]) { let watchers = self.watchers.read().unwrap(); let mut triggered = Vec::new(); diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index f08ade13..9033ac90 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -7,10 +7,13 @@ use ruma::{ RoomId, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::account_data::Data for KeyValueDatabase { - /// Places one event in the account data of the user and removes the previous entry. + /// Places one event in the account data of the user and removes the + /// previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] fn update( &self, @@ -24,13 +27,14 @@ impl service::account_data::Data for KeyValueDatabase { .unwrap_or_default() .as_bytes() .to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xff); + roomuserdataid + .extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.push(0xFF); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); let mut key = prefix; @@ -45,13 +49,13 @@ impl service::account_data::Data for KeyValueDatabase { self.roomuserdataid_accountdata.insert( &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), + &serde_json::to_vec(&data) + .expect("to_vec always works on json values"), )?; let prev = self.roomusertype_roomuserdataid.get(&key)?; - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; + self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; // Remove old entry if let Some(prev) = prev { @@ -74,17 +78,15 @@ impl service::account_data::Data for KeyValueDatabase { .unwrap_or_default() .as_bytes() .to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(kind.to_string().as_bytes()); self.roomusertype_roomuserdataid .get(&key)? .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() + self.roomuserdataid_accountdata.get(&roomuserdataid).transpose() }) .transpose()? .map(|data| { @@ -101,7 +103,8 @@ impl service::account_data::Data for KeyValueDatabase { room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>> { + ) -> Result>> + { let mut userdata = HashMap::new(); let mut prefix = room_id @@ -109,9 +112,9 @@ impl service::account_data::Data for KeyValueDatabase { .unwrap_or_default() .as_bytes() .to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); // Skip the data that's exactly at since, because we sent that last time let mut first_possible = prefix.clone(); @@ -124,14 +127,27 @@ impl service::account_data::Data for KeyValueDatabase { .map(|(k, v)| { Ok::<_, Error>(( RoomAccountDataEventType::from( - utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( - || Error::bad_database("RoomUserData ID in db is invalid."), - )?) - .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, + utils::string_from_bytes( + k.rsplit(|&b| b == 0xFF).next().ok_or_else( + || { + Error::bad_database( + "RoomUserData ID in db is invalid.", + ) + }, + )?, + ) + .map_err(|_| { + Error::bad_database( + "RoomUserData ID in db is invalid.", + ) + })?, ), - serde_json::from_slice::>(&v).map_err(|_| { - Error::bad_database("Database contains invalid account data.") - })?, + serde_json::from_slice::>(&v) + .map_err(|_| { + Error::bad_database( + "Database contains invalid account data.", + ) + })?, )) }) { diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 57907f48..b1ef80c0 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -20,8 +20,7 @@ impl service::appservice::Data for KeyValueDatabase { /// /// * `service_name` - the name you send to register the service previously fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes())?; + self.id_appserviceregistrations.remove(service_name.as_bytes())?; Ok(()) } @@ -30,20 +29,25 @@ impl service::appservice::Data for KeyValueDatabase { .get(id.as_bytes())? .map(|bytes| { serde_yaml::from_slice(&bytes).map_err(|_| { - Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") + Error::bad_database( + "Invalid registration bytes in \ + id_appserviceregistrations.", + ) }) }) .transpose() } - fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map( - |(id, _)| { - utils::string_from_bytes(&id).map_err(|_| { - Error::bad_database("Invalid id bytes in id_appserviceregistrations.") - }) - }, - ))) + fn iter_ids<'a>( + &'a self, + ) -> Result> + 'a>> { + Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { + utils::string_from_bytes(&id).map_err(|_| { + Error::bad_database( + "Invalid id bytes in id_appserviceregistrations.", + ) + }) + }))) } fn all(&self) -> Result> { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 090efcb3..aab8e101 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -6,10 +6,13 @@ use lru_cache::LruCache; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, + UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; pub(crate) const COUNTER: &[u8] = b"c"; @@ -27,14 +30,18 @@ impl service::globals::Data for KeyValueDatabase { }) } - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + async fn watch( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xff); + userid_prefix.push(0xFF); let mut userdeviceid_prefix = userid_prefix.clone(); userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xff); + userdeviceid_prefix.push(0xFF); let mut futures = FuturesUnordered::new(); @@ -46,10 +53,10 @@ impl service::globals::Data for KeyValueDatabase { futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), + self.userroomid_notificationcount.watch_prefix(&userid_prefix), ); - futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); + futures + .push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in for room_id in services() @@ -70,17 +77,24 @@ impl service::globals::Data for KeyValueDatabase { let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xff); + roomid_prefix.push(0xFF); // PDUs futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); // EDUs futures.push(Box::pin(async move { - let _result = services().rooms.edus.typing.wait_for_update(&room_id).await; + let _result = services() + .rooms + .edus + .typing + .wait_for_update(&room_id) + .await; })); - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); + futures.push( + self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix), + ); // Key changes futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); @@ -90,12 +104,11 @@ impl service::globals::Data for KeyValueDatabase { roomuser_prefix.extend_from_slice(&userid_prefix); futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), + self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix), ); } - let mut globaluserdata_prefix = vec![0xff]; + let mut globaluserdata_prefix = vec![0xFF]; globaluserdata_prefix.extend_from_slice(&userid_prefix); futures.push( @@ -107,7 +120,8 @@ impl service::globals::Data for KeyValueDatabase { futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); // One time keys - futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); + futures + .push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(Box::pin(services().globals.rotate.watch())); @@ -126,10 +140,14 @@ impl service::globals::Data for KeyValueDatabase { let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len(); - let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len(); - let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); - let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); - let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); + let statekeyshort_cache = + self.statekeyshort_cache.lock().unwrap().len(); + let our_real_users_cache = + self.our_real_users_cache.read().unwrap().len(); + let appservice_in_room_cache = + self.appservice_in_room_cache.read().unwrap().len(); + let lasttimelinecount_cache = + self.lasttimelinecount_cache.lock().unwrap().len(); let mut response = format!( "\ @@ -194,27 +212,29 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" |s| Ok(s.clone()), )?; - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); utils::string_from_bytes( // 1. version - parts - .next() - .expect("splitn always returns at least one element"), + parts.next().expect("splitn always returns at least one element"), ) .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .and_then(|version| { // 2. key parts .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .ok_or_else(|| { + Error::bad_database("Invalid keypair format in database.") + }) .map(|key| (version, key)) }) .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + Ed25519KeyPair::from_der(key, version).map_err(|_| { + Error::bad_database("Private or public keys are invalid.") + }) }) } + fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } @@ -231,7 +251,10 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" .and_then(|keys| serde_json::from_slice(&keys).ok()) .unwrap_or_else(|| { // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + ServerSigningKeys::new( + origin.to_owned(), + MilliSecondsSinceUnixEpoch::now(), + ) }); let ServerSigningKeys { @@ -245,7 +268,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" self.server_signingkeys.insert( origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + &serde_json::to_vec(&keys) + .expect("serversigningkeys can be serialized"), )?; let mut tree = keys.verify_keys; @@ -258,7 +282,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" Ok(tree) } - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. fn signing_keys_for( &self, origin: &ServerName, @@ -283,8 +308,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" fn database_version(&self) -> Result { self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version) - .map_err(|_| Error::bad_database("Database version id is invalid.")) + utils::u64_from_bytes(&version).map_err(|_| { + Error::bad_database("Database version id is invalid.") + }) }) } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index c2b9a57f..84140c93 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -9,7 +9,9 @@ use ruma::{ OwnedRoomId, RoomId, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::key_backups::Data for KeyValueDatabase { fn create_backup( @@ -20,12 +22,13 @@ impl service::key_backups::Data for KeyValueDatabase { let version = services().globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); self.backupid_algorithm.insert( &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + &serde_json::to_vec(backup_metadata) + .expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag .insert(&key, &services().globals.next_count()?.to_be_bytes())?; @@ -34,13 +37,13 @@ impl service::key_backups::Data for KeyValueDatabase { fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); self.backupid_algorithm.remove(&key)?; self.backupid_etag.remove(&key)?; - key.push(0xff); + key.push(0xFF); for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { self.backupkeyid_backup.remove(&outdated_key)?; @@ -56,7 +59,7 @@ impl service::key_backups::Data for KeyValueDatabase { backup_metadata: &Raw, ) -> Result { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); if self.backupid_algorithm.get(&key)?.is_none() { @@ -73,9 +76,12 @@ impl service::key_backups::Data for KeyValueDatabase { Ok(version.to_owned()) } - fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + fn get_latest_backup_version( + &self, + user_id: &UserId, + ) -> Result> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -85,11 +91,13 @@ impl service::key_backups::Data for KeyValueDatabase { .next() .map(|(key, _)| { utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + .map_err(|_| { + Error::bad_database("backupid_algorithm key is invalid.") + }) }) .transpose() } @@ -99,7 +107,7 @@ impl service::key_backups::Data for KeyValueDatabase { user_id: &UserId, ) -> Result)>> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -109,33 +117,42 @@ impl service::key_backups::Data for KeyValueDatabase { .next() .map(|(key, value)| { let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + .map_err(|_| { + Error::bad_database("backupid_algorithm key is invalid.") + })?; Ok(( version, serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("Algorithm in backupid_algorithm is invalid.") + Error::bad_database( + "Algorithm in backupid_algorithm is invalid.", + ) })?, )) }) .transpose() } - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + fn get_backup( + &self, + user_id: &UserId, + version: &str, + ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) + self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| { + serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database( + "Algorithm in backupid_algorithm is invalid.", + ) }) + }) } fn add_key( @@ -147,7 +164,7 @@ impl service::key_backups::Data for KeyValueDatabase { key_data: &Raw, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); if self.backupid_algorithm.get(&key)?.is_none() { @@ -160,9 +177,9 @@ impl service::key_backups::Data for KeyValueDatabase { self.backupid_etag .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(session_id.as_bytes()); self.backupkeyid_backup @@ -173,7 +190,7 @@ impl service::key_backups::Data for KeyValueDatabase { fn count_keys(&self, user_id: &UserId, version: &str) -> Result { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(version.as_bytes()); Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) @@ -181,7 +198,7 @@ impl service::key_backups::Data for KeyValueDatabase { fn get_etag(&self, user_id: &UserId, version: &str) -> Result { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); Ok(utils::u64_from_bytes( @@ -200,40 +217,56 @@ impl service::key_backups::Data for KeyValueDatabase { version: &str, ) -> Result> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); let mut rooms = BTreeMap::::new(); - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xff); + for result in + self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); - let session_id = - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| { - Error::bad_database("backupkeyid_backup session_id is invalid.") - })?; - - let room_id = RoomId::parse( - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, + let session_id = utils::string_from_bytes( + parts.next().ok_or_else(|| { + Error::bad_database( + "backupkeyid_backup key is invalid.", + ) + })?, ) .map_err(|_| { - Error::bad_database("backupkeyid_backup room_id is invalid room id.") + Error::bad_database( + "backupkeyid_backup session_id is invalid.", + ) })?; - let key_data = serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + let room_id = RoomId::parse( + utils::string_from_bytes(parts.next().ok_or_else( + || { + Error::bad_database( + "backupkeyid_backup key is invalid.", + ) + }, + )?) + .map_err(|_| { + Error::bad_database( + "backupkeyid_backup room_id is invalid.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "backupkeyid_backup room_id is invalid room id.", + ) })?; + let key_data = + serde_json::from_slice(&value).map_err(|_| { + Error::bad_database( + "KeyBackupData in backupkeyid_backup is invalid.", + ) + })?; + Ok::<_, Error>((room_id, session_id, key_data)) }) { @@ -257,30 +290,38 @@ impl service::key_backups::Data for KeyValueDatabase { room_id: &RoomId, ) -> Result>> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); Ok(self .backupkeyid_backup .scan_prefix(prefix) .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xff); + let mut parts = key.rsplit(|&b| b == 0xFF); - let session_id = - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| { - Error::bad_database("backupkeyid_backup session_id is invalid.") - })?; - - let key_data = serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + let session_id = utils::string_from_bytes( + parts.next().ok_or_else(|| { + Error::bad_database( + "backupkeyid_backup key is invalid.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "backupkeyid_backup session_id is invalid.", + ) })?; + let key_data = + serde_json::from_slice(&value).map_err(|_| { + Error::bad_database( + "KeyBackupData in backupkeyid_backup is invalid.", + ) + })?; + Ok::<_, Error>((session_id, key_data)) }) .filter_map(Result::ok) @@ -295,18 +336,20 @@ impl service::key_backups::Data for KeyValueDatabase { session_id: &str, ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(session_id.as_bytes()); self.backupkeyid_backup .get(&key)? .map(|value| { serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + Error::bad_database( + "KeyBackupData in backupkeyid_backup is invalid.", + ) }) }) .transpose() @@ -314,9 +357,9 @@ impl service::key_backups::Data for KeyValueDatabase { fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); - key.push(0xff); + key.push(0xFF); for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { self.backupkeyid_backup.remove(&outdated_key)?; @@ -325,13 +368,18 @@ impl service::key_backups::Data for KeyValueDatabase { Ok(()) } - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + fn delete_room_keys( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + key.push(0xFF); for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { self.backupkeyid_backup.remove(&outdated_key)?; @@ -348,11 +396,11 @@ impl service::key_backups::Data for KeyValueDatabase { session_id: &str, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(version.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(session_id.as_bytes()); for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 6abe5ba5..57154a51 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -12,22 +12,19 @@ impl service::media::Data for KeyValueDatabase { content_type: Option<&str>, ) -> Result> { let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice( content_disposition .as_ref() .map(|f| f.as_bytes()) .unwrap_or_default(), ); - key.push(0xff); + key.push(0xFF); key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), + content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default(), ); self.mediaid_file.insert(&key, &[])?; @@ -42,24 +39,25 @@ impl service::media::Data for KeyValueDatabase { height: u32, ) -> Result<(Option, Option, Vec)> { let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(&width.to_be_bytes()); prefix.extend_from_slice(&height.to_be_bytes()); - prefix.push(0xff); + prefix.push(0xFF); - let (key, _) = self - .mediaid_file - .scan_prefix(prefix) - .next() - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + let (key, _) = + self.mediaid_file.scan_prefix(prefix).next().ok_or( + Error::BadRequest(ErrorKind::NotFound, "Media not found"), + )?; - let mut parts = key.rsplit(|&b| b == 0xff); + let mut parts = key.rsplit(|&b| b == 0xFF); let content_type = parts .next() .map(|bytes| { utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") + Error::bad_database( + "Content type in mediaid_file is invalid unicode.", + ) }) }) .transpose()?; @@ -71,11 +69,14 @@ impl service::media::Data for KeyValueDatabase { let content_disposition = if content_disposition_bytes.is_empty() { None } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database("Content Disposition in mediaid_file is invalid unicode.") - })?, - ) + Some(utils::string_from_bytes(content_disposition_bytes).map_err( + |_| { + Error::bad_database( + "Content Disposition in mediaid_file is invalid \ + unicode.", + ) + }, + )?) }; Ok((content_disposition, content_type, key)) } diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index cf61a4a0..bd80288a 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -6,30 +6,39 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::pusher::Data for KeyValueDatabase { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + fn set_pusher( + &self, + sender: &UserId, + pusher: set_pusher::v3::PusherAction, + ) -> Result<()> { match &pusher { set_pusher::v3::PusherAction::Post(data) => { let mut key = sender.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); self.senderkey_pusher.insert( &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + &serde_json::to_vec(&pusher) + .expect("Pusher is valid JSON value"), )?; Ok(()) } set_pusher::v3::PusherAction::Delete(ids) => { let mut key = sender.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(ids.pushkey.as_bytes()); self.senderkey_pusher.remove(&key).map_err(Into::into) } } } - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + fn get_pusher( + &self, + sender: &UserId, + pushkey: &str, + ) -> Result> { let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xff); + senderkey.push(0xFF); senderkey.extend_from_slice(pushkey.as_bytes()); self.senderkey_pusher @@ -43,7 +52,7 @@ impl service::pusher::Data for KeyValueDatabase { fn get_pushers(&self, sender: &UserId) -> Result> { let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); self.senderkey_pusher .scan_prefix(prefix) @@ -59,16 +68,20 @@ impl service::pusher::Data for KeyValueDatabase { sender: &UserId, ) -> Box> + 'a> { let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xff); + let mut parts = k.splitn(2, |&b| b == 0xFF); let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; + let push_key = parts.next().ok_or_else(|| { + Error::bad_database("Invalid senderkey_pusher in db") + })?; + let push_key_string = + utils::string_from_bytes(push_key).map_err(|_| { + Error::bad_database( + "Invalid pusher bytes in senderkey_pusher", + ) + })?; Ok(push_key_string) })) diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 5058cb87..a1355ed2 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,6 +1,11 @@ -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; +use ruma::{ + api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, + RoomId, +}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::rooms::alias::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] @@ -8,17 +13,20 @@ impl service::rooms::alias::Data for KeyValueDatabase { self.alias_roomid .insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xff); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + aliasid.push(0xFF); + aliasid + .extend_from_slice(&services().globals.next_count()?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(()) } #[tracing::instrument(skip(self))] fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { + if let Some(room_id) = + self.alias_roomid.get(alias.alias().as_bytes())? + { let mut prefix = room_id.clone(); - prefix.push(0xff); + prefix.push(0xFF); for (key, _) in self.aliasid_alias.scan_prefix(prefix) { self.aliasid_alias.remove(&key)?; @@ -34,14 +42,23 @@ impl service::rooms::alias::Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + fn resolve_local_alias( + &self, + alias: &RoomAliasId, + ) -> Result> { self.alias_roomid .get(alias.alias().as_bytes())? .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in alias_roomid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) + RoomId::parse(utils::string_from_bytes(&bytes).map_err( + |_| { + Error::bad_database( + "Room ID in alias_roomid is invalid unicode.", + ) + }, + )?) + .map_err(|_| { + Error::bad_database("Room ID in alias_roomid is invalid.") + }) }) .transpose() } @@ -52,13 +69,17 @@ impl service::rooms::alias::Data for KeyValueDatabase { room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? + .map_err(|_| { + Error::bad_database("Invalid alias bytes in aliasid_alias.") + })? .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) + .map_err(|_| { + Error::bad_database("Invalid alias in aliasid_alias.") + }) })) } } diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 60057ac1..dee4269e 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -3,9 +3,13 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{database::KeyValueDatabase, service, utils, Result}; impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { + fn get_cached_eventid_authchain( + &self, + key: &[u64], + ) -> Result>>> { // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) + { return Ok(Some(Arc::clone(result))); } @@ -18,7 +22,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { .map(|chain| { chain .chunks_exact(size_of::()) - .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) + .map(|chunk| { + utils::u64_from_bytes(chunk) + .expect("byte length is correct") + }) .collect() }); @@ -38,7 +45,11 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { Ok(None) } - fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + fn cache_auth_chain( + &self, + key: Vec, + auth_chain: Arc>, + ) -> Result<()> { // Only persist single events in db if key.len() == 1 { self.shorteventid_authchain.insert( @@ -51,10 +62,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { } // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(key, auth_chain); + self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); Ok(()) } diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 9ed62582..58964927 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -19,14 +19,18 @@ impl service::rooms::directory::Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn public_rooms<'a>(&'a self) -> Box> + 'a> { + fn public_rooms<'a>( + &'a self, + ) -> Box> + 'a> { Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) + RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Room ID in publicroomids is invalid unicode.", + ) + })?) + .map_err(|_| { + Error::bad_database("Room ID in publicroomids is invalid.") + }) })) } } diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 5b5b54aa..28b66c27 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,10 +1,13 @@ use std::mem; use ruma::{ - events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, + events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, + OwnedUserId, RoomId, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { fn readreceipt_update( @@ -14,7 +17,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { event: ReceiptEvent, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -25,7 +28,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { .iter_from(&last_possible_key, true) .take_while(|(key, _)| key.starts_with(&prefix)) .find(|(key, _)| { - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element") == user_id.as_bytes() @@ -36,13 +39,15 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xff); + room_latest_id + .extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); self.readreceiptid_readreceipt.insert( &room_latest_id, - &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), + &serde_json::to_vec(&event) + .expect("EduEvent::to_string always works"), )?; Ok(()) @@ -64,7 +69,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { > + 'a, > { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let prefix2 = prefix.clone(); let mut first_possible_edu = prefix.clone(); @@ -79,21 +84,35 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { let count = utils::u64_from_bytes( &k[prefix.len()..prefix.len() + mem::size_of::()], ) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + .map_err(|_| { + Error::bad_database( + "Invalid readreceiptid count in db.", + ) + })?; let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) - .map_err(|_| { - Error::bad_database("Invalid readreceiptid userid bytes in db.") - })?, + utils::string_from_bytes( + &k[prefix.len() + mem::size_of::() + 1..], + ) + .map_err(|_| { + Error::bad_database( + "Invalid readreceiptid userid bytes in db.", + ) + })?, ) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; + .map_err(|_| { + Error::bad_database( + "Invalid readreceiptid userid in db.", + ) + })?; let mut json = - serde_json::from_slice::(&v).map_err(|_| { - Error::bad_database( - "Read receipt in roomlatestid_roomlatest is invalid json.", - ) - })?; + serde_json::from_slice::(&v) + .map_err(|_| { + Error::bad_database( + "Read receipt in roomlatestid_roomlatest \ + is invalid json.", + ) + })?; json.remove("room_id"); Ok(( @@ -109,36 +128,46 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + fn private_read_set( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<()> { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate .insert(&key, &services().globals.next_count()?.to_be_bytes()) } #[tracing::instrument(skip(self))] - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + fn private_read_get( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result> { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { - Error::bad_database("Invalid private read marker bytes") - })?)) - }) + self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| { + Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { + Error::bad_database("Invalid private read marker bytes") + })?)) + }) } - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn last_privateread_update( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); Ok(self @@ -146,7 +175,9 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { .get(&key)? .map(|bytes| { utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") + Error::bad_database( + "Count in roomuserid_lastprivatereadupdate is invalid.", + ) }) }) .transpose()? diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index a19d52cb..01952379 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -11,11 +11,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase { ll_user: &UserId, ) -> Result { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(ll_user.as_bytes()); Ok(self.lazyloadedids.get(&key)?.is_some()) } @@ -28,11 +28,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase { confirmed_user_ids: &mut dyn Iterator, ) -> Result<()> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); for ll_id in confirmed_user_ids { let mut key = prefix.clone(); @@ -50,11 +50,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase { room_id: &RoomId, ) -> Result<()> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); for (key, _) in self.lazyloadedids.scan_prefix(prefix) { self.lazyloadedids.remove(&key)?; diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 6a38f08b..ab7a5cb8 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,6 +1,8 @@ use ruma::{OwnedRoomId, RoomId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::rooms::metadata::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] @@ -19,14 +21,18 @@ impl service::rooms::metadata::Data for KeyValueDatabase { .is_some()) } - fn iter_ids<'a>(&'a self) -> Box> + 'a> { + fn iter_ids<'a>( + &'a self, + ) -> Box> + 'a> { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Room ID in publicroomids is invalid unicode.", + ) + })?) + .map_err(|_| { + Error::bad_database("Room ID in roomid_shortroomid is invalid.") + }) })) } diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index f4269770..f41a02ee 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -3,24 +3,35 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; impl service::rooms::outlier::Data for KeyValueDatabase { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + fn get_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { + self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or( + Ok(None), + |pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + }, + ) } fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or( + Ok(None), + |pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + }, + ) } #[tracing::instrument(skip(self, pdu))] - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + fn add_pdu_outlier( + &self, + event_id: &EventId, + pdu: &CanonicalJsonObject, + ) -> Result<()> { self.eventid_outlierpdu.insert( event_id.as_bytes(), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 0641f9d8..b890985f 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -22,7 +22,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { shortroomid: u64, target: u64, until: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> + { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); @@ -40,8 +41,12 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { .iter_from(¤t, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; + let from = utils::u64_from_bytes( + &tofrom[(mem::size_of::())..], + ) + .map_err(|_| { + Error::bad_database("Invalid count in tofrom_relation.") + })?; let mut pduid = shortroomid.to_be_bytes().to_vec(); pduid.extend_from_slice(&from.to_be_bytes()); @@ -50,7 +55,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { .rooms .timeline .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; + .ok_or_else(|| { + Error::bad_database( + "Pdu in tofrom_relation is invalid.", + ) + })?; if pdu.sender != user_id { pdu.remove_transaction_id()?; } @@ -59,7 +68,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { )) } - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + fn mark_as_referenced( + &self, + room_id: &RoomId, + event_ids: &[Arc], + ) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); @@ -69,7 +82,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { Ok(()) } - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn is_event_referenced( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(event_id.as_bytes()); Ok(self.referencedevents.get(&key)?.is_some()) @@ -80,8 +97,6 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { } fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some()) } } diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index a2eba329..bd14f312 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -4,7 +4,12 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result}; impl service::rooms::search::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + fn index_pdu( + &self, + shortroomid: u64, + pdu_id: &[u8], + message_body: &str, + ) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) @@ -13,7 +18,7 @@ impl service::rooms::search::Data for KeyValueDatabase { .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); - key.push(0xff); + key.push(0xFF); // TODO: currently we save the room id a second time here key.extend_from_slice(pdu_id); (key, Vec::new()) @@ -28,7 +33,8 @@ impl service::rooms::search::Data for KeyValueDatabase { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a>, Vec)>> { + ) -> Result> + 'a>, Vec)>> + { let prefix = services() .rooms .short @@ -46,7 +52,7 @@ impl service::rooms::search::Data for KeyValueDatabase { let iterators = words.clone().into_iter().map(move |word| { let mut prefix2 = prefix.clone(); prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xff); + prefix2.push(0xFF); let prefix3 = prefix2.clone(); let mut last_possible_id = prefix2.clone(); @@ -60,7 +66,9 @@ impl service::rooms::search::Data for KeyValueDatabase { }); // We compare b with a because we reversed the iterator earlier - let Some(common_elements) = utils::common_elements(iterators, |a, b| b.cmp(a)) else { + let Some(common_elements) = + utils::common_elements(iterators, |a, b| b.cmp(a)) + else { return Ok(None); }; diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index b8186561..2dd04f5e 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -2,26 +2,32 @@ use std::sync::Arc; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::rooms::short::Data for KeyValueDatabase { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { + if let Some(short) = + self.eventidshort_cache.lock().unwrap().get_mut(event_id) + { return Ok(*short); } - let short = - if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid) - .map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = services().globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; + let short = if let Some(shorteventid) = + self.eventid_shorteventid.get(event_id.as_bytes())? + { + utils::u64_from_bytes(&shorteventid).map_err(|_| { + Error::bad_database("Invalid shorteventid in db.") + })? + } else { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + }; self.eventidshort_cache .lock() @@ -46,15 +52,16 @@ impl service::rooms::short::Data for KeyValueDatabase { } let mut db_key = event_type.to_string().as_bytes().to_vec(); - db_key.push(0xff); + db_key.push(0xFF); db_key.extend_from_slice(state_key.as_bytes()); let short = self .statekey_shortstatekey .get(&db_key)? .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + utils::u64_from_bytes(&shortstatekey).map_err(|_| { + Error::bad_database("Invalid shortstatekey in db.") + }) }) .transpose()?; @@ -83,12 +90,15 @@ impl service::rooms::short::Data for KeyValueDatabase { } let mut db_key = event_type.to_string().as_bytes().to_vec(); - db_key.push(0xff); + db_key.push(0xFF); db_key.extend_from_slice(state_key.as_bytes()); - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&db_key)? { - utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? + let short = if let Some(shortstatekey) = + self.statekey_shortstatekey.get(&db_key)? + { + utils::u64_from_bytes(&shortstatekey).map_err(|_| { + Error::bad_database("Invalid shortstatekey in db.") + })? } else { let shortstatekey = services().globals.next_count()?; self.statekey_shortstatekey @@ -106,12 +116,12 @@ impl service::rooms::short::Data for KeyValueDatabase { Ok(short) } - fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - if let Some(id) = self - .shorteventid_cache - .lock() - .unwrap() - .get_mut(&shorteventid) + fn get_eventid_from_short( + &self, + shorteventid: u64, + ) -> Result> { + if let Some(id) = + self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { return Ok(Arc::clone(id)); } @@ -119,12 +129,20 @@ impl service::rooms::short::Data for KeyValueDatabase { let bytes = self .shorteventid_eventid .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + .ok_or_else(|| { + Error::bad_database("Shorteventid does not exist") + })?; - let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + let event_id = EventId::parse_arc( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "EventID in shorteventid_eventid is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database("EventId in shorteventid_eventid is invalid.") + })?; self.shorteventid_cache .lock() @@ -134,12 +152,12 @@ impl service::rooms::short::Data for KeyValueDatabase { Ok(event_id) } - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - if let Some(id) = self - .shortstatekey_cache - .lock() - .unwrap() - .get_mut(&shortstatekey) + fn get_statekey_from_short( + &self, + shortstatekey: u64, + ) -> Result<(StateEventType, String)> { + if let Some(id) = + self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) { return Ok(id.clone()); } @@ -147,23 +165,32 @@ impl service::rooms::short::Data for KeyValueDatabase { let bytes = self .shortstatekey_statekey .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + .ok_or_else(|| { + Error::bad_database("Shortstatekey does not exist") + })?; - let mut parts = bytes.splitn(2, |&b| b == 0xff); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = - StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { - Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") - })?); - - let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { - Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") + let mut parts = bytes.splitn(2, |&b| b == 0xFF); + let eventtype_bytes = + parts.next().expect("split always returns one entry"); + let statekey_bytes = parts.next().ok_or_else(|| { + Error::bad_database("Invalid statekey in shortstatekey_statekey.") })?; + let event_type = StateEventType::from( + utils::string_from_bytes(eventtype_bytes).map_err(|_| { + Error::bad_database( + "Event type in shortstatekey_statekey is invalid unicode.", + ) + })?, + ); + + let state_key = + utils::string_from_bytes(statekey_bytes).map_err(|_| { + Error::bad_database( + "Statekey in shortstatekey_statekey is invalid unicode.", + ) + })?; + let result = (event_type, state_key); self.shortstatekey_cache @@ -175,12 +202,18 @@ impl service::rooms::short::Data for KeyValueDatabase { } /// Returns `(shortstatehash, already_existed)` - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + fn get_or_create_shortstatehash( + &self, + state_hash: &[u8], + ) -> Result<(u64, bool)> { Ok( - if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { + if let Some(shortstatehash) = + self.statehash_shortstatehash.get(state_hash)? + { ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + utils::u64_from_bytes(&shortstatehash).map_err(|_| { + Error::bad_database("Invalid shortstatehash in db.") + })?, true, ) } else { @@ -196,17 +229,21 @@ impl service::rooms::short::Data for KeyValueDatabase { self.roomid_shortroomid .get(room_id.as_bytes())? .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortroomid in db.") + }) }) .transpose() } fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { Ok( - if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))? + if let Some(short) = + self.roomid_shortroomid.get(room_id.as_bytes())? + { + utils::u64_from_bytes(&short).map_err(|_| { + Error::bad_database("Invalid shortroomid in db.") + })? } else { let short = services().globals.next_count()?; self.roomid_shortroomid diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index c7e042d2..78863a75 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,20 +1,22 @@ -use ruma::{EventId, OwnedEventId, RoomId}; -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; -use std::sync::Arc; +use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { + self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or( + Ok(None), + |bytes| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + Error::bad_database( + "Invalid shortstatehash in roomid_shortstatehash", + ) })?)) - }) + }, + ) } fn set_room_state( @@ -29,23 +31,40 @@ impl service::rooms::state::Data for KeyValueDatabase { Ok(()) } - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + fn set_event_state( + &self, + shorteventid: u64, + shortstatehash: u64, + ) -> Result<()> { + self.shorteventid_shortstatehash.insert( + &shorteventid.to_be_bytes(), + &shortstatehash.to_be_bytes(), + )?; Ok(()) } - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + fn get_forward_extremities( + &self, + room_id: &RoomId, + ) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); self.roomid_pduleaves .scan_prefix(prefix) .map(|(_, bytes)| { - EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + EventId::parse_arc(utils::string_from_bytes(&bytes).map_err( + |_| { + Error::bad_database( + "EventID in roomid_pduleaves is invalid unicode.", + ) + }, + )?) + .map_err(|_| { + Error::bad_database( + "EventId in roomid_pduleaves is invalid.", + ) + }) }) .collect() } @@ -58,7 +77,7 @@ impl service::rooms::state::Data for KeyValueDatabase { _mutex_lock: &MutexGuard<'_, ()>, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { self.roomid_pduleaves.remove(&key)?; diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index dc5a112d..ed763bf3 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,12 +1,19 @@ use std::{collections::HashMap, sync::Arc}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, PduEvent, + Result, +}; + #[async_trait] impl service::rooms::state_accessor::Data for KeyValueDatabase { - async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + async fn state_full_ids( + &self, + shortstatehash: u64, + ) -> Result>> { let full_state = services() .rooms .state_compressor @@ -56,7 +63,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { pdu.kind.to_string().into(), pdu.state_key .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? + .ok_or_else(|| { + Error::bad_database( + "State event has no state key.", + ) + })? .clone(), ), pdu, @@ -72,17 +83,16 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { Ok(result) } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { - let Some(shortstatekey) = services() - .rooms - .short - .get_shortstatekey(event_type, state_key)? + let Some(shortstatekey) = + services().rooms.short.get_shortstatekey(event_type, state_key)? else { return Ok(None); }; @@ -106,7 +116,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { })) } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn state_get( &self, shortstatehash: u64, @@ -121,20 +132,22 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { /// Returns the state hash for this pdu. fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { + self.eventid_shorteventid.get(event_id.as_bytes())?.map_or( + Ok(None), + |shorteventid| { self.shorteventid_shortstatehash .get(&shorteventid)? .map(|bytes| { utils::u64_from_bytes(&bytes).map_err(|_| { Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", + "Invalid shortstatehash bytes in \ + shorteventid_shortstatehash", ) }) }) .transpose() - }) + }, + ) } /// Returns the full room state. @@ -151,7 +164,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { } } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn room_state_get_id( &self, room_id: &RoomId, @@ -167,7 +181,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { } } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn room_state_get( &self, room_id: &RoomId, diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index a1a5025a..776f10bd 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -13,20 +13,24 @@ use crate::{ }; impl service::rooms::state_cache::Data for KeyValueDatabase { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + fn mark_as_once_joined( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); self.roomuseroncejoinedids.insert(&userroom_id, &[]) } fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); self.userroomid_joined.insert(&userroom_id, &[])?; @@ -46,11 +50,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { last_state: Option>>, ) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); self.userroomid_invitestate.insert( @@ -72,11 +76,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); // TODO @@ -112,7 +116,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { joinedcount += 1; } - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { + for _invited in + self.room_members_invited(room_id).filter_map(Result::ok) + { invitedcount += 1; } @@ -127,15 +133,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .unwrap() .insert(room_id.to_owned(), Arc::new(real_users)); - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { + for old_joined_server in + self.room_servers(room_id).filter_map(Result::ok) + { if !joined_servers.remove(&old_joined_server) { // Server not in room anymore let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); + roomserver_id.push(0xFF); roomserver_id.extend_from_slice(old_joined_server.as_bytes()); let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xff); + serverroom_id.push(0xFF); serverroom_id.extend_from_slice(room_id.as_bytes()); self.roomserverids.remove(&roomserver_id)?; @@ -146,49 +154,45 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { // Now only new servers are in joined_servers anymore for server in joined_servers { let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); + roomserver_id.push(0xFF); roomserver_id.extend_from_slice(server.as_bytes()); let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xff); + serverroom_id.push(0xFF); serverroom_id.extend_from_slice(room_id.as_bytes()); self.roomserverids.insert(&roomserver_id, &[])?; self.serverroomids.insert(&serverroom_id, &[])?; } - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); + self.appservice_in_room_cache.write().unwrap().remove(room_id); Ok(()) } #[tracing::instrument(skip(self, room_id))] - fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { - let maybe = self - .our_real_users_cache - .read() - .unwrap() - .get(room_id) - .cloned(); + fn get_our_real_users( + &self, + room_id: &RoomId, + ) -> Result>> { + let maybe = + self.our_real_users_cache.read().unwrap().get(room_id).cloned(); if let Some(users) = maybe { Ok(users) } else { self.update_joined_count(room_id)?; Ok(Arc::clone( - self.our_real_users_cache - .read() - .unwrap() - .get(room_id) - .unwrap(), + self.our_real_users_cache.read().unwrap().get(room_id).unwrap(), )) } } #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { + fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &RegistrationInfo, + ) -> Result { let maybe = self .appservice_in_room_cache .read() @@ -206,11 +210,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { ) .ok(); - let in_room = bridge_user_id - .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self.room_members(room_id).any(|userid| { - userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())) - }); + let in_room = bridge_user_id.map_or(false, |id| { + self.is_joined(&id, room_id).unwrap_or(false) + }) || self.room_members(room_id).any(|userid| { + userid.map_or(false, |userid| { + appservice.users.is_match(userid.as_str()) + }) + }); self.appservice_in_room_cache .write() @@ -227,11 +233,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); self.userroomid_leftstate.remove(&userroom_id)?; @@ -247,51 +253,66 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { ServerName::parse( utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) .map_err(|_| { - Error::bad_database("Server name in roomserverids is invalid unicode.") + Error::bad_database( + "Server name in roomserverids is invalid unicode.", + ) })?, ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) + .map_err(|_| { + Error::bad_database("Server name in roomserverids is invalid.") + }) })) } #[tracing::instrument(skip(self))] - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + fn server_in_room( + &self, + server: &ServerName, + room_id: &RoomId, + ) -> Result { let mut key = server.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); self.serverroomids.get(&key).map(|o| o.is_some()) } - /// Returns an iterator of all rooms a server participates in (as far as we know). + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). #[tracing::instrument(skip(self))] fn server_rooms<'a>( &'a self, server: &ServerName, ) -> Box> + 'a> { let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { RoomId::parse( utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + .map_err(|_| { + Error::bad_database( + "RoomId in serverroomids is invalid unicode.", + ) + })?, ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) + .map_err(|_| { + Error::bad_database("RoomId in serverroomids is invalid.") + }) })) } @@ -302,20 +323,24 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { UserId::parse( utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) .map_err(|_| { - Error::bad_database("User ID in roomuserid_joined is invalid unicode.") + Error::bad_database( + "User ID in roomuserid_joined is invalid unicode.", + ) })?, ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) + .map_err(|_| { + Error::bad_database("User ID in roomuserid_joined is invalid.") + }) })) } @@ -324,8 +349,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomid_joinedcount .get(room_id.as_bytes())? .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + utils::u64_from_bytes(&b).map_err(|_| { + Error::bad_database("Invalid joinedcount in db.") + }) }) .transpose() } @@ -335,8 +361,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomid_invitedcount .get(room_id.as_bytes())? .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + utils::u64_from_bytes(&b).map_err(|_| { + Error::bad_database("Invalid joinedcount in db.") + }) }) .transpose() } @@ -348,27 +375,30 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database( - "User ID in room_useroncejoined is invalid unicode.", - ) - })?, + Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map( + |(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) + .map_err(|_| { + Error::bad_database( + "User ID in room_useroncejoined is invalid \ + unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "User ID in room_useroncejoined is invalid.", + ) + }) + }, + )) } /// Returns an iterator over all invited members of a room. @@ -378,53 +408,64 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_invited is invalid unicode.") - })?, + Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map( + |(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) + .map_err(|_| { + Error::bad_database( + "User ID in roomuserid_invited is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "User ID in roomuserid_invited is invalid.", + ) + }) + }, + )) } #[tracing::instrument(skip(self))] - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + fn get_invite_count( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result> { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid invitecount in db.") - })?)) - }) + self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid invitecount in db.") + })?)) + }) } #[tracing::instrument(skip(self))] - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + fn get_left_count( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result> { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); self.roomuserid_leftcount .get(&key)? .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid leftcount in db.")) + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid leftcount in db.") + }) }) .transpose() } @@ -441,15 +482,22 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .map(|(key, _)| { RoomId::parse( utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xFF) .next() .expect("rsplit always returns an element"), ) .map_err(|_| { - Error::bad_database("Room ID in userroomid_joined is invalid unicode.") + Error::bad_database( + "Room ID in userroomid_joined is invalid \ + unicode.", + ) })?, ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) + .map_err(|_| { + Error::bad_database( + "Room ID in userroomid_joined is invalid.", + ) + }) }), ) } @@ -460,35 +508,43 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn rooms_invited<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a> { + ) -> Box< + dyn Iterator< + Item = Result<(OwnedRoomId, Vec>)>, + > + 'a, + > { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, + Box::new(self.userroomid_invitestate.scan_prefix(prefix).map( + |(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), ) .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; + Error::bad_database( + "Room ID in userroomid_invited is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "Room ID in userroomid_invited is invalid.", + ) + })?; - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_invitestate.") - })?; + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database( + "Invalid state in userroomid_invitestate.", + ) + })?; - Ok((room_id, state)) - }), - ) + Ok((room_id, state)) + }, + )) } #[tracing::instrument(skip(self))] @@ -498,14 +554,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Result>>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); self.userroomid_invitestate .get(&key)? .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database( + "Invalid state in userroomid_invitestate.", + ) + })?; Ok(state) }) @@ -519,14 +578,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, ) -> Result>>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); self.userroomid_leftstate .get(&key)? .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database( + "Invalid state in userroomid_leftstate.", + ) + })?; Ok(state) }) @@ -539,41 +601,48 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn rooms_left<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a> { + ) -> Box< + dyn Iterator>)>> + + 'a, + > { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, + Box::new(self.userroomid_leftstate.scan_prefix(prefix).map( + |(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), ) .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; + Error::bad_database( + "Room ID in userroomid_invited is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database( + "Room ID in userroomid_invited is invalid.", + ) + })?; - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database( + "Invalid state in userroomid_leftstate.", + ) + })?; - Ok((room_id, state)) - }), - ) + Ok((room_id, state)) + }, + )) } #[tracing::instrument(skip(self))] fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) @@ -582,7 +651,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) @@ -591,7 +660,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) @@ -600,7 +669,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { #[tracing::instrument(skip(self))] fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index ab06d8f3..7fb24698 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -12,8 +12,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase { .shortstatehash_statediff .get(&shortstatehash.to_be_bytes())? .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = - utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + let parent = utils::u64_from_bytes(&value[0..size_of::()]) + .expect("bytes have right length"); let parent = (parent != 0).then_some(parent); let mut add_mode = true; @@ -30,7 +30,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase { if add_mode { added.insert(v.try_into().expect("we checked the size above")); } else { - removed.insert(v.try_into().expect("we checked the size above")); + removed + .insert(v.try_into().expect("we checked the size above")); } i += 2 * size_of::(); } @@ -42,7 +43,11 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase { }) } - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { + fn save_statediff( + &self, + shortstatehash: u64, + diff: StateDiff, + ) -> Result<()> { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); diff --git a/src/database/key_value/rooms/threads.rs b/src/database/key_value/rooms/threads.rs index 257828bb..9915198e 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/database/key_value/rooms/threads.rs @@ -1,8 +1,14 @@ use std::mem; -use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; +use ruma::{ + api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, + UserId, +}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, PduEvent, + Result, +}; impl service::rooms::threads::Data for KeyValueDatabase { fn threads_until<'a>( @@ -28,14 +34,22 @@ impl service::rooms::threads::Data for KeyValueDatabase { .iter_from(¤t, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; + let count = utils::u64_from_bytes( + &pduid[(mem::size_of::())..], + ) + .map_err(|_| { + Error::bad_database( + "Invalid pduid in threadid_userids.", + ) + })?; let mut pdu = services() .rooms .timeline .get_pdu_from_id(&pduid)? .ok_or_else(|| { - Error::bad_database("Invalid pduid reference in threadid_userids") + Error::bad_database( + "Invalid pduid reference in threadid_userids", + ) })?; if pdu.sender != user_id { pdu.remove_transaction_id()?; @@ -45,28 +59,43 @@ impl service::rooms::threads::Data for KeyValueDatabase { )) } - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + fn update_participants( + &self, + root_id: &[u8], + participants: &[OwnedUserId], + ) -> Result<()> { let users = participants .iter() .map(|user| user.as_bytes()) .collect::>() - .join(&[0xff][..]); + .join(&[0xFF][..]); self.threadid_userids.insert(root_id, &users)?; Ok(()) } - fn get_participants(&self, root_id: &[u8]) -> Result>> { + fn get_participants( + &self, + root_id: &[u8], + ) -> Result>> { if let Some(users) = self.threadid_userids.get(root_id)? { Ok(Some( users - .split(|b| *b == 0xff) + .split(|b| *b == 0xFF) .map(|bytes| { - UserId::parse(utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Invalid UserId bytes in threadid_userids.") - })?) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) + UserId::parse(utils::string_from_bytes(bytes).map_err( + |_| { + Error::bad_database( + "Invalid UserId bytes in threadid_userids.", + ) + }, + )?) + .map_err(|_| { + Error::bad_database( + "Invalid UserId in threadid_userids.", + ) + }) }) .filter_map(Result::ok) .collect(), diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 468d8a1f..9d125908 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -1,16 +1,23 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; use ruma::{ - api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, + api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, + RoomId, UserId, }; +use service::rooms::timeline::PduCount; use tracing::error; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; - -use service::rooms::timeline::PduCount; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, PduEvent, + Result, +}; impl service::rooms::timeline::Data for KeyValueDatabase { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + fn last_timeline_count( + &self, + sender_user: &UserId, + room_id: &RoomId, + ) -> Result { match self .lasttimelinecount_cache .lock() @@ -45,14 +52,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result> { + fn get_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { self.get_non_outlier_pdu_json(event_id)?.map_or_else( || { self.eventid_outlierpdu .get(event_id.as_bytes())? .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) + serde_json::from_slice(&pdu).map_err(|_| { + Error::bad_database("Invalid PDU in db.") + }) }) .transpose() }, @@ -61,17 +72,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + }) }) .transpose()? .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) }) .transpose() } @@ -82,17 +97,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the pdu. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + fn get_non_outlier_pdu( + &self, + event_id: &EventId, + ) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + }) }) .transpose()? .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) }) .transpose() } @@ -112,8 +131,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { self.eventid_outlierpdu .get(event_id.as_bytes())? .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) + serde_json::from_slice(&pdu).map_err(|_| { + Error::bad_database("Invalid PDU in db.") + }) }) .transpose() }, @@ -144,7 +164,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + fn get_pdu_json_from_id( + &self, + pdu_id: &[u8], + ) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) @@ -162,7 +185,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { ) -> Result<()> { self.pduid_pdu.insert( pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + &serde_json::to_vec(json) + .expect("CanonicalJsonObject is always a valid"), )?; self.lasttimelinecount_cache @@ -184,7 +208,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { ) -> Result<()> { self.pduid_pdu.insert( pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + &serde_json::to_vec(json) + .expect("CanonicalJsonObject is always a valid"), )?; self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; @@ -203,7 +228,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { if self.pduid_pdu.get(pdu_id)?.is_some() { self.pduid_pdu.insert( pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), + &serde_json::to_vec(pdu_json) + .expect("CanonicalJsonObject is always a valid"), )?; } else { return Err(Error::BadRequest( @@ -212,22 +238,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase { )); } - self.pdu_cache - .lock() - .unwrap() - .remove(&(*pdu.event_id).to_owned()); + self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned()); Ok(()) } - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> + { let (prefix, current) = count_to_id(room_id, until, 1, true)?; let user_id = user_id.to_owned(); @@ -238,7 +263,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + .map_err(|_| { + Error::bad_database("PDU in db is invalid.") + })?; if pdu.sender != user_id { pdu.remove_transaction_id()?; } @@ -254,7 +281,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> + { let (prefix, current) = count_to_id(room_id, from, 1, false)?; let user_id = user_id.to_owned(); @@ -265,7 +293,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + .map_err(|_| { + Error::bad_database("PDU in db is invalid.") + })?; if pdu.sender != user_id { pdu.remove_transaction_id()?; } @@ -286,13 +316,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let mut highlights_batch = Vec::new(); for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); notifies_batch.push(userroom_id); } for user in highlights { let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); highlights_batch.push(userroom_id); } @@ -307,10 +337,12 @@ impl service::rooms::timeline::Data for KeyValueDatabase { /// Returns the `count` of this pdu's id. fn pdu_count(pdu_id: &[u8]) -> Result { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; + let last_u64 = + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; let second_last_u64 = utils::u64_from_bytes( - &pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()], + &pdu_id[pdu_id.len() - 2 * size_of::() + ..pdu_id.len() - size_of::()], ); if matches!(second_last_u64, Ok(0)) { @@ -330,7 +362,9 @@ fn count_to_id( .rooms .short .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .ok_or_else(|| { + Error::bad_database("Looked for bad shortroomid in timeline") + })? .to_be_bytes() .to_vec(); let mut pdu_id = prefix.clone(); diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 7c253d3f..53461d27 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,14 +1,20 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, service, services, utils, Error, Result, +}; impl service::rooms::user::Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + fn reset_notification_counts( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); self.userroomid_notificationcount @@ -24,35 +30,51 @@ impl service::rooms::user::Data for KeyValueDatabase { Ok(()) } - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn notification_count( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) + self.userroomid_notificationcount.get(&userroom_id)?.map_or( + Ok(0), + |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid notification count in db.") + }) + }, + ) } - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn highlight_count( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); + userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) + self.userroomid_highlightcount.get(&userroom_id)?.map_or( + Ok(0), + |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid highlight count in db.") + }) + }, + ) } - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn last_notification_read( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); Ok(self @@ -60,7 +82,9 @@ impl service::rooms::user::Data for KeyValueDatabase { .get(&key)? .map(|bytes| { utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") + Error::bad_database( + "Count in roomuserid_lastprivatereadupdate is invalid.", + ) }) }) .transpose()? @@ -86,7 +110,11 @@ impl service::rooms::user::Data for KeyValueDatabase { .insert(&key, &shortstatehash.to_be_bytes()) } - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + fn get_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + ) -> Result> { let shortroomid = services() .rooms .short @@ -100,7 +128,10 @@ impl service::rooms::user::Data for KeyValueDatabase { .get(&key)? .map(|bytes| { utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") + Error::bad_database( + "Invalid shortstatehash in \ + roomsynctoken_shortstatehash", + ) }) }) .transpose() @@ -112,7 +143,7 @@ impl service::rooms::user::Data for KeyValueDatabase { ) -> Result> + 'a>> { let iterators = users.into_iter().map(move |user_id| { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); self.userroomid_joined .scan_prefix(prefix) @@ -121,8 +152,12 @@ impl service::rooms::user::Data for KeyValueDatabase { let roomid_index = key .iter() .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .find(|(_, &b)| b == 0xFF) + .ok_or_else(|| { + Error::bad_database( + "Invalid userroomid_joined in db.", + ) + })? .0 + 1; @@ -133,15 +168,24 @@ impl service::rooms::user::Data for KeyValueDatabase { .filter_map(Result::ok) }); - // We use the default compare function because keys are sorted correctly (not reversed) + // We use the default compare function because keys are sorted correctly + // (not reversed) Ok(Box::new( utils::common_elements(iterators, Ord::cmp) .expect("users is not empty") .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + RoomId::parse(utils::string_from_bytes(&bytes).map_err( + |_| { + Error::bad_database( + "Invalid RoomId bytes in userroomid_joined", + ) + }, + )?) + .map_err(|_| { + Error::bad_database( + "Invalid RoomId in userroomid_joined.", + ) + }) }), )) } diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 6c8e939b..bac027e6 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -12,31 +12,34 @@ use crate::{ impl service::sending::Data for KeyValueDatabase { fn active_requests<'a>( &'a self, - ) -> Box, OutgoingKind, SendingEventType)>> + 'a> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) + ) -> Box< + dyn Iterator, OutgoingKind, SendingEventType)>> + + 'a, + > { + Box::new(self.servercurrentevent_data.iter().map(|(key, v)| { + parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e)) + })) } fn active_requests_for<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box, SendingEventType)>> + 'a> { + ) -> Box, SendingEventType)>> + 'a> + { let prefix = outgoing_kind.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) + Box::new(self.servercurrentevent_data.scan_prefix(prefix).map( + |(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e)), + )) } fn delete_active_request(&self, key: Vec) -> Result<()> { self.servercurrentevent_data.remove(&key) } - fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + fn delete_all_active_requests_for( + &self, + outgoing_kind: &OutgoingKind, + ) -> Result<()> { let prefix = outgoing_kind.get_prefix(); for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { self.servercurrentevent_data.remove(&key)?; @@ -45,9 +48,13 @@ impl service::sending::Data for KeyValueDatabase { Ok(()) } - fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + fn delete_all_requests_for( + &self, + outgoing_kind: &OutgoingKind, + ) -> Result<()> { let prefix = outgoing_kind.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) + { self.servercurrentevent_data.remove(&key).unwrap(); } @@ -69,7 +76,9 @@ impl service::sending::Data for KeyValueDatabase { if let SendingEventType::Pdu(value) = &event { key.extend_from_slice(value); } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice( + &services().globals.next_count()?.to_be_bytes(), + ); } let value = if let SendingEventType::Edu(value) = &event { &**value @@ -79,24 +88,25 @@ impl service::sending::Data for KeyValueDatabase { batch.push((key.clone(), value.to_owned())); keys.push(key); } - self.servernameevent_data - .insert_batch(&mut batch.into_iter())?; + self.servernameevent_data.insert_batch(&mut batch.into_iter())?; Ok(keys) } fn queued_requests<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box)>> + 'a> { + ) -> Box)>> + 'a> + { let prefix = outgoing_kind.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); + return Box::new(self.servernameevent_data.scan_prefix(prefix).map( + |(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k)), + )); } - fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()> { + fn mark_as_active( + &self, + events: &[(SendingEventType, Vec)], + ) -> Result<()> { for (e, key) in events { let value = if let SendingEventType::Edu(value) = &e { &**value @@ -110,18 +120,24 @@ impl service::sending::Data for KeyValueDatabase { Ok(()) } - fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + fn set_latest_educount( + &self, + server_name: &ServerName, + last_count: u64, + ) -> Result<()> { self.servername_educount .insert(server_name.as_bytes(), &last_count.to_be_bytes()) } fn get_latest_educount(&self, server_name: &ServerName) -> Result { - self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) + self.servername_educount.get(server_name.as_bytes())?.map_or( + Ok(0), + |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid u64 in servername_educount.") + }) + }, + ) } } @@ -132,15 +148,17 @@ fn parse_servercurrentevent( ) -> Result<(OutgoingKind, SendingEventType)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { - let mut parts = key[1..].splitn(2, |&b| b == 0xff); + let mut parts = key[1..].splitn(2, |&b| b == 0xFF); let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts.next().ok_or_else(|| { + Error::bad_database("Invalid bytes in servercurrentpdus.") + })?; let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") + Error::bad_database( + "Invalid server bytes in server_currenttransaction", + ) })?; ( @@ -152,23 +170,27 @@ fn parse_servercurrentevent( }, ) } else if key.starts_with(b"$") { - let mut parts = key[1..].splitn(3, |&b| b == 0xff); + let mut parts = key[1..].splitn(3, |&b| b == 0xFF); let user = parts.next().expect("splitn always returns one element"); - let user_string = utils::string_from_bytes(user) - .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; - let user_id = UserId::parse(user_string) - .map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; + let user_string = utils::string_from_bytes(user).map_err(|_| { + Error::bad_database("Invalid user string in servercurrentevent") + })?; + let user_id = UserId::parse(user_string).map_err(|_| { + Error::bad_database("Invalid user id in servercurrentevent") + })?; - let pushkey = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let pushkey_string = utils::string_from_bytes(pushkey) - .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; + let pushkey = parts.next().ok_or_else(|| { + Error::bad_database("Invalid bytes in servercurrentpdus.") + })?; + let pushkey_string = + utils::string_from_bytes(pushkey).map_err(|_| { + Error::bad_database("Invalid pushkey in servercurrentevent") + })?; - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts.next().ok_or_else(|| { + Error::bad_database("Invalid bytes in servercurrentpdus.") + })?; ( OutgoingKind::Push(user_id, pushkey_string), @@ -180,20 +202,24 @@ fn parse_servercurrentevent( }, ) } else { - let mut parts = key.splitn(2, |&b| b == 0xff); + let mut parts = key.splitn(2, |&b| b == 0xFF); let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts.next().ok_or_else(|| { + Error::bad_database("Invalid bytes in servercurrentpdus.") + })?; let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") + Error::bad_database( + "Invalid server bytes in server_currenttransaction", + ) })?; ( OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { - Error::bad_database("Invalid server string in server_currenttransaction") + Error::bad_database( + "Invalid server string in server_currenttransaction", + ) })?), if value.is_empty() { SendingEventType::Pdu(event.to_vec()) diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index b3bd05f4..345b74a1 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -11,9 +11,11 @@ impl service::transaction_ids::Data for KeyValueDatabase { data: &[u8], ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xff); + key.push(0xFF); + key.extend_from_slice( + device_id.map(DeviceId::as_bytes).unwrap_or_default(), + ); + key.push(0xFF); key.extend_from_slice(txn_id.as_bytes()); self.userdevicetxnid_response.insert(&key, data)?; @@ -28,9 +30,11 @@ impl service::transaction_ids::Data for KeyValueDatabase { txn_id: &TransactionId, ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xff); + key.push(0xFF); + key.extend_from_slice( + device_id.map(DeviceId::as_bytes).unwrap_or_default(), + ); + key.push(0xFF); key.extend_from_slice(txn_id.as_bytes()); // If there's no entry, this is a new transaction diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 20a0357d..3452b990 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -13,13 +13,10 @@ impl service::uiaa::Data for KeyValueDatabase { session: &str, request: &CanonicalJsonValue, ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); + self.userdevicesessionid_uiaarequest.write().unwrap().insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); Ok(()) } @@ -33,7 +30,11 @@ impl service::uiaa::Data for KeyValueDatabase { self.userdevicesessionid_uiaarequest .read() .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .get(&( + user_id.to_owned(), + device_id.to_owned(), + session.to_owned(), + )) .map(ToOwned::to_owned) } @@ -45,19 +46,19 @@ impl service::uiaa::Data for KeyValueDatabase { uiaainfo: Option<&UiaaInfo>, ) -> Result<()> { let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); + userdevicesessionid.push(0xFF); userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); + userdevicesessionid.push(0xFF); userdevicesessionid.extend_from_slice(session.as_bytes()); if let Some(uiaainfo) = uiaainfo { self.userdevicesessionid_uiaainfo.insert( &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + &serde_json::to_vec(&uiaainfo) + .expect("UiaaInfo::to_vec always works"), )?; } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; + self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?; } Ok(()) @@ -70,9 +71,9 @@ impl service::uiaa::Data for KeyValueDatabase { session: &str, ) -> Result { let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); + userdevicesessionid.push(0xFF); userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); + userdevicesessionid.push(0xFF); userdevicesessionid.extend_from_slice(session.as_bytes()); serde_json::from_slice( @@ -84,6 +85,8 @@ impl service::uiaa::Data for KeyValueDatabase { "UIAA session does not exist.", ))?, ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + .map_err(|_| { + Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.") + }) } } diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 840e4788..0cef5414 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -5,8 +5,8 @@ use ruma::{ encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, - OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, + OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, }; use tracing::warn; @@ -40,67 +40,100 @@ impl service::users::Data for KeyValueDatabase { } /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xff); + fn find_from_token( + &self, + token: &str, + ) -> Result> { + self.token_userdeviceid.get(token.as_bytes())?.map_or( + Ok(None), + |bytes| { + let mut parts = bytes.split(|&b| b == 0xFF); let user_bytes = parts.next().ok_or_else(|| { - Error::bad_database("User ID in token_userdeviceid is invalid.") + Error::bad_database( + "User ID in token_userdeviceid is invalid.", + ) })?; let device_bytes = parts.next().ok_or_else(|| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") + Error::bad_database( + "Device ID in token_userdeviceid is invalid.", + ) })?; Ok(Some(( - UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid unicode.") - })?) + UserId::parse( + utils::string_from_bytes(user_bytes).map_err(|_| { + Error::bad_database( + "User ID in token_userdeviceid is invalid \ + unicode.", + ) + })?, + ) .map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid.") + Error::bad_database( + "User ID in token_userdeviceid is invalid.", + ) })?, utils::string_from_bytes(device_bytes).map_err(|_| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") + Error::bad_database( + "Device ID in token_userdeviceid is invalid.", + ) })?, ))) - }) + }, + ) } /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a> { + fn iter<'a>( + &'a self, + ) -> Box> + 'a> { Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in userid_password is invalid unicode.") + Error::bad_database( + "User ID in userid_password is invalid unicode.", + ) })?) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + .map_err(|_| { + Error::bad_database("User ID in userid_password is invalid.") + }) })) } /// Returns a list of local users as list of usernames. /// - /// A user account is considered `local` if the length of it's password is greater then zero. + /// A user account is considered `local` if the length of it's password is + /// greater then zero. fn list_local_users(&self) -> Result> { let users: Vec = self .userid_password .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) + .filter_map(|(username, pw)| { + get_username_with_valid_password(&username, &pw) + }) .collect(); Ok(users) } /// Returns the password hash for the given user. fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { + self.userid_password.get(user_id.as_bytes())?.map_or( + Ok(None), + |bytes| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") + Error::bad_database( + "Password hash in db is not valid string.", + ) })?)) - }) + }, + ) } /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + fn set_password( + &self, + user_id: &UserId, + password: Option<&str>, + ) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_password_hash(password) { self.userid_password @@ -120,17 +153,23 @@ impl service::users::Data for KeyValueDatabase { /// Returns the `displayname` of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { + self.userid_displayname.get(user_id.as_bytes())?.map_or( + Ok(None), + |bytes| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("Displayname in db is invalid.") })?)) - }) + }, + ) } - /// Sets a new `displayname` or removes it if `displayname` is `None`. You still need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + /// Sets a new `displayname` or removes it if `displayname` is `None`. You + /// still need to nofify all rooms of this change. + fn set_displayname( + &self, + user_id: &UserId, + displayname: Option, + ) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname .insert(user_id.as_bytes(), displayname.as_bytes())?; @@ -147,17 +186,25 @@ impl service::users::Data for KeyValueDatabase { .get(user_id.as_bytes())? .map(|bytes| { utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) + .map_err(|_| { + Error::bad_database("Avatar URL in db is invalid.") + }) .map(Into::into) }) .transpose() } /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + fn set_avatar_url( + &self, + user_id: &UserId, + avatar_url: Option, + ) -> Result<()> { if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + self.userid_avatarurl.insert( + user_id.as_bytes(), + avatar_url.to_string().as_bytes(), + )?; } else { self.userid_avatarurl.remove(user_id.as_bytes())?; } @@ -170,8 +217,9 @@ impl service::users::Data for KeyValueDatabase { self.userid_blurhash .get(user_id.as_bytes())? .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + let s = utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Avatar URL in db is invalid.") + })?; Ok(s) }) @@ -179,7 +227,11 @@ impl service::users::Data for KeyValueDatabase { } /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + fn set_blurhash( + &self, + user_id: &UserId, + blurhash: Option, + ) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash .insert(user_id.as_bytes(), blurhash.as_bytes())?; @@ -204,11 +256,10 @@ impl service::users::Data for KeyValueDatabase { ); let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); - self.userid_devicelistversion - .increment(user_id.as_bytes())?; + self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( &userdeviceid, @@ -228,9 +279,13 @@ impl service::users::Data for KeyValueDatabase { } /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + fn remove_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); // Remove tokens @@ -241,7 +296,7 @@ impl service::users::Data for KeyValueDatabase { // Remove todevice events let mut prefix = userdeviceid.clone(); - prefix.push(0xff); + prefix.push(0xFF); for (key, _) in self.todeviceid_events.scan_prefix(prefix) { self.todeviceid_events.remove(&key)?; @@ -249,8 +304,7 @@ impl service::users::Data for KeyValueDatabase { // TODO: Remove onetimekeys - self.userid_devicelistversion - .increment(user_id.as_bytes())?; + self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userdeviceid_metadata.remove(&userdeviceid)?; @@ -263,29 +317,34 @@ impl service::users::Data for KeyValueDatabase { user_id: &UserId, ) -> Box> + 'a> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("UserDevice ID in db is invalid.") - })?, + Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map( + |(bytes, _)| { + Ok(utils::string_from_bytes( + bytes.rsplit(|&b| b == 0xFF).next().ok_or_else(|| { + Error::bad_database("UserDevice ID in db is invalid.") + })?, + ) + .map_err(|_| { + Error::bad_database( + "Device ID in userdeviceid_metadata is invalid.", ) - .map_err(|_| { - Error::bad_database("Device ID in userdeviceid_metadata is invalid.") - })? - .into()) - }), - ) + })? + .into()) + }, + )) } /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + fn set_token( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); assert!( @@ -300,10 +359,8 @@ impl service::users::Data for KeyValueDatabase { } // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; + self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?; Ok(()) } @@ -316,17 +373,19 @@ impl service::users::Data for KeyValueDatabase { one_time_key_value: &Raw, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); assert!( self.userdeviceid_metadata.get(&key)?.is_some(), - "devices should have metadata and this method should only be called with existing devices" + "devices should have metadata and this method should only be \ + called with existing devices" ); - key.push(0xff); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update + // everything, because there are no wrapping quotation marks + // anymore) key.extend_from_slice( serde_json::to_string(one_time_key_key) .expect("DeviceKeyId::to_string always works") @@ -335,7 +394,8 @@ impl service::users::Data for KeyValueDatabase { self.onetimekeyid_onetimekeys.insert( &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + &serde_json::to_vec(&one_time_key_value) + .expect("OneTimeKey::to_vec always works"), )?; self.userid_lastonetimekeyupdate.insert( @@ -347,13 +407,16 @@ impl service::users::Data for KeyValueDatabase { } fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { + self.userid_lastonetimekeyupdate.get(user_id.as_bytes())?.map_or( + Ok(0), + |bytes| { utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") + Error::bad_database( + "Count in roomid_lastroomactiveupdate is invalid.", + ) }) - }) + }, + ) } fn take_one_time_key( @@ -363,9 +426,9 @@ impl service::users::Data for KeyValueDatabase { key_algorithm: &DeviceKeyAlgorithm, ) -> Result)>> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); // Annoying quotation mark prefix.push(b'"'); prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); @@ -384,13 +447,18 @@ impl service::users::Data for KeyValueDatabase { Ok(( serde_json::from_slice( - key.rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + key.rsplit(|&b| b == 0xFF).next().ok_or_else(|| { + Error::bad_database( + "OneTimeKeyId in db is invalid.", + ) + })?, ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + .map_err(|_| { + Error::bad_database("OneTimeKeyId in db is invalid.") + })?, + serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("OneTimeKeys in db are invalid.") + })?, )) }) .transpose() @@ -402,25 +470,31 @@ impl service::users::Data for KeyValueDatabase { device_id: &DeviceId, ) -> Result> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); let mut counts = BTreeMap::new(); - for algorithm in - self.onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("OneTimeKey ID in db is invalid.") - })?, - ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), + for algorithm in self + .onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::( + bytes.rsplit(|&b| b == 0xFF).next().ok_or_else( + || { + Error::bad_database( + "OneTimeKey ID in db is invalid.", + ) + }, + )?, ) - }) + .map_err(|_| { + Error::bad_database("DeviceKeyId in db is invalid.") + })? + .algorithm(), + ) + }) { *counts.entry(algorithm?).or_default() += UInt::from(1_u32); } @@ -435,12 +509,13 @@ impl service::users::Data for KeyValueDatabase { device_keys: &Raw, ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); self.keyid_key.insert( &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + &serde_json::to_vec(&device_keys) + .expect("DeviceKeys::to_vec always works"), )?; self.mark_device_key_update(user_id)?; @@ -458,30 +533,33 @@ impl service::users::Data for KeyValueDatabase { ) -> Result<()> { // TODO: Check signatures let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; self.keyid_key .insert(&master_key_key, master_key.json().get().as_bytes())?; - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; + self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?; // Self-signing key if let Some(self_signing_key) = self_signing_key { let mut self_signing_key_ids = self_signing_key .deserialize() .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid self signing key", + ) })? .keys .into_values(); - let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained no key.", - ))?; + let self_signing_key_id = + self_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained no key.", + ))?; if self_signing_key_ids.next().is_some() { return Err(Error::BadRequest( @@ -491,7 +569,8 @@ impl service::users::Data for KeyValueDatabase { } let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + self_signing_key_key + .extend_from_slice(self_signing_key_id.as_bytes()); self.keyid_key.insert( &self_signing_key_key, @@ -507,15 +586,19 @@ impl service::users::Data for KeyValueDatabase { let mut user_signing_key_ids = user_signing_key .deserialize() .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid user signing key", + ) })? .keys .into_values(); - let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained no key.", - ))?; + let user_signing_key_id = + user_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained no key.", + ))?; if user_signing_key_ids.next().is_some() { return Err(Error::BadRequest( @@ -525,7 +608,8 @@ impl service::users::Data for KeyValueDatabase { } let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + user_signing_key_key + .extend_from_slice(user_signing_key_id.as_bytes()); self.keyid_key.insert( &user_signing_key_key, @@ -551,32 +635,44 @@ impl service::users::Data for KeyValueDatabase { sender_id: &UserId, ) -> Result<()> { let mut key = target_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(key_id.as_bytes()); - let mut cross_signing_key: serde_json::Value = - serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( + let mut cross_signing_key: serde_json::Value = serde_json::from_slice( + &self.keyid_key.get(&key)?.ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Tried to sign nonexistent key.", - ))?) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + ))?, + ) + .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; let signatures = cross_signing_key .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .ok_or_else(|| { + Error::bad_database("key in keyid_key has no signatures field.") + })? .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .ok_or_else(|| { + Error::bad_database( + "key in keyid_key has invalid signatures field.", + ) + })? .entry(sender_id.to_string()) .or_insert_with(|| serde_json::Map::new().into()); signatures .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .ok_or_else(|| { + Error::bad_database( + "signatures in keyid_key for a user is invalid.", + ) + })? .insert(signature.0, signature.1.into()); self.keyid_key.insert( &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + &serde_json::to_vec(&cross_signing_key) + .expect("CrossSigningKey::to_vec always works"), )?; self.mark_device_key_update(target_id)?; @@ -591,7 +687,7 @@ impl service::users::Data for KeyValueDatabase { to: Option, ) -> Box> + 'a> { let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); let mut start = prefix.clone(); start.extend_from_slice(&(from + 1).to_be_bytes()); @@ -603,26 +699,39 @@ impl service::users::Data for KeyValueDatabase { .iter_from(&start, false) .take_while(move |(k, _)| { k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { + && if let Some(current) = + k.splitn(2, |&b| b == 0xFF).nth(1) + { if let Ok(c) = utils::u64_from_bytes(current) { c <= to } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + warn!( + "BadDatabase: Could not parse \ + keychangeid_userid bytes" + ); false } } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); + warn!( + "BadDatabase: Could not parse \ + keychangeid_userid" + ); false } }) .map(|(_, bytes)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "User ID in devicekeychangeid_userid is invalid unicode.", - ) - })?) + UserId::parse(utils::string_from_bytes(&bytes).map_err( + |_| { + Error::bad_database( + "User ID in devicekeychangeid_userid is \ + invalid unicode.", + ) + }, + )?) .map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid.") + Error::bad_database( + "User ID in devicekeychangeid_userid is invalid.", + ) }) }), ) @@ -647,14 +756,14 @@ impl service::users::Data for KeyValueDatabase { } let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(&count); self.keychangeid_userid.insert(&key, user_id.as_bytes())?; } let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(&count); self.keychangeid_userid.insert(&key, user_id.as_bytes())?; @@ -667,7 +776,7 @@ impl service::users::Data for KeyValueDatabase { device_id: &DeviceId, ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { @@ -683,11 +792,11 @@ impl service::users::Data for KeyValueDatabase { master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let master_key = master_key.deserialize().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key") + })?; let mut master_key_ids = master_key.keys.values(); let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( ErrorKind::InvalidParam, @@ -712,8 +821,12 @@ impl service::users::Data for KeyValueDatabase { allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + let mut cross_signing_key = serde_json::from_slice::< + serde_json::Value, + >(&bytes) + .map_err(|_| { + Error::bad_database("CrossSigningKey in db is invalid.") + })?; clean_signatures( &mut cross_signing_key, sender_user, @@ -754,16 +867,20 @@ impl service::users::Data for KeyValueDatabase { }) } - fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { + fn get_user_signing_key( + &self, + user_id: &UserId, + ) -> Result>> { + self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or( + Ok(None), + |key| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Error::bad_database("CrossSigningKey in db is invalid.") })?)) }) - }) + }, + ) } fn add_to_device_event( @@ -775,9 +892,9 @@ impl service::users::Data for KeyValueDatabase { content: serde_json::Value, ) -> Result<()> { let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); let mut json = serde_json::Map::new(); @@ -785,7 +902,8 @@ impl service::users::Data for KeyValueDatabase { json.insert("sender".to_owned(), sender.to_string().into()); json.insert("content".to_owned(), content); - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + let value = + serde_json::to_vec(&json).expect("Map::to_vec always works"); self.todeviceid_events.insert(&key, &value)?; @@ -800,15 +918,14 @@ impl service::users::Data for KeyValueDatabase { let mut events = Vec::new(); let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, - ); + events.push(serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("Event in todeviceid_events is invalid.") + })?); } Ok(events) @@ -821,9 +938,9 @@ impl service::users::Data for KeyValueDatabase { until: u64, ) -> Result<()> { let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + prefix.push(0xFF); let mut last = prefix.clone(); last.extend_from_slice(&until.to_be_bytes()); @@ -836,8 +953,14 @@ impl service::users::Data for KeyValueDatabase { .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + utils::u64_from_bytes( + &key[key.len() - size_of::()..key.len()], + ) + .map_err(|_| { + Error::bad_database( + "ToDeviceId has invalid count bytes.", + ) + })?, )) }) .filter_map(Result::ok) @@ -856,7 +979,7 @@ impl service::users::Data for KeyValueDatabase { device: &Device, ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); assert!( @@ -864,12 +987,12 @@ impl service::users::Data for KeyValueDatabase { "this method should only be called with existing devices" ); - self.userid_devicelistversion - .increment(user_id.as_bytes())?; + self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), + &serde_json::to_vec(device) + .expect("Device::to_string always works"), )?; Ok(()) @@ -882,26 +1005,32 @@ impl service::users::Data for KeyValueDatabase { device_id: &DeviceId, ) -> Result> { let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); + userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { + self.userdeviceid_metadata.get(&userdeviceid)?.map_or( + Ok(None), + |bytes| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + Error::bad_database( + "Metadata in userdeviceid_metadata is invalid.", + ) })?)) - }) + }, + ) } fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { + self.userid_devicelistversion.get(user_id.as_bytes())?.map_or( + Ok(None), + |bytes| { utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map_err(|_| { + Error::bad_database("Invalid devicelistversion in db.") + }) .map(Some) - }) + }, + ) } fn all_devices_metadata<'a>( @@ -909,25 +1038,29 @@ impl service::users::Data for KeyValueDatabase { user_id: &UserId, ) -> Box> + 'a> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes).map_err(|_| { - Error::bad_database("Device in userdeviceid_metadata is invalid.") - }) - }), - ) + Box::new(self.userdeviceid_metadata.scan_prefix(key).map( + |(_, bytes)| { + serde_json::from_slice::(&bytes).map_err(|_| { + Error::bad_database( + "Device in userdeviceid_metadata is invalid.", + ) + }) + }, + )) } /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + fn create_filter( + &self, + user_id: &UserId, + filter: &FilterDefinition, + ) -> Result { let filter_id = utils::random_string(4); let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(filter_id.as_bytes()); self.userfilterid_filter.insert( @@ -938,9 +1071,13 @@ impl service::users::Data for KeyValueDatabase { Ok(filter_id) } - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result> { let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); + key.push(0xFF); key.extend_from_slice(filter_id.as_bytes()); let raw = self.userfilterid_filter.get(&key)?; @@ -956,9 +1093,12 @@ impl service::users::Data for KeyValueDatabase { /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. -/// If [`utils::string_from_bytes`] returns an error that username will be skipped -/// and the error will be logged. -fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { +/// If [`utils::string_from_bytes`] returns an error that username will be +/// skipped and the error will be logged. +fn get_username_with_valid_password( + username: &[u8], + password: &[u8], +) -> Option { // A valid password is not empty if password.is_empty() { None @@ -967,7 +1107,8 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option< Ok(u) => Some(u), Err(e) => { warn!( - "Failed to parse username while calling get_local_users(): {}", + "Failed to parse username while calling \ + get_local_users(): {}", e.to_string() ); None diff --git a/src/main.rs b/src/main.rs index 1b44322f..407a01b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,9 @@ use axum::{ routing::{any, get, on, MethodFilter}, Router, }; -use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; +use axum_server::{ + bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, +}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -50,16 +52,16 @@ use api::{client_server, server_server}; pub(crate) use config::Config; pub(crate) use database::KeyValueDatabase; pub(crate) use service::{pdu::PduEvent, Services}; -pub(crate) use utils::error::{Error, Result}; - #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] use tikv_jemallocator::Jemalloc; +pub(crate) use utils::error::{Error, Result}; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -pub(crate) static SERVICES: RwLock> = RwLock::new(None); +pub(crate) static SERVICES: RwLock> = + RwLock::new(None); /// Convenient access to the global [`Services`] instance pub(crate) fn services() -> &'static Services { @@ -71,9 +73,9 @@ pub(crate) fn services() -> &'static Services { /// Returns the current version of the crate with extra info if supplied /// -/// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string to -/// include it in parenthesis after the SemVer version. A common value are git -/// commit hashes. +/// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string +/// to include it in parenthesis after the SemVer version. A common value are +/// git commit hashes. fn version() -> String { let cargo_pkg_version = env!("CARGO_PKG_VERSION"); @@ -91,7 +93,8 @@ async fn main() { let raw_config = Figment::new() .merge( Toml::file(Env::var("GRAPEVINE_CONFIG").expect( - "The GRAPEVINE_CONFIG env var needs to be set. Example: /etc/grapevine.toml", + "The GRAPEVINE_CONFIG env var needs to be set. Example: \ + /etc/grapevine.toml", )) .nested(), ) @@ -100,7 +103,10 @@ async fn main() { let config = match raw_config.extract::() { Ok(s) => s, Err(e) => { - eprintln!("It looks like your config is invalid. The following error occurred: {e}"); + eprintln!( + "It looks like your config is invalid. The following error \ + occurred: {e}" + ); std::process::exit(1); } }; @@ -108,7 +114,9 @@ async fn main() { config.warn_deprecated(); if config.allow_jaeger { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); + opentelemetry::global::set_text_map_propagator( + opentelemetry_jaeger::Propagator::new(), + ); let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_auto_split_batch(true) .with_service_name("grapevine") @@ -120,7 +128,8 @@ async fn main() { Ok(s) => s, Err(e) => { eprintln!( - "It looks like your log config is invalid. The following error occurred: {e}" + "It looks like your log config is invalid. The following \ + error occurred: {e}" ); EnvFilter::try_new("warn").unwrap() } @@ -146,7 +155,10 @@ async fn main() { let filter_layer = match EnvFilter::try_new(&config.log) { Ok(s) => s, Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + eprintln!( + "It looks like your config is invalid. The following \ + error occured while parsing it: {e}" + ); EnvFilter::try_new("warn").unwrap() } }; @@ -163,7 +175,8 @@ async fn main() { // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 #[cfg(unix)] - maximize_fd_limit().expect("should be able to increase the soft limit to the hard limit"); + maximize_fd_limit() + .expect("should be able to increase the soft limit to the hard limit"); info!("Loading database"); if let Err(error) = KeyValueDatabase::load_or_create(config).await { @@ -190,17 +203,19 @@ async fn run_server() -> io::Result<()> { let middlewares = ServiceBuilder::new() .sensitive_headers([header::AUTHORIZATION]) .layer(axum::middleware::from_fn(spawn_task)) - .layer( - TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| { - let path = if let Some(path) = request.extensions().get::() { + .layer(TraceLayer::new_for_http().make_span_with( + |request: &http::Request<_>| { + let path = if let Some(path) = + request.extensions().get::() + { path.as_str() } else { request.uri().path() }; tracing::info_span!("http_request", %path) - }), - ) + }, + )) .layer(axum::middleware::from_fn(unrecognized_method)) .layer( CorsLayer::new() @@ -235,7 +250,8 @@ async fn run_server() -> io::Result<()> { match &config.tls { Some(tls) => { - let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; + let conf = + RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; let server = bind_rustls(addr, conf).handle(handle).serve(app); #[cfg(feature = "systemd")] @@ -411,9 +427,10 @@ fn routes(config: &Config) -> Router { .ruma_route(c2s::get_relating_events_route) .ruma_route(c2s::get_hierarchy_route); - // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes - // share one Ruma request / response type pair with {get,send}_state_event_for_key_route. - // These two endpoints also allow trailing slashes. + // Ruma doesn't have support for multiple paths for a single endpoint yet, + // and these routes share one Ruma request / response type pair with + // {get,send}_state_event_for_key_route. These two endpoints also allow + // trailing slashes. let router = router .route( "/_matrix/client/r0/rooms/:room_id/state/:event_type", @@ -483,9 +500,7 @@ fn routes(config: &Config) -> Router { async fn shutdown_signal(handle: ServerHandle) { let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); + signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] @@ -554,9 +569,9 @@ impl RouterExt for Router { } pub(crate) trait RumaHandler { - // Can't transform to a handler without boxing or relying on the nightly-only - // impl-trait-in-traits feature. Moving a small amount of extra logic into the trait - // allows bypassing both. + // Can't transform to a handler without boxing or relying on the + // nightly-only impl-trait-in-traits feature. Moving a small amount of + // extra logic into the trait allows bypassing both. fn add_to_router(self, router: Router) -> Router; } diff --git a/src/service.rs b/src/service.rs index 710edaca..65548616 100644 --- a/src/service.rs +++ b/src/service.rs @@ -4,10 +4,9 @@ use std::{ }; use lru_cache::LruCache; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, Mutex, RwLock}; use crate::{Config, Result}; -use tokio::sync::RwLock; pub(crate) mod account_data; pub(crate) mod admin; @@ -58,10 +57,14 @@ impl Services { ) -> Result { Ok(Self { appservice: appservice::Service::build(db)?, - pusher: pusher::Service { db }, + pusher: pusher::Service { + db, + }, rooms: rooms::Service { alias: db, - auth_chain: rooms::auth_chain::Service { db }, + auth_chain: rooms::auth_chain::Service { + db, + }, directory: db, edus: rooms::edus::Service { read_receipt: db, @@ -78,10 +81,14 @@ impl Services { }, metadata: db, outlier: db, - pdu_metadata: rooms::pdu_metadata::Service { db }, + pdu_metadata: rooms::pdu_metadata::Service { + db, + }, search: db, short: db, - state: rooms::state::Service { db }, + state: rooms::state::Service { + db, + }, state_accessor: rooms::state_accessor::Service { db, #[allow( @@ -101,7 +108,9 @@ impl Services { (100.0 * config.cache_capacity_modifier) as usize, )), }, - state_cache: rooms::state_cache::Service { db }, + state_cache: rooms::state_cache::Service { + db, + }, state_compressor: rooms::state_compressor::Service { db, #[allow( @@ -117,14 +126,18 @@ impl Services { db, lasttimelinecount_cache: Mutex::new(HashMap::new()), }, - threads: rooms::threads::Service { db }, + threads: rooms::threads::Service { + db, + }, spaces: rooms::spaces::Service { roomid_spacechunk_cache: Mutex::new(LruCache::new(200)), }, user: db, }, transaction_ids: db, - uiaa: uiaa::Service { db }, + uiaa: uiaa::Service { + db, + }, users: users::Service { db, connections: StdMutex::new(BTreeMap::new()), @@ -132,14 +145,18 @@ impl Services { account_data: db, admin: admin::Service::build(), key_backups: db, - media: media::Service { db }, + media: media::Service { + db, + }, sending: sending::Service::build(db, &config), globals: globals::Service::load(db, config)?, }) } + async fn memory_usage(&self) -> String { - let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); + let lazy_load_waiting = + self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); let server_visibility_cache = self .rooms .state_accessor @@ -154,21 +171,12 @@ impl Services { .lock() .unwrap() .len(); - let stateinfo_cache = self - .rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .len(); - let lasttimelinecount_cache = self - .rooms - .timeline - .lasttimelinecount_cache - .lock() - .await - .len(); - let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len(); + let stateinfo_cache = + self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); + let lasttimelinecount_cache = + self.rooms.timeline.lasttimelinecount_cache.lock().await.len(); + let roomid_spacechunk_cache = + self.rooms.spaces.roomid_spacechunk_cache.lock().await.len(); format!( "\ @@ -177,18 +185,13 @@ server_visibility_cache: {server_visibility_cache} user_visibility_cache: {user_visibility_cache} stateinfo_cache: {stateinfo_cache} lasttimelinecount_cache: {lasttimelinecount_cache} -roomid_spacechunk_cache: {roomid_spacechunk_cache}\ - " +roomid_spacechunk_cache: {roomid_spacechunk_cache}" ) } + async fn clear_caches(&self, amount: u32) { if amount > 0 { - self.rooms - .lazy_loading - .lazy_load_waiting - .lock() - .await - .clear(); + self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear(); } if amount > 1 { self.rooms @@ -207,28 +210,13 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\ .clear(); } if amount > 3 { - self.rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .clear(); + self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear(); } if amount > 4 { - self.rooms - .timeline - .lasttimelinecount_cache - .lock() - .await - .clear(); + self.rooms.timeline.lasttimelinecount_cache.lock().await.clear(); } if amount > 5 { - self.rooms - .spaces - .roomid_spacechunk_cache - .lock() - .await - .clear(); + self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear(); } } } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 92a17cab..04d4bf76 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,14 +1,16 @@ use std::collections::HashMap; -use crate::Result; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; +use crate::Result; + pub(crate) trait Data: Send + Sync { - /// Places one event in the account data of the user and removes the previous entry. + /// Places one event in the account data of the user and removes the + /// previous entry. fn update( &self, room_id: Option<&RoomId>, diff --git a/src/service/admin.rs b/src/service/admin.rs index d591a514..87b645f1 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -16,7 +16,9 @@ use ruma::{ canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + history_visibility::{ + HistoryVisibility, RoomHistoryVisibilityEventContent, + }, join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, message::RoomMessageEventContent, @@ -26,19 +28,19 @@ use ruma::{ }, TimelineEventType, }, - EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, + ServerName, UserId, }; use serde_json::value::to_raw_value; use tokio::sync::{mpsc, Mutex, RwLock}; use tracing::warn; +use super::pdu::PduBuilder; use crate::{ api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, services, utils, Error, PduEvent, Result, }; -use super::pdu::PduBuilder; - #[cfg_attr(test, derive(Debug))] #[derive(Parser)] #[command(name = "@grapevine:server.name:", version = env!("CARGO_PKG_VERSION"))] @@ -46,11 +48,12 @@ enum AdminCommand { #[command(verbatim_doc_comment)] /// Register an appservice using its registration YAML /// - /// This command needs a YAML generated by an appservice (such as a bridge), - /// which must be provided in a Markdown code-block below the command. + /// This command needs a YAML generated by an appservice (such as a + /// bridge), which must be provided in a Markdown code-block below the + /// command. /// - /// Registering a new bridge using the ID of an existing bridge will replace - /// the old one. + /// Registering a new bridge using the ID of an existing bridge will + /// replace the old one. /// /// [commandbody]() /// # ``` @@ -95,8 +98,9 @@ enum AdminCommand { /// /// Users will not be removed from joined rooms by default. /// Can be overridden with --leave-rooms flag. - /// Removing a mass amount of users from a room may cause a significant amount of leave events. - /// The time to leave rooms may depend significantly on joined rooms and servers. + /// Removing a mass amount of users from a room may cause a significant + /// amount of leave events. The time to leave rooms may depend + /// significantly on joined rooms and servers. /// /// [commandbody]() /// # ``` @@ -138,11 +142,17 @@ enum AdminCommand { /// Print database memory usage statistics MemoryUsage, - /// Clears all of Grapevine's database caches with index smaller than the amount - ClearDatabaseCaches { amount: u32 }, + /// Clears all of Grapevine's database caches with index smaller than the + /// amount + ClearDatabaseCaches { + amount: u32, + }, - /// Clears all of Grapevine's service caches with index smaller than the amount - ClearServiceCaches { amount: u32 }, + /// Clears all of Grapevine's service caches with index smaller than the + /// amount + ClearServiceCaches { + amount: u32, + }, /// Show configuration values ShowConfig, @@ -162,9 +172,13 @@ enum AdminCommand { }, /// Disables incoming federation handling for a room. - DisableRoom { room_id: Box }, + DisableRoom { + room_id: Box, + }, /// Enables incoming federation handling for a room again. - EnableRoom { room_id: Box }, + EnableRoom { + room_id: Box, + }, /// Verify json signatures /// [commandbody]() @@ -267,31 +281,38 @@ impl Service { } pub(crate) fn process_message(&self, room_message: String) { - self.sender - .send(AdminRoomEvent::ProcessMessage(room_message)) - .unwrap(); + self.sender.send(AdminRoomEvent::ProcessMessage(room_message)).unwrap(); } - pub(crate) fn send_message(&self, message_content: RoomMessageEventContent) { - self.sender - .send(AdminRoomEvent::SendMessage(message_content)) - .unwrap(); + pub(crate) fn send_message( + &self, + message_content: RoomMessageEventContent, + ) { + self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap(); } // Parse and process a message from the admin room - async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { + async fn process_admin_message( + &self, + room_message: String, + ) -> RoomMessageEventContent { let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); - let command_line = lines.next().expect("each string has at least one line"); + let command_line = + lines.next().expect("each string has at least one line"); let body: Vec<_> = lines.collect(); let admin_command = match Self::parse_admin_command(command_line) { Ok(command) => command, Err(error) => { let server_name = services().globals.server_name(); - let message = error.replace("server.name", server_name.as_str()); + let message = + error.replace("server.name", server_name.as_str()); let html_message = Self::usage_to_html(&message, server_name); - return RoomMessageEventContent::text_html(message, html_message); + return RoomMessageEventContent::text_html( + message, + html_message, + ); } }; @@ -299,22 +320,28 @@ impl Service { Ok(reply_message) => reply_message, Err(error) => { let markdown_message = format!( - "Encountered an error while handling the command:\n\ - ```\n{error}\n```", + "Encountered an error while handling the \ + command:\n```\n{error}\n```", ); let html_message = format!( - "Encountered an error while handling the command:\n\ -
\n{error}\n
", + "Encountered an error while handling the \ + command:\n
\n{error}\n
", ); - RoomMessageEventContent::text_html(markdown_message, html_message) + RoomMessageEventContent::text_html( + markdown_message, + html_message, + ) } } } // Parse chat messages from the admin room into an AdminCommand object - fn parse_admin_command(command_line: &str) -> std::result::Result { - // Note: argv[0] is `@grapevine:servername:`, which is treated as the main command + fn parse_admin_command( + command_line: &str, + ) -> std::result::Result { + // Note: argv[0] is `@grapevine:servername:`, which is treated as the + // main command let mut argv: Vec<_> = command_line.split_whitespace().collect(); // Replace `help command` with `command --help` @@ -342,18 +369,26 @@ impl Service { ) -> Result { let reply_message_content = match command { AdminCommand::RegisterAppservice => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + if body.len() > 2 + && body[0].trim() == "```" + && body.last().unwrap().trim() == "```" { let appservice_config = body[1..body.len() - 1].join("\n"); - let parsed_config = serde_yaml::from_str::(&appservice_config); + let parsed_config = serde_yaml::from_str::( + &appservice_config, + ); match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml).await { - Ok(id) => RoomMessageEventContent::text_plain(format!( - "Appservice registered with ID: {id}." - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to register appservice: {e}" - )), + Ok(yaml) => match services() + .appservice + .register_appservice(yaml) + .await + { + Ok(id) => RoomMessageEventContent::text_plain( + format!("Appservice registered with ID: {id}."), + ), + Err(e) => RoomMessageEventContent::text_plain( + format!("Failed to register appservice: {e}"), + ), }, Err(e) => RoomMessageEventContent::text_plain(format!( "Could not parse appservice config: {e}" @@ -361,7 +396,8 @@ impl Service { } } else { RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", + "Expected code block in command body. Add --help for \ + details.", ) } } @@ -372,7 +408,9 @@ impl Service { .unregister_appservice(&appservice_identifier) .await { - Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), + Ok(()) => RoomMessageEventContent::text_plain( + "Appservice unregistered.", + ), Err(e) => RoomMessageEventContent::text_plain(format!( "Failed to unregister appservice: {e}" )), @@ -407,17 +445,25 @@ impl Service { ); RoomMessageEventContent::text_plain(output) } - AdminCommand::ListLocalUsers => match services().users.list_local_users() { + AdminCommand::ListLocalUsers => match services() + .users + .list_local_users() + { Ok(users) => { - let mut msg: String = format!("Found {} local user account(s):\n", users.len()); + let mut msg: String = format!( + "Found {} local user account(s):\n", + users.len() + ); msg += &users.join("\n"); RoomMessageEventContent::text_plain(&msg) } Err(e) => RoomMessageEventContent::text_plain(e.to_string()), }, AdminCommand::IncomingFederation => { - let map = services().globals.roomid_federationhandletime.read().await; - let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); + let map = + services().globals.roomid_federationhandletime.read().await; + let mut msg: String = + format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { let elapsed = i.elapsed(); @@ -431,17 +477,26 @@ impl Service { } RoomMessageEventContent::text_plain(&msg) } - AdminCommand::GetAuthChain { event_id } => { + AdminCommand::GetAuthChain { + event_id, + } => { let event_id = Arc::::from(event_id); - if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { + if let Some(event) = + services().rooms.timeline.get_pdu_json(&event_id)? + { let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + .ok_or_else(|| { + Error::bad_database("Invalid event in database") + })?; - let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { - Error::bad_database("Invalid room id field in event in database") - })?; + let room_id = + <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database( + "Invalid room id field in event in database", + ) + })?; let start = Instant::now(); let count = services() .rooms @@ -458,29 +513,47 @@ impl Service { } } AdminCommand::ParsePdu => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + if body.len() > 2 + && body[0].trim() == "```" + && body.last().unwrap().trim() == "```" { let string = body[1..body.len() - 1].join("\n"); match serde_json::from_str(&string) { Ok(value) => { - match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { + match ruma::signatures::reference_hash( + &value, + &RoomVersionId::V6, + ) { Ok(hash) => { - let event_id = EventId::parse(format!("${hash}")); + let event_id = + EventId::parse(format!("${hash}")); match serde_json::from_value::( - serde_json::to_value(value).expect("value is json"), + serde_json::to_value(value) + .expect("value is json"), ) { - Ok(pdu) => RoomMessageEventContent::text_plain(format!( - "EventId: {event_id:?}\n{pdu:#?}" - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "EventId: {event_id:?}\nCould not parse event: {e}" - )), + Ok(pdu) => { + RoomMessageEventContent::text_plain( + format!( + "EventId: {event_id:?}\\ + n{pdu:#?}" + ), + ) + } + Err(e) => { + RoomMessageEventContent::text_plain( + format!( + "EventId: {event_id:?}\\ + nCould not parse event: \ + {e}" + ), + ) + } } } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse PDU JSON: {e:?}" - )), + Err(e) => RoomMessageEventContent::text_plain( + format!("Could not parse PDU JSON: {e:?}"), + ), } } Err(e) => RoomMessageEventContent::text_plain(format!( @@ -488,10 +561,14 @@ impl Service { )), } } else { - RoomMessageEventContent::text_plain("Expected code block in command body.") + RoomMessageEventContent::text_plain( + "Expected code block in command body.", + ) } } - AdminCommand::GetPdu { event_id } => { + AdminCommand::GetPdu { + event_id, + } => { let mut outlier = false; let mut pdu_json = services() .rooms @@ -499,7 +576,8 @@ impl Service { .get_non_outlier_pdu_json(&event_id)?; if pdu_json.is_none() { outlier = true; - pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; + pdu_json = + services().rooms.timeline.get_pdu_json(&event_id)?; } match pdu_json { Some(json) => { @@ -516,7 +594,8 @@ impl Service { json_text ), format!( - "

{}

\n
{}\n
\n", + "

{}

\n
{}\n
\n", if outlier { "PDU is outlier" } else { @@ -526,7 +605,9 @@ impl Service { ), ) } - None => RoomMessageEventContent::text_plain("PDU not found."), + None => { + RoomMessageEventContent::text_plain("PDU not found.") + } } } AdminCommand::MemoryUsage => { @@ -537,30 +618,42 @@ impl Service { "Services:\n{response1}\n\nDatabase:\n{response2}" )) } - AdminCommand::ClearDatabaseCaches { amount } => { + AdminCommand::ClearDatabaseCaches { + amount, + } => { services().globals.db.clear_caches(amount); RoomMessageEventContent::text_plain("Done.") } - AdminCommand::ClearServiceCaches { amount } => { + AdminCommand::ClearServiceCaches { + amount, + } => { services().clear_caches(amount).await; RoomMessageEventContent::text_plain("Done.") } AdminCommand::ShowConfig => { // Construct and send the response - RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) + RoomMessageEventContent::text_plain(format!( + "{}", + services().globals.config + )) } - AdminCommand::ResetPassword { username } => { + AdminCommand::ResetPassword { + username, + } => { let user_id = match UserId::parse_with_server_name( username.as_str().to_lowercase(), services().globals.server_name(), ) { Ok(id) => id, Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {e}" - ))) + return Ok(RoomMessageEventContent::text_plain( + format!( + "The supplied username is not a valid \ + username: {e}" + ), + )) } }; @@ -589,23 +682,29 @@ impl Service { )); } - let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); + let new_password = + utils::random_string(AUTO_GEN_PASSWORD_LENGTH); match services() .users .set_password(&user_id, Some(new_password.as_str())) { Ok(()) => RoomMessageEventContent::text_plain(format!( - "Successfully reset the password for user {user_id}: {new_password}" + "Successfully reset the password for user {user_id}: \ + {new_password}" )), Err(e) => RoomMessageEventContent::text_plain(format!( "Couldn't reset the password for user {user_id}: {e}" )), } } - AdminCommand::CreateUser { username, password } => { - let password = - password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); + AdminCommand::CreateUser { + username, + password, + } => { + let password = password.unwrap_or_else(|| { + utils::random_string(AUTO_GEN_PASSWORD_LENGTH) + }); // Validate user id let user_id = match UserId::parse_with_server_name( username.as_str().to_lowercase(), @@ -613,9 +712,12 @@ impl Service { ) { Ok(id) => id, Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {e}" - ))) + return Ok(RoomMessageEventContent::text_plain( + format!( + "The supplied username is not a valid \ + username: {e}" + ), + )) } }; if user_id.is_historical() { @@ -647,24 +749,32 @@ impl Service { .into(), &serde_json::to_value(PushRulesEvent { content: PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), + global: ruma::push::Ruleset::server_default( + &user_id, + ), }, }) .expect("to json value always works"), )?; - // we dont add a device since we're not the user, just the creator + // we dont add a device since we're not the user, just the + // creator // Inhibit login does not work for guests RoomMessageEventContent::text_plain(format!( - "Created user with user_id: {user_id} and password: {password}" + "Created user with user_id: {user_id} and password: \ + {password}" )) } - AdminCommand::DisableRoom { room_id } => { + AdminCommand::DisableRoom { + room_id, + } => { services().rooms.metadata.disable_room(&room_id, true)?; RoomMessageEventContent::text_plain("Room disabled.") } - AdminCommand::EnableRoom { room_id } => { + AdminCommand::EnableRoom { + room_id, + } => { services().rooms.metadata.disable_room(&room_id, false)?; RoomMessageEventContent::text_plain("Room enabled.") } @@ -677,13 +787,16 @@ impl Service { RoomMessageEventContent::text_plain(format!( "User {user_id} doesn't exist on this server" )) - } else if user_id.server_name() != services().globals.server_name() { + } else if user_id.server_name() + != services().globals.server_name() + { RoomMessageEventContent::text_plain(format!( "User {user_id} is not from this server" )) } else { RoomMessageEventContent::text_plain(format!( - "Making {user_id} leave all rooms before deactivation..." + "Making {user_id} leave all rooms before \ + deactivation..." )); services().users.deactivate_account(&user_id)?; @@ -697,10 +810,18 @@ impl Service { )) } } - AdminCommand::DeactivateAll { leave_rooms, force } => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + AdminCommand::DeactivateAll { + leave_rooms, + force, + } => { + if body.len() > 2 + && body[0].trim() == "```" + && body.last().unwrap().trim() == "```" { - let users = body.clone().drain(1..body.len() - 1).collect::>(); + let users = body + .clone() + .drain(1..body.len() - 1) + .collect::>(); let mut user_ids = Vec::new(); let mut remote_ids = Vec::new(); @@ -710,7 +831,9 @@ impl Service { for &user in &users { match <&UserId>::try_from(user) { Ok(user_id) => { - if user_id.server_name() != services().globals.server_name() { + if user_id.server_name() + != services().globals.server_name() + { remote_ids.push(user_id); } else if !services().users.exists(user_id)? { non_existant_ids.push(user_id); @@ -727,39 +850,59 @@ impl Service { let mut markdown_message = String::new(); let mut html_message = String::new(); if !invalid_users.is_empty() { - markdown_message.push_str("The following user ids are not valid:\n```\n"); - html_message.push_str("The following user ids are not valid:\n
\n");
+                        markdown_message.push_str(
+                            "The following user ids are not valid:\n```\n",
+                        );
+                        html_message.push_str(
+                            "The following user ids are not valid:\n
\n",
+                        );
                         for invalid_user in invalid_users {
                             writeln!(markdown_message, "{invalid_user}")
-                                .expect("write to in-memory buffer should succeed");
-                            writeln!(html_message, "{invalid_user}")
-                                .expect("write to in-memory buffer should succeed");
+                                .expect(
+                                    "write to in-memory buffer should succeed",
+                                );
+                            writeln!(html_message, "{invalid_user}").expect(
+                                "write to in-memory buffer should succeed",
+                            );
                         }
                         markdown_message.push_str("```\n\n");
                         html_message.push_str("
\n\n"); } if !remote_ids.is_empty() { - markdown_message - .push_str("The following users are not from this server:\n```\n"); - html_message - .push_str("The following users are not from this server:\n
\n");
+                        markdown_message.push_str(
+                            "The following users are not from this \
+                             server:\n```\n",
+                        );
+                        html_message.push_str(
+                            "The following users are not from this \
+                             server:\n
\n",
+                        );
                         for remote_id in remote_ids {
-                            writeln!(markdown_message, "{remote_id}")
-                                .expect("write to in-memory buffer should succeed");
-                            writeln!(html_message, "{remote_id}")
-                                .expect("write to in-memory buffer should succeed");
+                            writeln!(markdown_message, "{remote_id}").expect(
+                                "write to in-memory buffer should succeed",
+                            );
+                            writeln!(html_message, "{remote_id}").expect(
+                                "write to in-memory buffer should succeed",
+                            );
                         }
                         markdown_message.push_str("```\n\n");
                         html_message.push_str("
\n\n"); } if !non_existant_ids.is_empty() { - markdown_message.push_str("The following users do not exist:\n```\n"); - html_message.push_str("The following users do not exist:\n
\n");
+                        markdown_message.push_str(
+                            "The following users do not exist:\n```\n",
+                        );
+                        html_message.push_str(
+                            "The following users do not exist:\n
\n",
+                        );
                         for non_existant_id in non_existant_ids {
                             writeln!(markdown_message, "{non_existant_id}")
-                                .expect("write to in-memory buffer should succeed");
-                            writeln!(html_message, "{non_existant_id}")
-                                .expect("write to in-memory buffer should succeed");
+                                .expect(
+                                    "write to in-memory buffer should succeed",
+                                );
+                            writeln!(html_message, "{non_existant_id}").expect(
+                                "write to in-memory buffer should succeed",
+                            );
                         }
                         markdown_message.push_str("```\n\n");
                         html_message.push_str("
\n\n"); @@ -775,21 +918,24 @@ impl Service { let mut admins = Vec::new(); if !force { - user_ids.retain(|&user_id| match services().users.is_admin(user_id) { - Ok(is_admin) => { - if is_admin { - admins.push(user_id.localpart()); - false - } else { - true + user_ids.retain(|&user_id| { + match services().users.is_admin(user_id) { + Ok(is_admin) => { + if is_admin { + admins.push(user_id.localpart()); + false + } else { + true + } } + Err(_) => false, } - Err(_) => false, }); } for &user_id in &user_ids { - if services().users.deactivate_account(user_id).is_ok() { + if services().users.deactivate_account(user_id).is_ok() + { deactivation_count += 1; } } @@ -807,16 +953,25 @@ impl Service { "Deactivated {deactivation_count} accounts." )) } else { - RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) + RoomMessageEventContent::text_plain(format!( + "Deactivated {} accounts.\nSkipped admin \ + accounts: {:?}. Use --force to deactivate admin \ + accounts", + deactivation_count, + admins.join(", ") + )) } } else { RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", + "Expected code block in command body. Add --help for \ + details.", ) } } AdminCommand::SignJson => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + if body.len() > 2 + && body[0].trim() == "```" + && body.last().unwrap().trim() == "```" { let string = body[1..body.len() - 1].join("\n"); match serde_json::from_str(&string) { @@ -827,20 +982,26 @@ impl Service { &mut value, ) .expect("our request json is what ruma expects"); - let json_text = serde_json::to_string_pretty(&value) - .expect("canonical json is valid json"); + let json_text = + serde_json::to_string_pretty(&value) + .expect("canonical json is valid json"); RoomMessageEventContent::text_plain(json_text) } - Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Invalid json: {e}" + )), } } else { RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", + "Expected code block in command body. Add --help for \ + details.", ) } } AdminCommand::VerifyJson => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + if body.len() > 2 + && body[0].trim() == "```" + && body.last().unwrap().trim() == "```" { let string = body[1..body.len() - 1].join("\n"); match serde_json::from_str(&string) { @@ -850,22 +1011,35 @@ impl Service { services() .rooms .event_handler - .fetch_required_signing_keys(&value, &pub_key_map) + .fetch_required_signing_keys( + &value, + &pub_key_map, + ) .await?; let pub_key_map = pub_key_map.read().await; - match ruma::signatures::verify_json(&pub_key_map, &value) { - Ok(()) => RoomMessageEventContent::text_plain("Signature correct"), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Signature verification failed: {e}" - )), + match ruma::signatures::verify_json( + &pub_key_map, + &value, + ) { + Ok(()) => RoomMessageEventContent::text_plain( + "Signature correct", + ), + Err(e) => RoomMessageEventContent::text_plain( + format!( + "Signature verification failed: {e}" + ), + ), } } - Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Invalid json: {e}" + )), } } else { RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", + "Expected code block in command body. Add --help for \ + details.", ) } } @@ -876,7 +1050,8 @@ impl Service { // Utility to turn clap's `--help` text to HTML. fn usage_to_html(text: &str, server_name: &ServerName) -> String { - // Replace `@grapevine:servername:-subcmdname` with `@grapevine:servername: subcmdname` + // Replace `@grapevine:servername:-subcmdname` with + // `@grapevine:servername: subcmdname` let localpart = if services().globals.config.conduit_compat { "conduit" } else { @@ -892,11 +1067,13 @@ impl Service { let text = text.replace("SUBCOMMAND", "COMMAND"); let text = text.replace("subcommand", "command"); - // Escape option names (e.g. ``) since they look like HTML tags + // Escape option names (e.g. ``) since they look like HTML + // tags let text = text.replace('<', "<").replace('>', ">"); // Italicize the first line (command name and version text) - let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); + let re = + Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); let text = re.replace_all(&text, "$1\n"); // Unmerge wrapped lines @@ -911,8 +1088,8 @@ impl Service { .expect("Regex compilation should not fail"); let text = re.replace_all(&text, "$1: $4"); - // Look for a `[commandbody]()` tag. If it exists, use all lines below it that - // start with a `#` in the USAGE section. + // Look for a `[commandbody]()` tag. If it exists, use all lines below + // it that start with a `#` in the USAGE section. let mut text_lines: Vec<&str> = text.lines().collect(); let command_body = text_lines .iter() @@ -936,8 +1113,11 @@ impl Service { // This makes the usage of e.g. `register-appservice` more accurate let re = Regex::new("(?m)^USAGE:\n (.*?)\n\n") .expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n
$1[nobr]\n[commandbodyblock]
") - .replace("[commandbodyblock]", &command_body) + re.replace_all( + &text, + "USAGE:\n
$1[nobr]\n[commandbodyblock]
", + ) + .replace("[commandbodyblock]", &command_body) }; // Add HTML line-breaks @@ -949,8 +1129,9 @@ impl Service { /// Create the admin room. /// - /// Users in this room are considered admins by grapevine, and the room can be - /// used to issue admin commands by talking to the server user inside it. + /// Users in this room are considered admins by grapevine, and the room can + /// be used to issue admin commands by talking to the server user inside + /// it. #[allow(clippy::too_many_lines)] pub(crate) async fn create_admin_room(&self) -> Result<()> { let room_id = RoomId::new(services().globals.server_name()); @@ -993,7 +1174,9 @@ impl Service { | RoomVersionId::V7 | RoomVersionId::V8 | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(grapevine_user.clone()), + | RoomVersionId::V10 => { + RoomCreateEventContent::new_v1(grapevine_user.clone()) + } RoomVersionId::V11 => RoomCreateEventContent::new_v11(), _ => unreachable!("Validity of room version already checked"), }; @@ -1008,7 +1191,8 @@ impl Service { .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), + content: to_raw_value(&content) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), redacts: None, @@ -1079,8 +1263,10 @@ impl Service { .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), + content: to_raw_value(&RoomJoinRulesEventContent::new( + JoinRule::Invite, + )) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), redacts: None, @@ -1098,9 +1284,11 @@ impl Service { .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new( - HistoryVisibility::Shared, - )) + content: to_raw_value( + &RoomHistoryVisibilityEventContent::new( + HistoryVisibility::Shared, + ), + ) .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), @@ -1134,15 +1322,18 @@ impl Service { .await?; // 5. Events implied by name and topic - let room_name = format!("{} Admin Room", services().globals.server_name()); + let room_name = + format!("{} Admin Room", services().globals.server_name()); services() .rooms .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(room_name)) - .expect("event is valid, we just created it"), + content: to_raw_value(&RoomNameEventContent::new( + room_name, + )) + .expect("event is valid, we just created it"), unsigned: None, state_key: Some(String::new()), redacts: None, @@ -1160,7 +1351,10 @@ impl Service { PduBuilder { event_type: TimelineEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", services().globals.server_name()), + topic: format!( + "Manage {}", + services().globals.server_name() + ), }) .expect("event is valid, we just created it"), unsigned: None, @@ -1174,9 +1368,10 @@ impl Service { .await?; // 6. Room alias - let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); + let alias: OwnedRoomAliasId = + format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); services() .rooms @@ -1206,7 +1401,8 @@ impl Service { /// Gets the room ID of the admin room /// - /// Errors are propagated from the database, and will have None if there is no admin room + /// Errors are propagated from the database, and will have None if there is + /// no admin room // Allowed because this function uses `services()` #[allow(clippy::unused_self)] pub(crate) fn get_admin_room(&self) -> Result> { @@ -1215,10 +1411,7 @@ impl Service { .try_into() .expect("#admins:server_name is a valid alias name"); - services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias) + services().rooms.alias.resolve_local_alias(&admin_room_alias) } /// Invite the user to the grapevine admin room. @@ -1356,11 +1549,13 @@ mod test { } fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) - .unwrap_err() - .to_string(); + let error = + AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) + .unwrap_err() + .to_string(); - // Search for a handful of keywords that suggest the help printed properly + // Search for a handful of keywords that suggest the help printed + // properly assert!(error.contains("Usage:")); assert!(error.contains("Commands:")); assert!(error.contains("Options:")); diff --git a/src/service/appservice.rs b/src/service/appservice.rs index 665d5a31..8abc7735 100644 --- a/src/service/appservice.rs +++ b/src/service/appservice.rs @@ -3,7 +3,6 @@ mod data; use std::collections::BTreeMap; pub(crate) use data::Data; - use futures_util::Future; use regex::RegexSet; use ruma::{ @@ -48,6 +47,8 @@ impl NamespaceRegex { } impl TryFrom> for NamespaceRegex { + type Error = regex::Error; + fn try_from(value: Vec) -> Result { let mut exclusive = vec![]; let mut non_exclusive = vec![]; @@ -73,8 +74,6 @@ impl TryFrom> for NamespaceRegex { }, }) } - - type Error = regex::Error; } /// Appservice registration combined with its compiled regular expressions. @@ -99,6 +98,8 @@ impl RegistrationInfo { } impl TryFrom for RegistrationInfo { + type Error = regex::Error; + fn try_from(value: Registration) -> Result { Ok(RegistrationInfo { users: value.namespaces.users.clone().try_into()?, @@ -107,8 +108,6 @@ impl TryFrom for RegistrationInfo { registration: value, }) } - - type Error = regex::Error; } pub(crate) struct Service { @@ -135,8 +134,12 @@ impl Service { registration_info: RwLock::new(registration_info), }) } + /// Registers an appservice and returns the ID to the caller. - pub(crate) async fn register_appservice(&self, yaml: Registration) -> Result { + pub(crate) async fn register_appservice( + &self, + yaml: Registration, + ) -> Result { //TODO: Check for collisions between exclusive appservice namespaces self.registration_info .write() @@ -151,19 +154,27 @@ impl Service { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub(crate) async fn unregister_appservice(&self, service_name: &str) -> Result<()> { + pub(crate) async fn unregister_appservice( + &self, + service_name: &str, + ) -> Result<()> { services() .appservice .registration_info .write() .await .remove(service_name) - .ok_or_else(|| crate::Error::AdminCommand("Appservice not found"))?; + .ok_or_else(|| { + crate::Error::AdminCommand("Appservice not found") + })?; self.db.unregister_appservice(service_name) } - pub(crate) async fn get_registration(&self, id: &str) -> Option { + pub(crate) async fn get_registration( + &self, + id: &str, + ) -> Option { self.registration_info .read() .await @@ -173,15 +184,13 @@ impl Service { } pub(crate) async fn iter_ids(&self) -> Vec { - self.registration_info - .read() - .await - .keys() - .cloned() - .collect() + self.registration_info.read().await.keys().cloned().collect() } - pub(crate) async fn find_from_token(&self, token: &str) -> Option { + pub(crate) async fn find_from_token( + &self, + token: &str, + ) -> Option { self.read() .await .values() @@ -207,8 +216,12 @@ impl Service { pub(crate) fn read( &self, - ) -> impl Future>> - { + ) -> impl Future< + Output = tokio::sync::RwLockReadGuard< + '_, + BTreeMap, + >, + > { self.registration_info.read() } } diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 92da3a73..84e2bd90 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -15,7 +15,9 @@ pub(crate) trait Data: Send + Sync { fn get_registration(&self, id: &str) -> Result>; - fn iter_ids<'a>(&'a self) -> Result> + 'a>>; + fn iter_ids<'a>( + &'a self, + ) -> Result> + 'a>>; fn all(&self) -> Result>; } diff --git a/src/service/globals.rs b/src/service/globals.rs index a72a5e06..87ad973e 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -1,26 +1,4 @@ mod data; -pub(crate) use data::Data; -use ruma::{ - serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, - OwnedServerSigningKeyId, OwnedUserId, -}; - -use crate::api::server_server::FedDest; - -use crate::{services, Config, Error, Result}; -use futures_util::FutureExt; -use hyper::{ - client::connect::dns::{GaiResolver, Name}, - service::Service as HyperService, -}; -use reqwest::dns::{Addrs, Resolve, Resolving}; -use ruma::{ - api::{ - client::sync::sync_events, - federation::discovery::{ServerSigningKeys, VerifyKey}, - }, - DeviceId, RoomVersionId, ServerName, UserId, -}; use std::{ collections::{BTreeMap, HashMap}, error::Error as StdError, @@ -35,11 +13,29 @@ use std::{ }, time::{Duration, Instant}, }; + +use base64::{engine::general_purpose, Engine as _}; +pub(crate) use data::Data; +use futures_util::FutureExt; +use hyper::{ + client::connect::dns::{GaiResolver, Name}, + service::Service as HyperService, +}; +use reqwest::dns::{Addrs, Resolve, Resolving}; +use ruma::{ + api::{ + client::sync::sync_events, + federation::discovery::{ServerSigningKeys, VerifyKey}, + }, + serde::Base64, + DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, + OwnedServerSigningKeyId, OwnedUserId, RoomVersionId, ServerName, UserId, +}; use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore}; use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; -use base64::{engine::general_purpose, Engine as _}; +use crate::{api::server_server::FedDest, services, Config, Error, Result}; type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; @@ -66,27 +62,40 @@ pub(crate) struct Service { default_client: reqwest::Client, pub(crate) stable_room_versions: Vec, pub(crate) unstable_room_versions: Vec, - pub(crate) bad_event_ratelimiter: Arc>>, - pub(crate) bad_signature_ratelimiter: Arc, RateLimitState>>>, - pub(crate) bad_query_ratelimiter: Arc>>, - pub(crate) servername_ratelimiter: Arc>>>, - pub(crate) sync_receivers: RwLock>, - pub(crate) roomid_mutex_insert: RwLock>>>, + pub(crate) bad_event_ratelimiter: + Arc>>, + pub(crate) bad_signature_ratelimiter: + Arc, RateLimitState>>>, + pub(crate) bad_query_ratelimiter: + Arc>>, + pub(crate) servername_ratelimiter: + Arc>>>, + pub(crate) sync_receivers: + RwLock>, + pub(crate) roomid_mutex_insert: + RwLock>>>, pub(crate) roomid_mutex_state: RwLock>>>, // this lock will be held longer - pub(crate) roomid_mutex_federation: RwLock>>>, - pub(crate) roomid_federationhandletime: RwLock>, + pub(crate) roomid_mutex_federation: + RwLock>>>, + pub(crate) roomid_federationhandletime: + RwLock>, pub(crate) stateres_mutex: Arc>, pub(crate) rotate: RotationHandler, pub(crate) shutdown: AtomicBool, } -/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. +/// Handles "rotation" of long-polling requests. "Rotation" in this context is +/// similar to "rotation" of log files and the like. /// -/// This is utilized to have sync workers return early and release read locks on the database. -pub(crate) struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>); +/// This is utilized to have sync workers return early and release read locks on +/// the database. +pub(crate) struct RotationHandler( + broadcast::Sender<()>, + broadcast::Receiver<()>, +); impl RotationHandler { pub(crate) fn new() -> Self { @@ -136,7 +145,10 @@ impl Resolve for Resolver { .and_then(|(override_name, port)| { override_name.first().map(|first_name| { let x: Box + Send> = - Box::new(iter::once(SocketAddr::new(*first_name, *port))); + Box::new(iter::once(SocketAddr::new( + *first_name, + *port, + ))); let x: Resolving = Box::pin(future::ready(Ok(x))); x }) @@ -144,9 +156,11 @@ impl Resolve for Resolver { .unwrap_or_else(|| { let this = &mut self.inner.clone(); Box::pin(HyperService::::call(this, name).map(|result| { - result - .map(|addrs| -> Addrs { Box::new(addrs) }) - .map_err(|err| -> Box { Box::new(err) }) + result.map(|addrs| -> Addrs { Box::new(addrs) }).map_err( + |err| -> Box { + Box::new(err) + }, + ) })) }) } @@ -167,10 +181,9 @@ impl Service { let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new())); - let jwt_decoding_key = config - .jwt_secret - .as_ref() - .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); + let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| { + jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()) + }); let default_client = reqwest_client_builder(&config)?.build()?; let federation_client = reqwest_client_builder(&config)? @@ -187,20 +200,28 @@ impl Service { RoomVersionId::V11, ]; // Experimental, partially supported room versions - let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; + let unstable_room_versions = + vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; let mut s = Self { db, config, keypair: Arc::new(keypair), - dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { - error!( - "Failed to set up trust dns resolver with system config: {}", - e - ); - Error::bad_config("Failed to set up trust dns resolver with system config.") - })?, - actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), + dns_resolver: TokioAsyncResolver::tokio_from_system_conf() + .map_err(|e| { + error!( + "Failed to set up trust dns resolver with system \ + config: {}", + e + ); + Error::bad_config( + "Failed to set up trust dns resolver with system \ + config.", + ) + })?, + actual_destination_cache: Arc::new( + RwLock::new(WellKnownMap::new()), + ), tls_name_override, federation_client, default_client, @@ -223,12 +244,11 @@ impl Service { fs::create_dir_all(s.get_media_folder())?; - if !s - .supported_room_versions() - .contains(&s.config.default_room_version) + if !s.supported_room_versions().contains(&s.config.default_room_version) { error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); - s.config.default_room_version = crate::config::default_default_room_version(); + s.config.default_room_version = + crate::config::default_default_room_version(); }; Ok(s) @@ -261,7 +281,11 @@ impl Service { self.db.current_count() } - pub(crate) async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub(crate) async fn watch( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<()> { self.db.watch(user_id, device_id).await } @@ -313,7 +337,9 @@ impl Service { &self.dns_resolver } - pub(crate) fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { + pub(crate) fn jwt_decoding_key( + &self, + ) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } @@ -353,7 +379,8 @@ impl Service { /// TODO: the key valid until timestamp is only honored in room version > 4 /// Remove the outdated keys and insert the new ones. /// - /// This doesn't actually check that the keys provided are newer than the old set. + /// This doesn't actually check that the keys provided are newer than the + /// old set. pub(crate) fn add_signing_key( &self, origin: &ServerName, @@ -362,7 +389,8 @@ impl Service { self.db.add_signing_key(origin, new_keys) } - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. pub(crate) fn signing_keys_for( &self, origin: &ServerName, diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 9d863ca9..e416530d 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -13,7 +13,8 @@ use crate::Result; pub(crate) trait Data: Send + Sync { fn next_count(&self) -> Result; fn current_count(&self) -> Result; - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) + -> Result<()>; fn cleanup(&self) -> Result<()>; fn memory_usage(&self) -> String; fn clear_caches(&self, amount: u32); @@ -25,7 +26,8 @@ pub(crate) trait Data: Send + Sync { new_keys: ServerSigningKeys, ) -> Result>; - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. fn signing_keys_for( &self, origin: &ServerName, diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index c99fe1d6..0b667192 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,12 +1,13 @@ use std::collections::BTreeMap; -use crate::Result; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn create_backup( &self, @@ -23,12 +24,21 @@ pub(crate) trait Data: Send + Sync { backup_metadata: &Raw, ) -> Result; - fn get_latest_backup_version(&self, user_id: &UserId) -> Result>; + fn get_latest_backup_version( + &self, + user_id: &UserId, + ) -> Result>; - fn get_latest_backup(&self, user_id: &UserId) - -> Result)>>; + fn get_latest_backup( + &self, + user_id: &UserId, + ) -> Result)>>; - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; + fn get_backup( + &self, + user_id: &UserId, + version: &str, + ) -> Result>>; fn add_key( &self, @@ -66,7 +76,12 @@ pub(crate) trait Data: Send + Sync { fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; + fn delete_room_keys( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<()>; fn delete_room_key( &self, diff --git a/src/service/media.rs b/src/service/media.rs index 76184374..899320fe 100644 --- a/src/service/media.rs +++ b/src/service/media.rs @@ -2,15 +2,14 @@ mod data; use std::io::Cursor; pub(crate) use data::Data; - -use crate::{services, Result}; use image::imageops::FilterType; - use tokio::{ fs::File, io::{AsyncReadExt, AsyncWriteExt, BufReader}, }; +use crate::{services, Result}; + pub(crate) struct FileMeta { pub(crate) content_disposition: Option, pub(crate) content_type: Option, @@ -31,9 +30,13 @@ impl Service { file: &[u8], ) -> Result<()> { // Width, Height = 0 if it's not a thumbnail - let key = self - .db - .create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; + let key = self.db.create_file_metadata( + mxc, + 0, + 0, + content_disposition, + content_type, + )?; let path = services().globals.get_media_file(&key); let mut f = File::create(path).await?; @@ -52,9 +55,13 @@ impl Service { height: u32, file: &[u8], ) -> Result<()> { - let key = - self.db - .create_file_metadata(mxc, width, height, content_disposition, content_type)?; + let key = self.db.create_file_metadata( + mxc, + width, + height, + content_disposition, + content_type, + )?; let path = services().globals.get_media_file(&key); let mut f = File::create(path).await?; @@ -84,9 +91,12 @@ impl Service { } } - /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when - /// the server should send the original file. - fn thumbnail_properties(width: u32, height: u32) -> Option<(u32, u32, bool)> { + /// Returns width, height of the thumbnail and whether it should be cropped. + /// Returns None when the server should send the original file. + fn thumbnail_properties( + width: u32, + height: u32, + ) -> Option<(u32, u32, bool)> { match (width, height) { (0..=32, 0..=32) => Some((32, 32, true)), (0..=96, 0..=96) => Some((96, 96, true)), @@ -102,11 +112,15 @@ impl Service { /// Here's an example on how it works: /// /// - Client requests an image with width=567, height=567 - /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails - /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) + /// - Server rounds that up to (800, 600), so it doesn't have to save too + /// many thumbnails + /// - Server rounds that up again to (958, 600) to fix the aspect ratio + /// (only for width,height>96) /// - Server creates the thumbnail and sends it to the user /// - /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. + /// For width,height <= 96 the server uses another thumbnailing algorithm + /// which crops the image afterwards. + #[allow(clippy::too_many_lines)] pub(crate) async fn get_thumbnail( &self, mxc: String, @@ -154,7 +168,8 @@ impl Service { } else { let (exact_width, exact_height) = { // Copied from image::dynimage::resize_dimensions - let use_width = (u64::from(width) * u64::from(original_height)) + let use_width = (u64::from(width) + * u64::from(original_height)) <= (u64::from(original_width) * u64::from(height)); let intermediate = if use_width { u64::from(original_height) * u64::from(width) @@ -165,21 +180,31 @@ impl Service { }; if use_width { if intermediate <= u64::from(::std::u32::MAX) { - (width, intermediate.try_into().unwrap_or(u32::MAX)) + ( + width, + intermediate.try_into().unwrap_or(u32::MAX), + ) } else { ( - (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) + (u64::from(width) + * u64::from(::std::u32::MAX) + / intermediate) .try_into() .unwrap_or(u32::MAX), ::std::u32::MAX, ) } } else if intermediate <= u64::from(::std::u32::MAX) { - (intermediate.try_into().unwrap_or(u32::MAX), height) + ( + intermediate.try_into().unwrap_or(u32::MAX), + height, + ) } else { ( ::std::u32::MAX, - (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) + (u64::from(height) + * u64::from(::std::u32::MAX) + / intermediate) .try_into() .unwrap_or(u32::MAX), ) @@ -195,7 +220,8 @@ impl Service { image::ImageOutputFormat::Png, )?; - // Save thumbnail in database so we don't have to generate it again next time + // Save thumbnail in database so we don't have to generate it + // again next time let thumbnail_key = self.db.create_file_metadata( mxc, width, diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 1de3d852..a5296ceb 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -1,24 +1,31 @@ -use crate::Error; +use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; + use ruma::{ canonical_json::redact_content_in_place, events::{ - room::{member::RoomMemberEventContent, redaction::RoomRedactionEventContent}, + room::{ + member::RoomMemberEventContent, + redaction::RoomRedactionEventContent, + }, space::child::HierarchySpaceChildEvent, - AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, - AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, + AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, + AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, + AnyTimelineEvent, StateEvent, TimelineEventType, }, serde::Raw, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, + MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, + RoomVersionId, UInt, UserId, }; use serde::{Deserialize, Serialize}; use serde_json::{ json, value::{to_raw_value, RawValue as RawJsonValue}, }; -use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; use tracing::warn; +use crate::Error; + /// Content hashes of a PDU. #[derive(Clone, Debug, Deserialize, Serialize)] pub(crate) struct EventHash { @@ -61,10 +68,18 @@ impl PduEvent { ) -> crate::Result<()> { self.unsigned = None; - let mut content = serde_json::from_str(self.content.get()) - .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; - redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) - .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; + let mut content = + serde_json::from_str(self.content.get()).map_err(|_| { + Error::bad_database("PDU in db has invalid content.") + })?; + redact_content_in_place( + &mut content, + &room_version_id, + self.kind.to_string(), + ) + .map_err(|e| { + Error::Redaction(self.sender.server_name().to_owned(), e) + })?; self.unsigned = Some(to_raw_value(&json!({ "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") @@ -78,10 +93,12 @@ impl PduEvent { pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> { if let Some(unsigned) = &self.unsigned { let mut unsigned: BTreeMap> = - serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + serde_json::from_str(unsigned.get()).map_err(|_| { + Error::bad_database("Invalid unsigned in pdu event") + })?; unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + self.unsigned = + Some(to_raw_value(&unsigned).expect("unsigned is valid")); } Ok(()) @@ -91,31 +108,45 @@ impl PduEvent { let mut unsigned: BTreeMap> = self .unsigned .as_ref() - .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + .map_or_else( + || Ok(BTreeMap::new()), + |u| serde_json::from_str(u.get()), + ) + .map_err(|_| { + Error::bad_database("Invalid unsigned in pdu event") + })?; unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + self.unsigned = + Some(to_raw_value(&unsigned).expect("unsigned is valid")); Ok(()) } - /// Copies the `redacts` property of the event to the `content` dict and vice-versa. + /// Copies the `redacts` property of the event to the `content` dict and + /// vice-versa. /// /// This follows the specification's /// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): /// - /// > For backwards-compatibility with older clients, servers should add a redacts - /// > property to the top level of m.room.redaction events in when serving such events + /// > For backwards-compatibility with older clients, servers should add a + /// > redacts + /// > property to the top level of m.room.redaction events in when serving + /// > such events /// > over the Client-Server API. /// > - /// > For improved compatibility with newer clients, servers should add a redacts property - /// > to the content of m.room.redaction events in older room versions when serving + /// > For improved compatibility with newer clients, servers should add a + /// > redacts property + /// > to the content of m.room.redaction events in older room versions when + /// > serving /// > such events over the Client-Server API. - pub(crate) fn copy_redacts(&self) -> (Option>, Box) { + pub(crate) fn copy_redacts( + &self, + ) -> (Option>, Box) { if self.kind == TimelineEventType::RoomRedaction { - if let Ok(mut content) = - serde_json::from_str::(self.content.get()) + if let Ok(mut content) = serde_json::from_str::< + RoomRedactionEventContent, + >(self.content.get()) { if let Some(redacts) = content.redacts { return (Some(redacts.into()), self.content.clone()); @@ -123,7 +154,9 @@ impl PduEvent { content.redacts = Some(redacts.into()); return ( self.redacts.clone(), - to_raw_value(&content).expect("Must be valid, we only added redacts field"), + to_raw_value(&content).expect( + "Must be valid, we only added redacts field", + ), ); } } @@ -281,7 +314,9 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_stripped_spacechild_state_event(&self) -> Raw { + pub(crate) fn to_stripped_spacechild_state_event( + &self, + ) -> Raw { let json = json!({ "content": self.content, "type": self.kind, @@ -294,7 +329,9 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_member_event(&self) -> Raw> { + pub(crate) fn to_member_event( + &self, + ) -> Raw> { let mut json = json!({ "content": self.content, "type": self.kind, @@ -318,16 +355,16 @@ impl PduEvent { pub(crate) fn convert_to_outgoing_federation_event( mut pdu_json: CanonicalJsonObject, ) -> Box { - if let Some(unsigned) = pdu_json - .get_mut("unsigned") - .and_then(|val| val.as_object_mut()) + if let Some(unsigned) = + pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) { unsigned.remove("transaction_id"); } pdu_json.remove("event_id"); - to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") + to_raw_value(&pdu_json) + .expect("CanonicalJson is valid serde_json::Value") } pub(crate) fn from_id_val( @@ -374,11 +411,15 @@ impl state_res::Event for PduEvent { self.state_key.as_deref() } - fn prev_events(&self) -> Box + '_> { + fn prev_events( + &self, + ) -> Box + '_> { Box::new(self.prev_events.iter()) } - fn auth_events(&self) -> Box + '_> { + fn auth_events( + &self, + ) -> Box + '_> { Box::new(self.auth_events.iter()) } @@ -408,15 +449,17 @@ impl Ord for PduEvent { /// Generates a correct eventId for the incoming pdu. /// -/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. +/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. pub(crate) fn gen_event_id_canonical_json( pdu: &RawJsonValue, room_version_id: &RoomVersionId, ) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let value: CanonicalJsonObject = + serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; let event_id = format!( "${}", diff --git a/src/service/pusher.rs b/src/service/pusher.rs index 90a3b90d..903f1a73 100644 --- a/src/service/pusher.rs +++ b/src/service/pusher.rs @@ -1,27 +1,34 @@ mod data; -pub(crate) use data::Data; -use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx}; +use std::{fmt::Debug, mem}; -use crate::{services, Error, PduEvent, Result}; use bytes::BytesMut; +pub(crate) use data::Data; use ruma::{ api::{ client::push::{set_pusher, Pusher, PusherKind}, push_gateway::send_event_notification::{ self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, + v1::{ + Device, Notification, NotificationCounts, NotificationPriority, + }, }, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, }, - events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType}, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + events::{ + room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, + StateEventType, TimelineEventType, + }, + push::{ + Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, + Ruleset, Tweak, + }, serde::Raw, uint, RoomId, UInt, UserId, }; - -use std::{fmt::Debug, mem}; use tracing::{info, warn}; +use crate::{services, Error, PduEvent, Result}; + pub(crate) struct Service { pub(crate) db: &'static dyn Data, } @@ -35,7 +42,11 @@ impl Service { self.db.set_pusher(sender, pusher) } - pub(crate) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + pub(crate) fn get_pusher( + &self, + sender: &UserId, + pushkey: &str, + ) -> Result> { self.db.get_pusher(sender, pushkey) } @@ -43,7 +54,10 @@ impl Service { self.db.get_pushers(sender) } - pub(crate) fn get_pushkeys(&self, sender: &UserId) -> Box>> { + pub(crate) fn get_pushkeys( + &self, + sender: &UserId, + ) -> Box>> { self.db.get_pushkeys(sender) } @@ -73,11 +87,8 @@ impl Service { let reqwest_request = reqwest::Request::try_from(http_request)?; let url = reqwest_request.url().clone(); - let response = services() - .globals - .default_client() - .execute(reqwest_request) - .await; + let response = + services().globals.default_client().execute(reqwest_request).await; match response { Ok(mut response) => { @@ -119,11 +130,16 @@ impl Service { "Push gateway returned invalid response bytes {}\n{}", destination, url ); - Error::BadServerResponse("Push gateway returned bad response.") + Error::BadServerResponse( + "Push gateway returned bad response.", + ) }) } Err(e) => { - warn!("Could not send request to pusher {}: {}", destination, e); + warn!( + "Could not send request to pusher {}: {}", + destination, e + ); Err(e.into()) } } @@ -146,8 +162,9 @@ impl Service { .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + serde_json::from_str(ev.content.get()).map_err(|_| { + Error::bad_database("invalid m.room.power_levels event") + }) }) .transpose()? .unwrap_or_default(); @@ -228,11 +245,16 @@ impl Service { PusherKind::Http(http) => { // TODO: // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 1. if "event_id_only" is the only format kind it seems we + // should never add more info // 2. can pusher/devices have conflicting formats - let event_id_only = http.format == Some(PushFormat::EventIdOnly); + let event_id_only = + http.format == Some(PushFormat::EventIdOnly); - let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); + let mut device = Device::new( + pusher.ids.app_id.clone(), + pusher.ids.pushkey.clone(), + ); device.data.default_payload = http.default_payload.clone(); device.data.format = http.format.clone(); @@ -251,32 +273,43 @@ impl Service { notifi.counts = NotificationCounts::new(unread, uint!(0)); if event.kind == TimelineEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + || tweaks.iter().any(|t| { + matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)) + }) { notifi.prio = NotificationPriority::High; } if event_id_only { - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) - .await?; + self.send_request( + &http.url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; } else { notifi.sender = Some(event.sender.clone()); notifi.event_type = Some(event.kind.clone()); - notifi.content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = + serde_json::value::to_raw_value(&event.content).ok(); if event.kind == TimelineEventType::RoomMember { - notifi.user_is_target = - event.state_key.as_deref() == Some(event.sender.as_str()); + notifi.user_is_target = event.state_key.as_deref() + == Some(event.sender.as_str()); } - notifi.sender_display_name = services().users.displayname(&event.sender)?; + notifi.sender_display_name = + services().users.displayname(&event.sender)?; - notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; + notifi.room_name = services() + .rooms + .state_accessor + .get_name(&event.room_id)?; - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) - .await?; + self.send_request( + &http.url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; } Ok(()) diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index b13fcc3b..9fcba933 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,16 +1,27 @@ -use crate::Result; use ruma::{ api::client::push::{set_pusher, Pusher}, UserId, }; -pub(crate) trait Data: Send + Sync { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; +use crate::Result; - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; +pub(crate) trait Data: Send + Sync { + fn set_pusher( + &self, + sender: &UserId, + pusher: set_pusher::v3::PusherAction, + ) -> Result<()>; + + fn get_pusher( + &self, + sender: &UserId, + pushkey: &str, + ) -> Result>; fn get_pushers(&self, sender: &UserId) -> Result>; - fn get_pushkeys<'a>(&'a self, sender: &UserId) - -> Box> + 'a>; + fn get_pushkeys<'a>( + &'a self, + sender: &UserId, + ) -> Box> + 'a>; } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index f0f442d0..0e0e6f7c 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,6 +1,7 @@ -use crate::Result; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { /// Creates or updates the alias to the given room id. fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; @@ -9,7 +10,10 @@ pub(crate) trait Data: Send + Sync { fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; /// Looks up the roomid for the given alias. - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>; + fn resolve_local_alias( + &self, + alias: &RoomAliasId, + ) -> Result>; /// Returns all local aliases that point to the given room fn local_aliases_for_room<'a>( diff --git a/src/service/rooms/auth_chain.rs b/src/service/rooms/auth_chain.rs index 40f49713..2a601830 100644 --- a/src/service/rooms/auth_chain.rs +++ b/src/service/rooms/auth_chain.rs @@ -43,7 +43,8 @@ impl Service { let mut i = 0; for id in starting_events { - let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let short = + services().rooms.short.get_or_create_shorteventid(&id)?; // I'm afraid to change this in case there is accidental reliance on // the truncation #[allow(clippy::as_conversions, clippy::cast_possible_truncation)] @@ -64,7 +65,8 @@ impl Service { continue; } - let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); + let chunk_key: Vec = + chunk.iter().map(|(short, _)| short).copied().collect(); if let Some(cached) = services() .rooms .auth_chain @@ -90,11 +92,13 @@ impl Service { chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); - services() - .rooms - .auth_chain - .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + let auth_chain = Arc::new( + self.get_auth_chain_inner(room_id, &event_id)?, + ); + services().rooms.auth_chain.cache_auth_chain( + vec![sevent_id], + Arc::clone(&auth_chain), + )?; debug!( event_id = ?event_id, chain_length = ?auth_chain.len(), @@ -129,13 +133,17 @@ impl Service { "Auth chain stats", ); - Ok(full_auth_chain - .into_iter() - .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + Ok(full_auth_chain.into_iter().filter_map(move |sid| { + services().rooms.short.get_eventid_from_short(sid).ok() + })) } #[tracing::instrument(skip(self, event_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + fn get_auth_chain_inner( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); @@ -143,7 +151,10 @@ impl Service { match services().rooms.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { - return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Evil event in db", + )); } for auth_event in &pdu.auth_events { let sauthevent = services() @@ -158,10 +169,17 @@ impl Service { } } Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); + warn!( + ?event_id, + "Could not find pdu mentioned in auth events" + ); } Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); + error!( + ?event_id, + ?error, + "Could not load event in auth chain" + ); } } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 7a865368..5ecaee34 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,11 +1,15 @@ -use crate::Result; use std::{collections::HashSet, sync::Arc}; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn get_cached_eventid_authchain( &self, shorteventid: &[u64], ) -> Result>>>; - fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc>) - -> Result<()>; + fn cache_auth_chain( + &self, + shorteventid: Vec, + auth_chain: Arc>, + ) -> Result<()>; } diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 9aefaf6c..f0bc3bf5 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,6 +1,7 @@ -use crate::Result; use ruma::{OwnedRoomId, RoomId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { /// Adds the room to the public room directory fn set_public(&self, room_id: &RoomId) -> Result<()>; @@ -12,5 +13,7 @@ pub(crate) trait Data: Send + Sync { fn is_public_room(&self, room_id: &RoomId) -> Result; /// Returns the unsorted public room directory - fn public_rooms<'a>(&'a self) -> Box> + 'a>; + fn public_rooms<'a>( + &'a self, + ) -> Box> + 'a>; } diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 366e946c..9e468400 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,5 +1,8 @@ +use ruma::{ + events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId, +}; + use crate::Result; -use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; pub(crate) trait Data: Send + Sync { /// Replaces the previous read receipt. @@ -10,7 +13,8 @@ pub(crate) trait Data: Send + Sync { event: ReceiptEvent, ) -> Result<()>; - /// Returns an iterator over the most recent read receipts in a room that happened after the event with id `since`. + /// Returns an iterator over the most recent read receipts in a room that + /// happened after the event with id `since`. #[allow(clippy::type_complexity)] fn readreceipts_since<'a>( &'a self, @@ -27,11 +31,24 @@ pub(crate) trait Data: Send + Sync { >; /// Sets a private read marker at `count`. - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; + fn private_read_set( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<()>; /// Returns the private read marker. - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn private_read_get( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result>; /// Returns the count of the last typing update in this room. - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn last_privateread_update( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result; } diff --git a/src/service/rooms/edus/typing.rs b/src/service/rooms/edus/typing.rs index b3838d8a..52761d69 100644 --- a/src/service/rooms/edus/typing.rs +++ b/src/service/rooms/edus/typing.rs @@ -1,8 +1,9 @@ +use std::collections::BTreeMap; + use ruma::{ events::{typing::TypingEventContent, SyncEphemeralRoomEvent}, OwnedRoomId, OwnedUserId, RoomId, UserId, }; -use std::collections::BTreeMap; use tokio::sync::{broadcast, RwLock}; use tracing::trace; @@ -10,15 +11,16 @@ use crate::{services, utils, Result}; pub(crate) struct Service { // u64 is unix timestamp of timeout - pub(crate) typing: RwLock>>, + pub(crate) typing: + RwLock>>, // timestamp of the last change to typing users pub(crate) last_typing_update: RwLock>, pub(crate) typing_update_sender: broadcast::Sender, } impl Service { - /// Sets a user as typing until the timeout timestamp is reached or `roomtyping_remove` is - /// called. + /// Sets a user as typing until the timeout timestamp is reached or + /// `roomtyping_remove` is called. pub(crate) async fn typing_add( &self, user_id: &UserId, @@ -36,13 +38,20 @@ impl Service { .await .insert(room_id.to_owned(), services().globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { - trace!("receiver found what it was looking for and is no longer interested"); + trace!( + "receiver found what it was looking for and is no longer \ + interested" + ); } Ok(()) } /// Removes a user from typing before the timeout is reached. - pub(crate) async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(crate) async fn typing_remove( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()> { self.typing .write() .await @@ -54,7 +63,10 @@ impl Service { .await .insert(room_id.to_owned(), services().globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { - trace!("receiver found what it was looking for and is no longer interested"); + trace!( + "receiver found what it was looking for and is no longer \ + interested" + ); } Ok(()) } @@ -97,14 +109,20 @@ impl Service { .await .insert(room_id.to_owned(), services().globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { - trace!("receiver found what it was looking for and is no longer interested"); + trace!( + "receiver found what it was looking for and is no longer \ + interested" + ); } } Ok(()) } /// Returns the count of the last typing update in this room. - pub(crate) async fn last_typing_update(&self, room_id: &RoomId) -> Result { + pub(crate) async fn last_typing_update( + &self, + room_id: &RoomId, + ) -> Result { self.typings_maintain(room_id).await?; Ok(self .last_typing_update diff --git a/src/service/rooms/event_handler.rs b/src/service/rooms/event_handler.rs index 8ff88088..48c40be1 100644 --- a/src/service/rooms/event_handler.rs +++ b/src/service/rooms/event_handler.rs @@ -24,7 +24,8 @@ use ruma::{ }, events::{ room::{ - create::RoomCreateEventContent, redaction::RoomRedactionEventContent, + create::RoomCreateEventContent, + redaction::RoomRedactionEventContent, server_acl::RoomServerAclEventContent, }, StateEventType, TimelineEventType, @@ -32,16 +33,16 @@ use ruma::{ int, serde::Base64, state_res::{self, RoomVersion, StateMap}, - uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedServerName, OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName, + uint, CanonicalJsonObject, CanonicalJsonValue, EventId, + MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId, + RoomId, RoomVersionId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore}; use tracing::{debug, error, info, trace, warn}; -use crate::{service::pdu, services, Error, PduEvent, Result}; - use super::state_compressor::CompressedStateEvent; +use crate::{service::pdu, services, Error, PduEvent, Result}; pub(crate) struct Service; @@ -52,24 +53,29 @@ impl Service { /// 1.1. Remove unsigned field /// 2. Check signatures, otherwise drop /// 3. Check content hash, redact if doesn't match - /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not - /// timeline events - /// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are - /// also rejected "due to auth events" - /// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + /// 4. Fetch any missing auth events doing all checks listed here starting + /// at 1. These are not timeline events + /// 5. Reject "due to auth events" if can't get all the auth events or some + /// of the auth events are also rejected "due to auth events" + /// 6. Reject "due to auth events" if the event doesn't pass auth based on + /// the auth events /// 7. Persist this event as an outlier /// 8. If not timeline event: stop - /// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline - /// events - /// 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - /// doing all the checks in this list starting at 1. These are not timeline events + /// 9. Fetch any missing prev events doing all checks listed here starting + /// at 1. These are timeline events + /// 10. Fetch missing state and auth chain events by calling /state_ids at + /// backwards extremities doing all the checks in this list starting at + /// 1. These are not timeline events /// 11. Check the auth of the event passes based on the state of the event - /// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by - /// doing state res where one of the inputs was a previously trusted set of state, don't just - /// trust a set of state we got from a remote) + /// 12. Ensure that the state is derived from the previous current state + /// (i.e. we calculated by doing state res where one of the inputs was a + /// previously trusted set of state, don't just trust a set of state we + /// got from a remote) /// 13. Use state resolution to find new room state - /// 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it - // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively + /// 14. Check if the event passes auth based on the "current state" of the + /// room, if not soft fail it + // We use some AsyncRecursiveType hacks here so we can call this async + // funtion recursively #[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))] pub(crate) async fn handle_incoming_pdu<'a>( &self, @@ -106,7 +112,9 @@ impl Service { .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + .ok_or_else(|| { + Error::bad_database("Failed to find create event in db.") + })?; let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()).map_err(|e| { @@ -115,11 +123,10 @@ impl Service { })?; let room_version_id = &create_event_content.room_version; - let first_pdu_in_room = services() - .rooms - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = + services().rooms.timeline.first_pdu_in_room(room_id)?.ok_or_else( + || Error::bad_database("Failed to find first pdu in db."), + )?; let (incoming_pdu, val) = self .handle_outlier_pdu( @@ -144,7 +151,8 @@ impl Service { return Ok(None); } - // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + // 9. Fetch any missing prev events doing all checks listed here + // starting at 1. These are timeline events let (sorted_prev_events, mut eventid_info) = self .fetch_unknown_prev_events( origin, @@ -163,7 +171,8 @@ impl Service { if services().rooms.metadata.is_disabled(room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, - "Federation of this room is currently disabled on this server.", + "Federation of this room is currently disabled on this \ + server.", )); } @@ -175,7 +184,8 @@ impl Service { .get(&*prev_id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } @@ -217,7 +227,10 @@ impl Service { .roomid_federationhandletime .write() .await - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + .insert( + room_id.to_owned(), + ((*prev_id).to_owned(), start_time), + ); if let Err(e) = self .upgrade_outlier_to_timeline_pdu( @@ -305,32 +318,40 @@ impl Service { mut value: BTreeMap, auth_events_known: bool, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>> { + ) -> AsyncRecursiveType< + 'a, + Result<(Arc, BTreeMap)>, + > { Box::pin(async move { // 1.1. Remove unsigned field value.remove("unsigned"); // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - self.fetch_required_signing_keys(&value, pub_key_map) - .await?; + // We go through all the signatures we see on the value and fetch + // the corresponding signing keys + self.fetch_required_signing_keys(&value, pub_key_map).await?; // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; + serde_json::from_str(create_event.content.get()).map_err( + |e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + }, + )?; let room_version_id = &create_event_content.room_version; - let room_version = - RoomVersion::new(room_version_id).expect("room version is supported"); + let room_version = RoomVersion::new(room_version_id) + .expect("room version is supported"); let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) { + let mut val = match ruma::signatures::verify_event( + &guard, + &value, + room_version_id, + ) { Err(e) => { // Drop warn!("Dropping bad event {}: {}", event_id, e,); @@ -342,15 +363,25 @@ impl Service { Ok(ruma::signatures::Verified::Signatures) => { // Redact warn!("Calculated hash does not match: {}", event_id); - let Ok(obj) = ruma::canonical_json::redact(value, room_version_id, None) else { + let Ok(obj) = ruma::canonical_json::redact( + value, + room_version_id, + None, + ) else { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Redaction failed", )); }; - // Skip the PDU if it is redacted and we already have it as an outlier event - if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { + // Skip the PDU if it is redacted and we already have it as + // an outlier event + if services() + .rooms + .timeline + .get_pdu_json(event_id)? + .is_some() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Event was redacted and we already knew about it", @@ -364,23 +395,28 @@ impl Service { drop(guard); - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type + // Now that we have checked the signature and hashes we can add the + // eventID and convert to our PduEvent type val.insert( "event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()), ); let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + serde_json::to_value(&val) + .expect("CanonicalJsonObj is a valid JsonValue"), ) .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; Self::check_room_id(room_id, &incoming_pdu)?; if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often + // 4. fetch any missing auth events doing all checks listed here + // starting at 1. These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth + // events or some of the auth events are also rejected "due + // to auth events" + // NOTE: Step 5 is not applied anymore because it failed too + // often debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); self.fetch_and_handle_outliers( origin, @@ -397,7 +433,8 @@ impl Service { .await; } - // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + // 6. Reject "due to auth events" if the event doesn't pass auth + // based on the auth events debug!( "Auth check for {} based on auth events", incoming_pdu.event_id @@ -406,7 +443,8 @@ impl Service { // Build map of auth events let mut auth_events = HashMap::new(); for id in &incoming_pdu.auth_events { - let Some(auth_event) = services().rooms.timeline.get_pdu(id)? else { + let Some(auth_event) = services().rooms.timeline.get_pdu(id)? + else { warn!("Could not find auth event {}", id); continue; }; @@ -426,7 +464,8 @@ impl Service { hash_map::Entry::Occupied(_) => { return Err(Error::BadRequest( ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", + "Auth event's type and state_key combination \ + exists multiple times.", )); } } @@ -450,8 +489,9 @@ impl Service { None::, |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? - { + .map_err(|_e| { + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })? { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Auth check failed", @@ -472,7 +512,13 @@ impl Service { }) } - #[tracing::instrument(skip(self, incoming_pdu, val, create_event, pub_key_map))] + #[tracing::instrument(skip( + self, + incoming_pdu, + val, + create_event, + pub_key_map + ))] pub(crate) async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, @@ -483,7 +529,9 @@ impl Service { pub_key_map: &RwLock>>, ) -> Result>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(Some(pduid)) = + services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) + { return Ok(Some(pduid)); } @@ -507,13 +555,16 @@ impl Service { })?; let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + let room_version = RoomVersion::new(room_version_id) + .expect("room version is supported"); - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. + // 10. Fetch missing state and auth chain events by calling /state_ids + // at backwards extremities doing all the checks in this list + // starting at 1. These are not timeline events. - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event + // TODO: if we know the prev_events of the incoming event we can avoid + // the request and build the state from a known point and + // resolve if > 1 prev_event debug!("Requesting state at event"); let mut state_at_incoming_event = None; @@ -546,14 +597,17 @@ impl Service { .ok() .flatten() .ok_or_else(|| { - Error::bad_database("Could not find prev event, but we know the state.") + Error::bad_database( + "Could not find prev event, but we know the state.", + ) })?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &prev_pdu.kind.to_string().into(), + state_key, + )?; state.insert(shortstatekey, Arc::from(prev_event)); // Now it's the state after the pdu @@ -567,7 +621,9 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = services().rooms.timeline.get_pdu(prev_eventid) else { + let Ok(Some(prev_event)) = + services().rooms.timeline.get_pdu(prev_eventid) + else { okay = false; break; }; @@ -585,8 +641,10 @@ impl Service { } if okay { - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + let mut fork_states = + Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = + Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { let mut leaf_state: HashMap<_, _> = services() @@ -596,23 +654,34 @@ impl Service { .await?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), - state_key, - )?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &prev_event.kind.to_string().into(), + state_key, + )?; + leaf_state.insert( + shortstatekey, + Arc::from(&*prev_event.event_id), + ); // Now it's the state after the pdu } let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); + let mut starting_events = + Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) + if let Ok((ty, st_key)) = + services().rooms.short.get_statekey_from_short(k) { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); + state.insert( + (ty.to_string().into(), st_key), + id.clone(), + ); } else { warn!("Failed to get_statekey_from_short."); } @@ -633,14 +702,18 @@ impl Service { let lock = services().globals.stateres_mutex.lock(); - let result = - state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let result = state_res::resolve( + room_version_id, + &fork_states, + auth_chain_sets, + |id| { let res = services().rooms.timeline.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } res.ok().flatten() - }); + }, + ); drop(lock); state_at_incoming_event = match result { @@ -648,8 +721,10 @@ impl Service { new_state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey( + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( &event_type.to_string().into(), &state_key, )?; @@ -658,7 +733,12 @@ impl Service { .collect::>()?, ), Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); + warn!( + "State resolution on prev events failed, either \ + an event could not be found or deserialization: \ + {}", + e + ); None } } @@ -667,8 +747,9 @@ impl Service { if state_at_incoming_event.is_none() { debug!("Calling /state_ids"); - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events + // Call /state_ids to find out what the state at this pdu is. We + // trust the server's response to some extend, but we + // still do a lot of checks on the events match services() .sending .send_federation_request( @@ -700,22 +781,31 @@ impl Service { let mut state: HashMap<_, Arc> = HashMap::new(); for (pdu, _) in state_vec { - let state_key = pdu.state_key.clone().ok_or_else(|| { - Error::bad_database("Found non-state pdu in state events.") - })?; + let state_key = + pdu.state_key.clone().ok_or_else(|| { + Error::bad_database( + "Found non-state pdu in state events.", + ) + })?; - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &pdu.kind.to_string().into(), + &state_key, + )?; match state.entry(shortstatekey) { hash_map::Entry::Vacant(v) => { v.insert(Arc::from(&*pdu.event_id)); } - hash_map::Entry::Occupied(_) => return Err( - Error::bad_database("State event's type and state_key combination exists multiple times."), - ), + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key \ + combination exists multiple times.", + )) + } } } @@ -726,7 +816,9 @@ impl Service { .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); - if state.get(&create_shortstatekey) != Some(&create_event.event_id) { + if state.get(&create_shortstatekey) + != Some(&create_event.event_id) + { return Err(Error::bad_database( "Incoming event refers to wrong create event.", )); @@ -745,7 +837,8 @@ impl Service { state_at_incoming_event.expect("we always set this to some above"); debug!("Starting auth check"); - // 11. Check the auth of the event passes based on the state of the event + // 11. Check the auth of the event passes based on the state of the + // event let check_result = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -758,11 +851,22 @@ impl Service { .get_shortstatekey(&k.to_string().into(), s) .ok() .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) + .and_then(|shortstatekey| { + state_at_incoming_event.get(&shortstatekey) + }) + .and_then(|event_id| { + services() + .rooms + .timeline + .get_pdu(event_id) + .ok() + .flatten() + }) }, ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; + .map_err(|_e| { + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed.") + })?; if !check_result { return Err(Error::bad_database( @@ -786,51 +890,57 @@ impl Service { None::, |k, s| auth_events.get(&(k.clone(), s.to_owned())), ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))? - || incoming_pdu.kind == TimelineEventType::RoomRedaction - && match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - if let Some(redact_id) = &incoming_pdu.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } + .map_err(|_e| { + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed.") + })? || incoming_pdu.kind + == TimelineEventType::RoomRedaction + && match room_version_id { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => { + if let Some(redact_id) = &incoming_pdu.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false } - RoomVersionId::V11 => { - let content = serde_json::from_str::( - incoming_pdu.content.get(), - ) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + } + RoomVersionId::V11 => { + let content = serde_json::from_str::< + RoomRedactionEventContent, + >( + incoming_pdu.content.get() + ) + .map_err(|_| { + Error::bad_database("Invalid content in redaction pdu.") + })?; - if let Some(redact_id) = &content.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } + if let Some(redact_id) = &content.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false } - _ => { - unreachable!("Validity of room version already checked") - } - }; + } + _ => { + unreachable!("Validity of room version already checked") + } + }; // 13. Use state resolution to find new room state @@ -846,12 +956,15 @@ impl Service { ); let state_lock = mutex_state.lock().await; - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) + // Now we calculate the set of extremities this room has after the + // incoming event has been applied. We start with the previous + // extremities (aka leaves) debug!("Calculating extremities"); - let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; + let mut extremities = + services().rooms.state.get_forward_extremities(room_id)?; - // Remove any forward extremities that are referenced by this incoming event's prev_events + // Remove any forward extremities that are referenced by this incoming + // event's prev_events for prev_event in &incoming_pdu.prev_events { if extremities.contains(prev_event) { extremities.remove(prev_event); @@ -861,10 +974,7 @@ impl Service { // Only keep those extremities were not referenced yet extremities.retain(|id| { !matches!( - services() - .rooms - .pdu_metadata - .is_event_referenced(room_id, id), + services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true) ) }); @@ -888,12 +998,14 @@ impl Service { // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &incoming_pdu.kind.to_string().into(), + state_key, + )?; - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + state_after + .insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); } let new_room_state = self @@ -915,7 +1027,8 @@ impl Service { .await?; } - // 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it + // 14. Check if the event passes auth based on the "current state" of + // the room, if not soft fail it debug!("Starting soft fail auth check"); if soft_fail { @@ -932,7 +1045,8 @@ impl Service { ) .await?; - // Soft fail, we keep the event as an outlier but don't add it to the timeline + // Soft fail, we keep the event as an outlier but don't add it to + // the timeline warn!("Event was soft failed: {:?}", incoming_pdu); services() .rooms @@ -998,7 +1112,10 @@ impl Service { services() .rooms .auth_chain - .get_auth_chain(room_id, state.iter().map(|(_, id)| id.clone()).collect()) + .get_auth_chain( + room_id, + state.iter().map(|(_, id)| id.clone()).collect(), + ) .await? .collect(), ); @@ -1015,7 +1132,9 @@ impl Service { .rooms .short .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) + .map(|(ty, st_key)| { + ((ty.to_string().into(), st_key), id) + }) .ok() }) .collect::>() @@ -1033,11 +1152,15 @@ impl Service { }; let lock = services().globals.stateres_mutex.lock(); - let Ok(state) = - state_res::resolve(room_version_id, &fork_states, auth_chain_sets, fetch_event) - else { + let Ok(state) = state_res::resolve( + room_version_id, + &fork_states, + auth_chain_sets, + fetch_event, + ) else { return Err(Error::bad_database( - "State resolution failed, either an event could not be found or deserialization", + "State resolution failed, either an event could not be found \ + or deserialization", )); }; @@ -1048,10 +1171,11 @@ impl Service { let new_room_state = state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + )?; services() .rooms .state_compressor @@ -1081,8 +1205,10 @@ impl Service { room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> - { + ) -> AsyncRecursiveType< + 'a, + Vec<(Arc, Option>)>, + > { Box::pin(async move { let back_off = |id| async move { match services() @@ -1106,15 +1232,17 @@ impl Service { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + if let Ok(Some(local_pdu)) = + services().rooms.timeline.get_pdu(id) + { trace!("Found {} in db", id); pdus.push((local_pdu, None)); continue; } // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. + // We also handle its auth chain here so we don't get a stack + // overflow in handle_outlier_pdu. let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::new(); let mut events_all = HashSet::new(); @@ -1130,8 +1258,11 @@ impl Service { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + if min_elapsed_duration + > Duration::from_secs(60 * 60 * 24) + { + min_elapsed_duration = + Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { @@ -1149,7 +1280,9 @@ impl Service { tokio::task::yield_now().await; } - if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + if let Ok(Some(_)) = + services().rooms.timeline.get_pdu(&next_id) + { trace!("Found {} in db", next_id); continue; } @@ -1167,24 +1300,30 @@ impl Service { { info!("Got {} over federation", next_id); let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + pdu::gen_event_id_canonical_json( + &res.pdu, + room_version_id, + ) else { back_off((*next_id).to_owned()).await; continue; }; if calculated_event_id != *next_id { - warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu); + warn!( + "Server didn't return event id we requested: \ + requested: {}, we got {}. Event: {:?}", + next_id, calculated_event_id, &res.pdu + ); } if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { for auth_event in auth_events { - if let Ok(auth_event) = - serde_json::from_value(auth_event.clone().into()) - { + if let Ok(auth_event) = serde_json::from_value( + auth_event.clone().into(), + ) { let a: Arc = auth_event; todo_auth_events.push(a); } else { @@ -1214,8 +1353,11 @@ impl Service { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + if min_elapsed_duration + > Duration::from_secs(60 * 60 * 24) + { + min_elapsed_duration = + Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { @@ -1242,7 +1384,10 @@ impl Service { } } Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); + warn!( + "Authentication of event {} failed: {:?}", + next_id, e + ); back_off((**next_id).to_owned()).await; } } @@ -1262,17 +1407,19 @@ impl Service { initial_set: Vec>, ) -> Result<( Vec>, - HashMap, (Arc, BTreeMap)>, + HashMap< + Arc, + (Arc, BTreeMap), + >, )> { let mut graph: HashMap, _> = HashMap::new(); let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; - let first_pdu_in_room = services() - .rooms - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = + services().rooms.timeline.first_pdu_in_room(room_id)?.ok_or_else( + || Error::bad_database("Failed to find first pdu in db."), + )?; let mut amount = 0; @@ -1306,7 +1453,8 @@ impl Service { .ok() .flatten() }) { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts + { amount += 1; for prev_prev in &pdu.prev_events { if !graph.contains_key(prev_prev) { @@ -1334,20 +1482,22 @@ impl Service { } } - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|_| Error::bad_database("Error sorting prev events"))?; + let sorted = + state_res::lexicographical_topological_sort(&graph, |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + Ok(( + int!(0), + MilliSecondsSinceUnixEpoch( + eventid_info.get(event_id).map_or_else( + || uint!(0), + |info| info.0.origin_server_ts, + ), + ), + )) + }) + .map_err(|_| Error::bad_database("Error sorting prev events"))?; Ok((sorted, eventid_info)) } @@ -1368,20 +1518,23 @@ impl Service { "Invalid signatures object in server response pdu.", ))?; - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys + // We go through all the signatures we see on the value and fetch the + // corresponding signing keys for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; + let signature_object = + signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; - let signature_ids = signature_object.keys().cloned().collect::>(); + let signature_ids = + signature_object.keys().cloned().collect::>(); let fetch_res = self .fetch_signing_keys( signature_server.as_str().try_into().map_err(|_| { Error::BadServerResponse( - "Invalid servername in signatures of server response pdu.", + "Invalid servername in signatures of server \ + response pdu.", ) })?, signature_ids, @@ -1389,32 +1542,40 @@ impl Service { .await; let Ok(keys) = fetch_res else { - warn!("Signature verification failed: Could not fetch signing key.",); + warn!( + "Signature verification failed: Could not fetch signing \ + key.", + ); continue; }; - pub_key_map - .write() - .await - .insert(signature_server.clone(), keys); + pub_key_map.write().await.insert(signature_server.clone(), keys); } Ok(()) } - // Gets a list of servers for which we don't have the signing key yet. We go over - // the PDUs and either cache the key or add it to the list that needs to be retrieved. + // Gets a list of servers for which we don't have the signing key yet. We go + // over the PDUs and either cache the key or add it to the list that + // needs to be retrieved. async fn get_server_keys_from_cache( &self, pdu: &RawJsonValue, - servers: &mut BTreeMap>, + servers: &mut BTreeMap< + OwnedServerName, + BTreeMap, + >, room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, + pub_key_map: &mut RwLockWriteGuard< + '_, + BTreeMap>, + >, ) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; let event_id = format!( "${}", @@ -1424,22 +1585,21 @@ impl Service { let event_id = <&EventId>::try_from(event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(event_id) + if let Some((time, tries)) = + services().globals.bad_event_ratelimiter.read().await.get(event_id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); + return Err(Error::BadServerResponse( + "bad event, still backing off", + )); } } @@ -1454,21 +1614,29 @@ impl Service { ))?; for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; + let signature_object = + signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; - let signature_ids = signature_object.keys().cloned().collect::>(); + let signature_ids = + signature_object.keys().cloned().collect::>(); let contains_all_ids = |keys: &BTreeMap| { signature_ids.iter().all(|id| keys.contains_key(id)) }; - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|_| { - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?; + let origin = <&ServerName>::try_from(signature_server.as_str()) + .map_err(|_| { + Error::BadServerResponse( + "Invalid servername in signatures of server response \ + pdu.", + ) + })?; - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { + if servers.contains_key(origin) + || pub_key_map.contains_key(origin.as_str()) + { continue; } @@ -1492,6 +1660,7 @@ impl Service { Ok(()) } + #[allow(clippy::too_many_lines)] pub(crate) async fn fetch_join_signing_keys( &self, event: &create_join_event::v2::Response, @@ -1515,7 +1684,12 @@ impl Service { .chain(&event.room_state.auth_chain) { if let Err(error) = self - .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) + .get_server_keys_from_cache( + pdu, + &mut servers, + room_version, + &mut pkm, + ) .await { debug!(%error, "failed to get server keys from cache"); @@ -1549,7 +1723,8 @@ impl Service { Ok(key) => key, Err(e) => { warn!( - "Received error {} while fetching keys from trusted server {}", + "Received error {} while fetching keys from \ + trusted server {}", e, server ); warn!("{}", k.into_json()); @@ -1584,7 +1759,10 @@ impl Service { ( services() .sending - .send_federation_request(&server, get_server_keys::v2::Request::new()) + .send_federation_request( + &server, + get_server_keys::v2::Request::new(), + ) .await, server, ) @@ -1602,7 +1780,10 @@ impl Service { .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); - pub_key_map.write().await.insert(origin.to_string(), result); + pub_key_map + .write() + .await + .insert(origin.to_string(), result); } } info!("Done handling result"); @@ -1616,7 +1797,11 @@ impl Service { /// Returns Ok if the acl allows the server // Allowed because this function uses `services()` #[allow(clippy::unused_self)] - pub(crate) fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + pub(crate) fn acl_check( + &self, + server_name: &ServerName, + room_id: &RoomId, + ) -> Result<()> { let Some(acl_event) = services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomServerAcl, @@ -1626,9 +1811,9 @@ impl Service { return Ok(()); }; - let Ok(acl_event_content) = - serde_json::from_str::(acl_event.content.get()) - else { + let Ok(acl_event_content) = serde_json::from_str::< + RoomServerAclEventContent, + >(acl_event.content.get()) else { warn!("Invalid ACL event"); return Ok(()); }; @@ -1652,16 +1837,17 @@ impl Service { } } - /// Search the DB for the signing keys of the given server, if we don't have them - /// fetch them from the server and save to our DB. + /// Search the DB for the signing keys of the given server, if we don't have + /// them fetch them from the server and save to our DB. #[tracing::instrument(skip_all)] pub(crate) async fn fetch_signing_keys( &self, origin: &ServerName, signature_ids: Vec, ) -> Result> { - let contains_all_ids = - |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + let contains_all_ids = |keys: &BTreeMap| { + signature_ids.iter().all(|id| keys.contains_key(id)) + }; let permit = services() .globals @@ -1674,7 +1860,8 @@ impl Service { let permit = if let Some(p) = permit { p } else { - let mut write = services().globals.servername_ratelimiter.write().await; + let mut write = + services().globals.servername_ratelimiter.write().await; let s = Arc::clone( write .entry(origin.to_owned()) @@ -1696,7 +1883,9 @@ impl Service { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + } } }; @@ -1708,14 +1897,17 @@ impl Service { .get(&signature_ids) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { debug!("Backing off from {:?}", signature_ids); - return Err(Error::BadServerResponse("bad signature, still backing off")); + return Err(Error::BadServerResponse( + "bad signature, still backing off", + )); } } @@ -1736,14 +1928,15 @@ impl Service { if let Some(server_key) = services() .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) + .send_federation_request( + origin, + get_server_keys::v2::Request::new(), + ) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services() - .globals - .add_signing_key(origin, server_key.clone())?; + services().globals.add_signing_key(origin, server_key.clone())?; result.extend( server_key @@ -1814,9 +2007,7 @@ impl Service { back_off(signature_ids).await; warn!("Failed to find public key for server: {}", origin); - Err(Error::BadServerResponse( - "Failed to find public key for server", - )) + Err(Error::BadServerResponse("Failed to find public key for server")) } fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { diff --git a/src/service/rooms/lazy_loading.rs b/src/service/rooms/lazy_loading.rs index 54a10480..10e0fff1 100644 --- a/src/service/rooms/lazy_loading.rs +++ b/src/service/rooms/lazy_loading.rs @@ -5,16 +5,19 @@ pub(crate) use data::Data; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; -use crate::Result; - use super::timeline::PduCount; +use crate::Result; pub(crate) struct Service { pub(crate) db: &'static dyn Data, #[allow(clippy::type_complexity)] - pub(crate) lazy_load_waiting: - Mutex>>, + pub(crate) lazy_load_waiting: Mutex< + HashMap< + (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), + HashSet, + >, + >, } impl Service { @@ -26,8 +29,7 @@ impl Service { room_id: &RoomId, ll_user: &UserId, ) -> Result { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) + self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) } #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 92f694a3..95bf83d8 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,6 +1,7 @@ -use crate::Result; use ruma::{DeviceId, RoomId, UserId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn lazy_load_was_sent_before( &self, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 48317458..81dd44ff 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,10 +1,13 @@ -use crate::Result; use ruma::{OwnedRoomId, RoomId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { /// Checks if a room exists. fn exists(&self, room_id: &RoomId) -> Result; - fn iter_ids<'a>(&'a self) -> Box> + 'a>; + fn iter_ids<'a>( + &'a self, + ) -> Box> + 'a>; fn is_disabled(&self, room_id: &RoomId) -> Result; fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index 8956b491..3357c5c9 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -4,8 +4,15 @@ use crate::{PduEvent, Result}; pub(crate) trait Data: Send + Sync { /// Returns the pdu from the outlier tree. - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; /// Append the PDU as an outlier. - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; + fn add_pdu_outlier( + &self, + event_id: &EventId, + pdu: &CanonicalJsonObject, + ) -> Result<()>; } diff --git a/src/service/rooms/pdu_metadata.rs b/src/service/rooms/pdu_metadata.rs index 5a778f6b..d8acc698 100644 --- a/src/service/rooms/pdu_metadata.rs +++ b/src/service/rooms/pdu_metadata.rs @@ -9,9 +9,8 @@ use ruma::{ }; use serde::Deserialize; -use crate::{services, PduEvent, Result}; - use super::timeline::PduCount; +use crate::{services, PduEvent, Result}; pub(crate) struct Service { pub(crate) db: &'static dyn Data, @@ -29,9 +28,15 @@ struct ExtractRelatesToEventId { impl Service { #[tracing::instrument(skip(self, from, to))] - pub(crate) fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub(crate) fn add_relation( + &self, + from: PduCount, + to: PduCount, + ) -> Result<()> { match (from, to) { - (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), + (PduCount::Normal(f), PduCount::Normal(t)) => { + self.db.add_relation(f, t) + } _ => { // TODO: Relations with backfilled pdus @@ -42,6 +47,7 @@ impl Service { #[allow( clippy::too_many_arguments, + clippy::too_many_lines, // Allowed because this function uses `services()` clippy::unused_self, )] @@ -68,15 +74,17 @@ impl Service { .relations_until(sender_user, room_id, target, from)? .filter(|r| { r.as_ref().map_or(true, |(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::( - pdu.content.get(), - ) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &&content.relates_to.rel_type == r) + filter_event_type + .as_ref() + .map_or(true, |t| &&pdu.kind == t) + && if let Ok(content) = serde_json::from_str::< + ExtractRelatesToEventId, + >( + pdu.content.get() + ) { + filter_rel_type.as_ref().map_or(true, |r| { + &&content.relates_to.rel_type == r + }) } else { false } @@ -88,13 +96,18 @@ impl Service { services() .rooms .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) + .user_can_see_event( + sender_user, + room_id, + &pdu.event_id, + ) .unwrap_or(false) }) .take_while(|&(k, _)| Some(k) != to) .collect(); - next_token = events_after.last().map(|(count, _)| count).copied(); + next_token = + events_after.last().map(|(count, _)| count).copied(); // Reversed because relations are always most recent first let events_after: Vec<_> = events_after @@ -116,15 +129,17 @@ impl Service { .relations_until(sender_user, room_id, target, from)? .filter(|r| { r.as_ref().map_or(true, |(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::( - pdu.content.get(), - ) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &&content.relates_to.rel_type == r) + filter_event_type + .as_ref() + .map_or(true, |t| &&pdu.kind == t) + && if let Ok(content) = serde_json::from_str::< + ExtractRelatesToEventId, + >( + pdu.content.get() + ) { + filter_rel_type.as_ref().map_or(true, |r| { + &&content.relates_to.rel_type == r + }) } else { false } @@ -136,13 +151,18 @@ impl Service { services() .rooms .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) + .user_can_see_event( + sender_user, + room_id, + &pdu.event_id, + ) .unwrap_or(false) }) .take_while(|&(k, _)| Some(k) != to) .collect(); - next_token = events_before.last().map(|(count, _)| count).copied(); + next_token = + events_before.last().map(|(count, _)| count).copied(); let events_before: Vec<_> = events_before .into_iter() @@ -165,7 +185,8 @@ impl Service { target: &'a EventId, until: PduCount, ) -> Result> + 'a> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let room_id = + services().rooms.short.get_or_create_shortroomid(room_id)?; let target = match services().rooms.timeline.get_pdu_count(target)? { Some(PduCount::Normal(c)) => c, // TODO: Support backfilled relations @@ -185,17 +206,27 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + pub(crate) fn is_event_referenced( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result { self.db.is_event_referenced(room_id, event_id) } #[tracing::instrument(skip(self))] - pub(crate) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + pub(crate) fn mark_event_soft_failed( + &self, + event_id: &EventId, + ) -> Result<()> { self.db.mark_event_soft_failed(event_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { + pub(crate) fn is_event_soft_failed( + &self, + event_id: &EventId, + ) -> Result { self.db.is_event_soft_failed(event_id) } } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 4e28474a..9aace01f 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use crate::{service::rooms::timeline::PduCount, PduEvent, Result}; use ruma::{EventId, RoomId, UserId}; +use crate::{service::rooms::timeline::PduCount, PduEvent, Result}; + pub(crate) trait Data: Send + Sync { fn add_relation(&self, from: u64, to: u64) -> Result<()>; #[allow(clippy::type_complexity)] @@ -13,8 +14,16 @@ pub(crate) trait Data: Send + Sync { target: u64, until: PduCount, ) -> Result> + 'a>>; - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn mark_as_referenced( + &self, + room_id: &RoomId, + event_ids: &[Arc], + ) -> Result<()>; + fn is_event_referenced( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result; fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; fn is_event_soft_failed(&self, event_id: &EventId) -> Result; } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 9e68fe62..903a85e6 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,8 +1,14 @@ -use crate::Result; use ruma::RoomId; +use crate::Result; + pub(crate) trait Data: Send + Sync { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; + fn index_pdu( + &self, + shortroomid: u64, + pdu_id: &[u8], + message_body: &str, + ) -> Result<()>; #[allow(clippy::type_complexity)] fn search_pdus<'a>( diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 295f11a2..dcde51c3 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use crate::Result; use ruma::{events::StateEventType, EventId, RoomId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; @@ -18,12 +19,19 @@ pub(crate) trait Data: Send + Sync { state_key: &str, ) -> Result; - fn get_eventid_from_short(&self, shorteventid: u64) -> Result>; + fn get_eventid_from_short(&self, shorteventid: u64) + -> Result>; - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; + fn get_statekey_from_short( + &self, + shortstatekey: u64, + ) -> Result<(StateEventType, String)>; /// Returns `(shortstatehash, already_existed)` - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; + fn get_or_create_shortstatehash( + &self, + state_hash: &[u8], + ) -> Result<(u64, bool)>; fn get_shortroomid(&self, room_id: &RoomId) -> Result>; diff --git a/src/service/rooms/spaces.rs b/src/service/rooms/spaces.rs index 7e97dbd4..854d6d44 100644 --- a/src/service/rooms/spaces.rs +++ b/src/service/rooms/spaces.rs @@ -15,8 +15,12 @@ use ruma::{ canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{self, AllowRule, JoinRule, RoomJoinRulesEventContent}, + history_visibility::{ + HistoryVisibility, RoomHistoryVisibilityEventContent, + }, + join_rules::{ + self, AllowRule, JoinRule, RoomJoinRulesEventContent, + }, topic::RoomTopicEventContent, }, space::child::SpaceChildEventContent, @@ -26,7 +30,6 @@ use ruma::{ OwnedRoomId, RoomId, UserId, }; use tokio::sync::Mutex; - use tracing::{debug, error, warn}; use crate::{services, Error, PduEvent, Result}; @@ -42,7 +45,8 @@ pub(crate) struct CachedSpaceChunk { } pub(crate) struct Service { - pub(crate) roomid_spacechunk_cache: Mutex>>, + pub(crate) roomid_spacechunk_cache: + Mutex>>, } impl Service { @@ -86,9 +90,11 @@ impl Service { { if let Some(cached) = cached { let allowed = match &cached.join_rule { - CachedJoinRule::Full(f) => { - self.handle_join_rule(f, sender_user, ¤t_room)? - } + CachedJoinRule::Full(f) => self.handle_join_rule( + f, + sender_user, + ¤t_room, + )?, }; if allowed { if left_to_skip > 0 { @@ -104,10 +110,8 @@ impl Service { continue; } - if let Some(current_shortstatehash) = services() - .rooms - .state - .get_room_shortstatehash(¤t_room)? + if let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(¤t_room)? { let state = services() .rooms @@ -124,16 +128,21 @@ impl Service { continue; } - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + let pdu = + services().rooms.timeline.get_pdu(&id)?.ok_or_else( + || { + Error::bad_database( + "Event in space state not found", + ) + }, + )?; - if serde_json::from_str::(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) + if serde_json::from_str::( + pdu.content.get(), + ) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) { continue; } @@ -147,7 +156,11 @@ impl Service { // TODO: Sort children children_ids.reverse(); - let chunk = self.get_room_chunk(sender_user, ¤t_room, children_pdus); + let chunk = self.get_room_chunk( + sender_user, + ¤t_room, + children_pdus, + ); if let Ok(chunk) = chunk { if left_to_skip > 0 { left_to_skip -= 1; @@ -157,13 +170,24 @@ impl Service { let join_rule = services() .rooms .state_accessor - .room_state_get(¤t_room, &StateEventType::RoomJoinRules, "")? + .room_state_get( + ¤t_room, + &StateEventType::RoomJoinRules, + "", + )? .map(|s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") + error!( + "Invalid room join rule event in \ + database: {}", + e + ); + Error::BadDatabase( + "Invalid room join rule event in \ + database.", + ) }) }) .transpose()? @@ -205,7 +229,10 @@ impl Service { ) .await { - warn!("Got response from {server} for /hierarchy\n{response:?}"); + warn!( + "Got response from {server} for \ + /hierarchy\n{response:?}" + ); let chunk = SpaceHierarchyRoomsChunk { canonical_alias: response.room.canonical_alias, name: response.room.name, @@ -250,9 +277,17 @@ impl Service { }) } SpaceRoomJoinRule::Public => JoinRule::Public, - _ => return Err(Error::BadServerResponse("Unknown join rule")), + _ => { + return Err(Error::BadServerResponse( + "Unknown join rule", + )) + } }; - if self.handle_join_rule(&join_rule, sender_user, ¤t_room)? { + if self.handle_join_rule( + &join_rule, + sender_user, + ¤t_room, + )? { if left_to_skip > 0 { left_to_skip -= 1; } else { @@ -301,12 +336,18 @@ impl Service { canonical_alias: services() .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? + .room_state_get( + room_id, + &StateEventType::RoomCanonicalAlias, + "", + )? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomCanonicalAliasEventContent| c.alias) .map_err(|_| { - Error::bad_database("Invalid canonical alias event in database.") + Error::bad_database( + "Invalid canonical alias event in database.", + ) }) })?, name: services().rooms.state_accessor.get_name(room_id)?, @@ -329,22 +370,34 @@ impl Service { serde_json::from_str(s.content.get()) .map(|c: RoomTopicEventContent| Some(c.topic)) .map_err(|_| { - error!("Invalid room topic event in database for room {}", room_id); - Error::bad_database("Invalid room topic event in database.") + error!( + "Invalid room topic event in database for \ + room {}", + room_id + ); + Error::bad_database( + "Invalid room topic event in database.", + ) }) })?, world_readable: services() .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get( + room_id, + &StateEventType::RoomHistoryVisibility, + "", + )? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable + c.history_visibility + == HistoryVisibility::WorldReadable }) .map_err(|_| { Error::bad_database( - "Invalid room history visibility event in database.", + "Invalid room history visibility event in \ + database.", ) }) })?, @@ -358,7 +411,9 @@ impl Service { c.guest_access == GuestAccess::CanJoin }) .map_err(|_| { - Error::bad_database("Invalid room guest access event in database.") + Error::bad_database( + "Invalid room guest access event in database.", + ) }) })?, avatar_url: services() @@ -368,7 +423,11 @@ impl Service { .map(|s| { serde_json::from_str(s.content.get()) .map(|c: RoomAvatarEventContent| c.url) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + .map_err(|_| { + Error::bad_database( + "Invalid room avatar event in database.", + ) + }) }) .transpose()? .flatten(), @@ -376,13 +435,23 @@ impl Service { let join_rule = services() .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? + .room_state_get( + room_id, + &StateEventType::RoomJoinRules, + "", + )? .map(|s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") + error!( + "Invalid room join rule event in \ + database: {}", + e + ); + Error::BadDatabase( + "Invalid room join rule event in database.", + ) }) }) .transpose()? @@ -404,9 +473,14 @@ impl Service { .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .map(|s| { - serde_json::from_str::(s.content.get()).map_err(|e| { + serde_json::from_str::( + s.content.get(), + ) + .map_err(|e| { error!("Invalid room create event in database: {}", e); - Error::BadDatabase("Invalid room create event in database.") + Error::BadDatabase( + "Invalid room create event in database.", + ) }) }) .transpose()? @@ -424,7 +498,9 @@ impl Service { JoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), JoinRule::Private => Ok(SpaceRoomJoinRule::Private), JoinRule::Restricted(_) => Ok(SpaceRoomJoinRule::Restricted), - JoinRule::KnockRestricted(_) => Ok(SpaceRoomJoinRule::KnockRestricted), + JoinRule::KnockRestricted(_) => { + Ok(SpaceRoomJoinRule::KnockRestricted) + } JoinRule::Public => Ok(SpaceRoomJoinRule::Public), _ => Err(Error::BadServerResponse("Unknown join rule")), } @@ -440,10 +516,9 @@ impl Service { ) -> Result { let allowed = match join_rule { SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::Public => true, - SpaceRoomJoinRule::Invite => services() - .rooms - .state_cache - .is_joined(sender_user, room_id)?, + SpaceRoomJoinRule::Invite => { + services().rooms.state_cache.is_joined(sender_user, room_id)? + } _ => false, }; diff --git a/src/service/rooms/state.rs b/src/service/rooms/state.rs index 0d4f91a0..07e5fc5e 100644 --- a/src/service/rooms/state.rs +++ b/src/service/rooms/state.rs @@ -19,9 +19,8 @@ use serde::Deserialize; use tokio::sync::MutexGuard; use tracing::warn; -use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; - use super::state_compressor::CompressedStateEvent; +use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; pub(crate) struct Service { pub(crate) db: &'static dyn Data, @@ -46,12 +45,15 @@ impl Service { .ok() .map(|(_, id)| id) }) { - let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else { + let Some(pdu) = + services().rooms.timeline.get_pdu_json(&event_id)? + else { continue; }; let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), + &serde_json::to_string(&pdu) + .expect("CanonicalJsonObj can be serialized to JSON"), ) { Ok(pdu) => pdu, Err(_) => continue, @@ -65,7 +67,9 @@ impl Service { } let membership = - match serde_json::from_str::(pdu.content.get()) { + match serde_json::from_str::( + pdu.content.get(), + ) { Ok(e) => e.membership, Err(_) => continue, }; @@ -102,8 +106,7 @@ impl Service { services().rooms.state_cache.update_joined_count(room_id)?; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash, state_lock)?; Ok(()) } @@ -119,24 +122,18 @@ impl Service { room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(event_id)?; + let shorteventid = + services().rooms.short.get_or_create_shorteventid(event_id)?; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = + self.db.get_room_shortstatehash(room_id)?; let state_hash = calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::>(), + &state_ids_compressed.iter().map(|s| &s[..]).collect::>(), ); - let (shortstatehash, already_existed) = services() - .rooms - .short - .get_or_create_shortstatehash(&state_hash)?; + let (shortstatehash, already_existed) = + services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( @@ -192,7 +189,8 @@ impl Service { .short .get_or_create_shorteventid(&new_pdu.event_id)?; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = + self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.db.set_event_state(shorteventid, p)?; @@ -209,10 +207,11 @@ impl Service { }, )?; - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &new_pdu.kind.to_string().into(), + state_key, + )?; let new = services() .rooms @@ -222,9 +221,9 @@ impl Service { let replaces = states_parents .last() .map(|info| { - info.1 - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + info.1.iter().find(|bytes| { + bytes.starts_with(&shortstatekey.to_be_bytes()) + }) }) .unwrap_or_default(); @@ -253,7 +252,8 @@ impl Service { Ok(shortstatehash) } else { - Ok(previous_shortstatehash.expect("first event in room must be a state event")) + Ok(previous_shortstatehash + .expect("first event in room must be a state event")) } } @@ -325,7 +325,10 @@ impl Service { /// Returns the room's version. #[tracing::instrument(skip(self))] - pub(crate) fn get_room_version(&self, room_id: &RoomId) -> Result { + pub(crate) fn get_room_version( + &self, + room_id: &RoomId, + ) -> Result { let create_event = services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomCreate, @@ -341,12 +344,20 @@ impl Service { }) }) .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; + .ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "No create event found", + ) + })?; Ok(create_event_content.room_version) } - pub(crate) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + pub(crate) fn get_room_shortstatehash( + &self, + room_id: &RoomId, + ) -> Result> { self.db.get_room_shortstatehash(room_id) } @@ -364,8 +375,7 @@ impl Service { // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, ) -> Result<()> { - self.db - .set_forward_extremities(room_id, event_ids, state_lock) + self.db.set_forward_extremities(room_id, event_ids, state_lock) } /// This fetches auth events from the current state. @@ -378,12 +388,15 @@ impl Service { state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { - let Some(shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { + let Some(shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + else { return Ok(HashMap::new()); }; - let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) - .expect("content is a valid JSON object"); + let auth_events = + state_res::auth_types_for_event(kind, sender, state_key, content) + .expect("content is a valid JSON object"); let mut sauthevents = auth_events .into_iter() @@ -391,7 +404,10 @@ impl Service { services() .rooms .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) + .get_shortstatekey( + &event_type.to_string().into(), + &state_key, + ) .ok() .flatten() .map(|s| (s, (event_type, state_key))) diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 8e267760..2ed50def 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,8 +1,10 @@ -use crate::Result; -use ruma::{EventId, OwnedEventId, RoomId}; use std::{collections::HashSet, sync::Arc}; + +use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; +use crate::Result; + pub(crate) trait Data: Send + Sync { /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; @@ -17,10 +19,17 @@ pub(crate) trait Data: Send + Sync { ) -> Result<()>; /// Associates a state with an event. - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; + fn set_event_state( + &self, + shorteventid: u64, + shortstatehash: u64, + ) -> Result<()>; /// Returns all events we would send as the `prev_events` of the next event. - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>>; + fn get_forward_extremities( + &self, + room_id: &RoomId, + ) -> Result>>; /// Replace the forward extremities of the room. fn set_forward_extremities( diff --git a/src/service/rooms/state_accessor.rs b/src/service/rooms/state_accessor.rs index 420ad76f..d980b3bf 100644 --- a/src/service/rooms/state_accessor.rs +++ b/src/service/rooms/state_accessor.rs @@ -10,7 +10,9 @@ use ruma::{ events::{ room::{ avatar::RoomAvatarEventContent, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + history_visibility::{ + HistoryVisibility, RoomHistoryVisibilityEventContent, + }, member::{MembershipState, RoomMemberEventContent}, name::RoomNameEventContent, power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, @@ -18,7 +20,8 @@ use ruma::{ StateEventType, }, state_res::Event, - EventId, JsOption, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + EventId, JsOption, OwnedServerName, OwnedUserId, RoomId, ServerName, + UserId, }; use serde_json::value::to_raw_value; use tokio::sync::MutexGuard; @@ -28,7 +31,8 @@ use crate::{service::pdu::PduBuilder, services, Error, PduEvent, Result}; pub(crate) struct Service { pub(crate) db: &'static dyn Data, - pub(crate) server_visibility_cache: Mutex>, + pub(crate) server_visibility_cache: + Mutex>, pub(crate) user_visibility_cache: Mutex>, } @@ -50,7 +54,8 @@ impl Service { self.db.state_full(shortstatehash).await } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). #[tracing::instrument(skip(self))] pub(crate) fn state_get_id( &self, @@ -61,7 +66,8 @@ impl Service { self.db.state_get_id(shortstatehash, event_type, state_key) } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). pub(crate) fn state_get( &self, shortstatehash: u64, @@ -72,7 +78,11 @@ impl Service { } /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { + fn user_membership( + &self, + shortstatehash: u64, + user_id: &UserId, + ) -> Result { self.state_get( shortstatehash, &StateEventType::RoomMember, @@ -81,7 +91,11 @@ impl Service { .map_or(Ok(MembershipState::Leave), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomMemberEventContent| c.membership) - .map_err(|_| Error::bad_database("Invalid room membership event in database.")) + .map_err(|_| { + Error::bad_database( + "Invalid room membership event in database.", + ) + }) }) } @@ -123,12 +137,20 @@ impl Service { } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? + .state_get( + shortstatehash, + &StateEventType::RoomHistoryVisibility, + "", + )? .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }) .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") + Error::bad_database( + "Invalid history visibility event in database.", + ) }) })?; @@ -140,14 +162,20 @@ impl Service { .filter(|member| member.server_name() == origin); let visibility = match history_visibility { - HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, + HistoryVisibility::WorldReadable | HistoryVisibility::Shared => { + true + } HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + // Allow if any member on requesting server was AT LEAST + // invited, else deny + current_server_members.any(|member| { + self.user_was_invited(shortstatehash, &member) + }) } HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_joined(shortstatehash, &member)) } _ => { error!("Unknown history visibility {history_visibility}"); @@ -185,15 +213,24 @@ impl Service { return Ok(*visibility); } - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = + services().rooms.state_cache.is_joined(user_id, room_id)?; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? + .state_get( + shortstatehash, + &StateEventType::RoomHistoryVisibility, + "", + )? .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }) .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") + Error::bad_database( + "Invalid history visibility event in database.", + ) }) })?; @@ -201,7 +238,8 @@ impl Service { HistoryVisibility::WorldReadable => true, HistoryVisibility::Shared => currently_member, HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny + // Allow if any member on requesting server was AT LEAST + // invited, else deny self.user_was_invited(shortstatehash, user_id) } HistoryVisibility::Joined => { @@ -230,23 +268,36 @@ impl Service { user_id: &UserId, room_id: &RoomId, ) -> Result { - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = + services().rooms.state_cache.is_joined(user_id, room_id)?; let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get( + room_id, + &StateEventType::RoomHistoryVisibility, + "", + )? .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }) .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") + Error::bad_database( + "Invalid history visibility event in database.", + ) }) })?; - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + Ok(currently_member + || history_visibility == HistoryVisibility::WorldReadable) } /// Returns the state hash for this pdu. - pub(crate) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + pub(crate) fn pdu_shortstatehash( + &self, + event_id: &EventId, + ) -> Result> { self.db.pdu_shortstatehash(event_id) } @@ -259,7 +310,8 @@ impl Service { self.db.room_state_full(room_id).await } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). #[tracing::instrument(skip(self))] pub(crate) fn room_state_get_id( &self, @@ -270,7 +322,8 @@ impl Service { self.db.room_state_get_id(room_id, event_type, state_key) } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). #[tracing::instrument(skip(self))] pub(crate) fn room_state_get( &self, @@ -282,26 +335,39 @@ impl Service { } pub(crate) fn get_name(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { + self.room_state_get(room_id, &StateEventType::RoomName, "")?.map_or( + Ok(None), + |s| { serde_json::from_str(s.content.get()) .map(|c: RoomNameEventContent| Some(c.name)) .map_err(|e| { error!( - "Invalid room name event in database for room {}. {}", + "Invalid room name event in database for room {}. \ + {}", room_id, e ); - Error::bad_database("Invalid room name event in database.") + Error::bad_database( + "Invalid room name event in database.", + ) }) - }) + }, + ) } - pub(crate) fn get_avatar(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(JsOption::Undefined), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }) + pub(crate) fn get_avatar( + &self, + room_id: &RoomId, + ) -> Result> { + self.room_state_get(room_id, &StateEventType::RoomAvatar, "")?.map_or( + Ok(JsOption::Undefined), + |s| { + serde_json::from_str(s.content.get()).map_err(|_| { + Error::bad_database( + "Invalid room avatar event in database.", + ) + }) + }, + ) } // Allowed because this function uses `services()` @@ -313,8 +379,9 @@ impl Service { target_user: &UserId, state_lock: &MutexGuard<'_, ()>, ) -> bool { - let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) - .expect("Event content always serializes"); + let content = + to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) + .expect("Event content always serializes"); let new_event = PduBuilder { event_type: ruma::events::TimelineEventType::RoomMember, @@ -336,18 +403,23 @@ impl Service { room_id: &RoomId, user_id: &UserId, ) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) + self.room_state_get( + room_id, + &StateEventType::RoomMember, + user_id.as_str(), + )? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()).map_err(|_| { + Error::bad_database("Invalid room member event in database.") }) + }) } /// Checks if a given user can redact a given event /// - /// If `federation` is `true`, it allows redaction events from any user of the same server - /// as the original event sender, [as required by room versions >= - /// v3](https://spec.matrix.org/v1.10/rooms/v11/#handling-redactions) + /// If `federation` is `true`, it allows redaction events from any user of + /// the same server as the original event sender, [as required by room + /// versions >= v3](https://spec.matrix.org/v1.10/rooms/v11/#handling-redactions) pub(crate) fn user_can_redact( &self, redacts: &EventId, @@ -359,18 +431,23 @@ impl Service { .map_or_else( // Falling back on m.room.create to judge power levels || { - if let Some(pdu) = - self.room_state_get(room_id, &StateEventType::RoomCreate, "")? - { + if let Some(pdu) = self.room_state_get( + room_id, + &StateEventType::RoomCreate, + "", + )? { Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { + || if let Ok(Some(pdu)) = + services().rooms.timeline.get_pdu(redacts) + { pdu.sender == sender } else { false }) } else { Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", + "No m.room.power_levels or m.room.create events \ + in database for room", )) } }, @@ -380,11 +457,14 @@ impl Service { .map(|e: RoomPowerLevels| { e.user_can_redact_event_of_other(sender) || e.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = - services().rooms.timeline.get_pdu(redacts) + && if let Ok(Some(pdu)) = services() + .rooms + .timeline + .get_pdu(redacts) { if federation { - pdu.sender().server_name() == sender.server_name() + pdu.sender().server_name() + == sender.server_name() } else { pdu.sender == sender } @@ -393,7 +473,9 @@ impl Service { } }) .map_err(|_| { - Error::bad_database("Invalid m.room.power_levels event in database") + Error::bad_database( + "Invalid m.room.power_levels event in database", + ) }) }, ) diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index a57cdb9d..68214f1d 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -9,14 +9,18 @@ use crate::{PduEvent, Result}; pub(crate) trait Data: Send + Sync { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; + async fn state_full_ids( + &self, + shortstatehash: u64, + ) -> Result>>; async fn state_full( &self, shortstatehash: u64, ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn state_get_id( &self, shortstatehash: u64, @@ -24,7 +28,8 @@ pub(crate) trait Data: Send + Sync { state_key: &str, ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn state_get( &self, shortstatehash: u64, @@ -41,7 +46,8 @@ pub(crate) trait Data: Send + Sync { room_id: &RoomId, ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn room_state_get_id( &self, room_id: &RoomId, @@ -49,7 +55,8 @@ pub(crate) trait Data: Send + Sync { state_key: &str, ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). fn room_state_get( &self, room_id: &RoomId, diff --git a/src/service/rooms/state_cache.rs b/src/service/rooms/state_cache.rs index 9b15cd88..bf4ce632 100644 --- a/src/service/rooms/state_cache.rs +++ b/src/service/rooms/state_cache.rs @@ -2,7 +2,6 @@ mod data; use std::{collections::HashSet, sync::Arc}; pub(crate) use data::Data; - use ruma::{ events::{ direct::DirectEvent, @@ -34,7 +33,8 @@ impl Service { last_state: Option>>, update_joined_count: bool, ) -> Result<()> { - // Keep track what remote users exist by adding them as "deactivated" users + // Keep track what remote users exist by adding them as "deactivated" + // users if user_id.server_name() != services().globals.server_name() { services().users.create(user_id, None)?; // TODO: displayname, avatar url @@ -51,17 +51,26 @@ impl Service { if let Some(predecessor) = services() .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) + .room_state_get( + room_id, + &StateEventType::RoomCreate, + "", + )? + .and_then(|create| { + serde_json::from_str(create.content.get()).ok() + }) + .and_then(|content: RoomCreateEventContent| { + content.predecessor + }) { - // Copy user settings from predecessor to the current room: + // Copy user settings from predecessor to the current + // room: // - Push rules // // TODO: finish this once push rules are implemented. // - // let mut push_rules_event_content: PushRulesEvent = account_data - // .get( + // let mut push_rules_event_content: PushRulesEvent = + // account_data .get( // None, // user_id, // EventType::PushRules, @@ -90,8 +99,13 @@ impl Service { )? .map(|event| { serde_json::from_str(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") + warn!( + "Invalid account data event in db: \ + {e:?}" + ); + Error::BadDatabase( + "Invalid account data event in db.", + ) }) }) { @@ -112,20 +126,32 @@ impl Service { .get( None, user_id, - GlobalAccountDataEventType::Direct.to_string().into(), + GlobalAccountDataEventType::Direct + .to_string() + .into(), )? .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str::(event.get()) + .map_err(|e| { + warn!( + "Invalid account data event in \ + db: {e:?}" + ); + Error::BadDatabase( + "Invalid account data event in db.", + ) + }) }) { let mut direct_event = direct_event?; let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { - if room_ids.iter().any(|r| r == &predecessor.room_id) { + for room_ids in direct_event.content.0.values_mut() + { + if room_ids + .iter() + .any(|r| r == &predecessor.room_id) + { room_ids.push(room_id.to_owned()); room_ids_updated = true; } @@ -135,7 +161,9 @@ impl Service { services().account_data.update( None, user_id, - GlobalAccountDataEventType::Direct.to_string().into(), + GlobalAccountDataEventType::Direct + .to_string() + .into(), &serde_json::to_value(&direct_event) .expect("to json always works"), )?; @@ -160,9 +188,14 @@ impl Service { .into(), )? .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { + serde_json::from_str::( + event.get(), + ) + .map_err(|e| { warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") + Error::BadDatabase( + "Invalid account data event in db.", + ) }) }) .transpose()? @@ -199,7 +232,10 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - pub(crate) fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { + pub(crate) fn get_our_real_users( + &self, + room_id: &RoomId, + ) -> Result>> { self.db.get_our_real_users(room_id) } @@ -214,7 +250,11 @@ impl Service { /// Makes a user forget a room. #[tracing::instrument(skip(self))] - pub(crate) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub(crate) fn forget( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()> { self.db.forget(room_id, user_id) } @@ -228,11 +268,16 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + pub(crate) fn server_in_room( + &self, + server: &ServerName, + room_id: &RoomId, + ) -> Result { self.db.server_in_room(server, room_id) } - /// Returns an iterator of all rooms a server participates in (as far as we know). + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). #[tracing::instrument(skip(self))] pub(crate) fn server_rooms<'a>( &'a self, @@ -251,12 +296,18 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn room_joined_count(&self, room_id: &RoomId) -> Result> { + pub(crate) fn room_joined_count( + &self, + room_id: &RoomId, + ) -> Result> { self.db.room_joined_count(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn room_invited_count(&self, room_id: &RoomId) -> Result> { + pub(crate) fn room_invited_count( + &self, + room_id: &RoomId, + ) -> Result> { self.db.room_invited_count(room_id) } @@ -288,7 +339,11 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(crate) fn get_left_count( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result> { self.db.get_left_count(room_id, user_id) } @@ -306,7 +361,9 @@ impl Service { pub(crate) fn rooms_invited<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator>)>> + 'a { + ) -> impl Iterator< + Item = Result<(OwnedRoomId, Vec>)>, + > + 'a { self.db.rooms_invited(user_id) } @@ -333,27 +390,44 @@ impl Service { pub(crate) fn rooms_left<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator>)>> + 'a { + ) -> impl Iterator>)>> + 'a + { self.db.rooms_left(user_id) } #[tracing::instrument(skip(self))] - pub(crate) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(crate) fn once_joined( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { self.db.once_joined(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(crate) fn is_joined( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { self.db.is_joined(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(crate) fn is_invited( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { self.db.is_invited(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(crate) fn is_left( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result { self.db.is_left(user_id, room_id) } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index c92ec848..94791c85 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,14 +1,19 @@ use std::{collections::HashSet, sync::Arc}; -use crate::{service::appservice::RegistrationInfo, Result}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +use crate::{service::appservice::RegistrationInfo, Result}; + pub(crate) trait Data: Send + Sync { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_once_joined( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_invited( &self, @@ -20,9 +25,16 @@ pub(crate) trait Data: Send + Sync { fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; - fn get_our_real_users(&self, room_id: &RoomId) -> Result>>; + fn get_our_real_users( + &self, + room_id: &RoomId, + ) -> Result>>; - fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result; + fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &RegistrationInfo, + ) -> Result; /// Makes a user forget a room. fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; @@ -33,9 +45,14 @@ pub(crate) trait Data: Send + Sync { room_id: &RoomId, ) -> Box> + 'a>; - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result; + fn server_in_room( + &self, + server: &ServerName, + room_id: &RoomId, + ) -> Result; - /// Returns an iterator of all rooms a server participates in (as far as we know). + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). fn server_rooms<'a>( &'a self, server: &ServerName, @@ -63,9 +80,17 @@ pub(crate) trait Data: Send + Sync { room_id: &RoomId, ) -> Box> + 'a>; - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_invite_count( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result>; - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_left_count( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result>; /// Returns an iterator over all rooms this user joined. fn rooms_joined<'a>( @@ -78,7 +103,11 @@ pub(crate) trait Data: Send + Sync { fn rooms_invited<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a>; + ) -> Box< + dyn Iterator< + Item = Result<(OwnedRoomId, Vec>)>, + > + 'a, + >; fn invite_state( &self, @@ -97,7 +126,10 @@ pub(crate) trait Data: Send + Sync { fn rooms_left<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a>; + ) -> Box< + dyn Iterator>)>> + + 'a, + >; fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/state_compressor.rs b/src/service/rooms/state_compressor.rs index 13f0912a..6e5fedf0 100644 --- a/src/service/rooms/state_compressor.rs +++ b/src/service/rooms/state_compressor.rs @@ -9,9 +9,8 @@ pub(crate) use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use crate::{services, utils, Result}; - use self::data::StateDiff; +use crate::{services, utils, Result}; pub(crate) struct Service { pub(crate) db: &'static dyn Data, @@ -37,7 +36,8 @@ pub(crate) struct Service { pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; impl Service { - /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + /// Returns a stack with info on shortstatehash, full state, added diff and + /// removed diff for the selected shortstatehash and each parent layer. #[allow(clippy::type_complexity)] #[tracing::instrument(skip(self))] pub(crate) fn load_shortstatehash_info( @@ -55,11 +55,8 @@ impl Service { Arc>, )>, > { - if let Some(r) = self - .stateinfo_cache - .lock() - .unwrap() - .get_mut(&shortstatehash) + if let Some(r) = + self.stateinfo_cache.lock().unwrap().get_mut(&shortstatehash) { return Ok(r.clone()); } @@ -79,7 +76,12 @@ impl Service { state.remove(r); } - response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); + response.push(( + shortstatehash, + Arc::new(state), + added, + Arc::new(removed), + )); self.stateinfo_cache .lock() @@ -88,7 +90,8 @@ impl Service { Ok(response) } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; + let response = + vec![(shortstatehash, added.clone(), added, removed)]; self.stateinfo_cache .lock() .unwrap() @@ -132,19 +135,24 @@ impl Service { )) } - /// Creates a new shortstatehash that often is just a diff to an already existing - /// shortstatehash and therefore very efficient. + /// Creates a new shortstatehash that often is just a diff to an already + /// existing shortstatehash and therefore very efficient. /// - /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer - /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 - /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's - /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. + /// There are multiple layers of diffs. The bottom layer 0 always contains + /// the full state. Layer 1 contains diffs to states of layer 0, layer 2 + /// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be + /// combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively + /// fix above layers too. /// /// * `shortstatehash` - Shortstatehash of this state /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid - /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid - /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer - /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + /// * `statediffremoved` - Removed from base. Each vec is + /// shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time + /// for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, + /// added diff and removed diff for each parent layer #[allow(clippy::type_complexity)] #[tracing::instrument(skip( self, @@ -185,7 +193,8 @@ impl Service { // It was not added in the parent and we removed it parent_removed.insert(*removed); } - // Else it was added in the parent and we removed it again. We can forget this change + // Else it was added in the parent and we removed it again. We + // can forget this change } for new in statediffnew.iter() { @@ -193,7 +202,8 @@ impl Service { // It was not touched in the parent and we added it parent_new.insert(*new); } - // Else it was removed in the parent and we added it again. We can forget this change + // Else it was removed in the parent and we added it again. We + // can forget this change } self.save_state_from_diff( @@ -238,7 +248,8 @@ impl Service { // It was not added in the parent and we removed it parent_removed.insert(*removed); } - // Else it was added in the parent and we removed it again. We can forget this change + // Else it was added in the parent and we removed it again. We + // can forget this change } for new in statediffnew.iter() { @@ -246,7 +257,8 @@ impl Service { // It was not touched in the parent and we added it parent_new.insert(*new); } - // Else it was removed in the parent and we added it again. We can forget this change + // Else it was removed in the parent and we added it again. We + // can forget this change } self.save_state_from_diff( @@ -271,7 +283,8 @@ impl Service { Ok(()) } - /// Returns the new shortstatehash, and the state diff from the previous room state + /// Returns the new shortstatehash, and the state diff from the previous + /// room state #[allow(clippy::type_complexity)] pub(crate) fn save_state( &self, @@ -282,7 +295,8 @@ impl Service { Arc>, Arc>, )> { - let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = + services().rooms.state.get_room_shortstatehash(room_id)?; let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -291,10 +305,8 @@ impl Service { .collect::>(), ); - let (new_shortstatehash, already_existed) = services() - .rooms - .short - .get_or_create_shortstatehash(&state_hash)?; + let (new_shortstatehash, already_existed) = + services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok(( @@ -304,26 +316,28 @@ impl Service { )); } - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = previous_shortstatehash.map_or_else( + || Ok(Vec::new()), + |p| self.load_shortstatehash_info(p), + )?; - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() - { - let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: HashSet<_> = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&new_state_ids_compressed) - .copied() - .collect(); + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .copied() + .collect(); - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (new_state_ids_compressed, Arc::new(HashSet::new())) - }; + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (new_state_ids_compressed, Arc::new(HashSet::new())) + }; if !already_existed { self.save_state_from_diff( diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 3000b4ad..3d9ffc19 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -11,5 +11,9 @@ pub(crate) struct StateDiff { pub(crate) trait Data: Send + Sync { fn get_statediff(&self, shortstatehash: u64) -> Result; - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; + fn save_statediff( + &self, + shortstatehash: u64, + diff: StateDiff, + ) -> Result<()>; } diff --git a/src/service/rooms/threads.rs b/src/service/rooms/threads.rs index 99f28c8d..c2e9ed11 100644 --- a/src/service/rooms/threads.rs +++ b/src/service/rooms/threads.rs @@ -6,7 +6,6 @@ use ruma::{ events::relation::BundledThread, uint, CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, UserId, }; - use serde_json::json; use crate::{services, Error, PduEvent, Result}; @@ -26,51 +25,64 @@ impl Service { self.db.threads_until(user_id, room_id, until, include) } - pub(crate) fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { - let root_id = &services() - .rooms - .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid event id in thread message", - ) - })?; + pub(crate) fn add_to_thread( + &self, + root_event_id: &EventId, + pdu: &PduEvent, + ) -> Result<()> { + let root_id = + &services().rooms.timeline.get_pdu_id(root_event_id)?.ok_or_else( + || { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid event id in thread message", + ) + }, + )?; - let root_pdu = services() - .rooms - .timeline - .get_pdu_from_id(root_id)? - .ok_or_else(|| { - Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found") - })?; + let root_pdu = + services().rooms.timeline.get_pdu_from_id(root_id)?.ok_or_else( + || { + Error::BadRequest( + ErrorKind::InvalidParam, + "Thread root pdu not found", + ) + }, + )?; let mut root_pdu_json = services() .rooms .timeline .get_pdu_json_from_id(root_id)? .ok_or_else(|| { - Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found") + Error::BadRequest( + ErrorKind::InvalidParam, + "Thread root pdu not found", + ) })?; - if let CanonicalJsonValue::Object(unsigned) = root_pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(CanonicalJsonObject::default())) + if let CanonicalJsonValue::Object(unsigned) = + root_pdu_json.entry("unsigned".to_owned()).or_insert_with(|| { + CanonicalJsonValue::Object(CanonicalJsonObject::default()) + }) { if let Some(mut relations) = unsigned .get("m.relations") .and_then(|r| r.as_object()) .and_then(|r| r.get("m.thread")) .and_then(|relations| { - serde_json::from_value::(relations.clone().into()).ok() + serde_json::from_value::( + relations.clone().into(), + ) + .ok() }) { // Thread already existed relations.count += uint!(1); relations.latest_event = pdu.to_message_like_event(); - let content = serde_json::to_value(relations).expect("to_value always works"); + let content = serde_json::to_value(relations) + .expect("to_value always works"); unsigned.insert( "m.relations".to_owned(), @@ -86,7 +98,8 @@ impl Service { current_user_participated: true, }; - let content = serde_json::to_value(relations).expect("to_value always works"); + let content = serde_json::to_value(relations) + .expect("to_value always works"); unsigned.insert( "m.relations".to_owned(), @@ -96,10 +109,11 @@ impl Service { ); } - services() - .rooms - .timeline - .replace_pdu(root_id, &root_pdu_json, &root_pdu)?; + services().rooms.timeline.replace_pdu( + root_id, + &root_pdu_json, + &root_pdu, + )?; } let mut users = Vec::new(); diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 42596e9b..8a1607db 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,5 +1,9 @@ +use ruma::{ + api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, + UserId, +}; + use crate::{PduEvent, Result}; -use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; pub(crate) trait Data: Send + Sync { #[allow(clippy::type_complexity)] @@ -11,6 +15,13 @@ pub(crate) trait Data: Send + Sync { include: &'a IncludeThreads, ) -> Result> + 'a>>; - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; - fn get_participants(&self, root_id: &[u8]) -> Result>>; + fn update_participants( + &self, + root_id: &[u8], + participants: &[OwnedUserId], + ) -> Result<()>; + fn get_participants( + &self, + root_id: &[u8], + ) -> Result>>; } diff --git a/src/service/rooms/timeline.rs b/src/service/rooms/timeline.rs index 42b802bc..1cc479fc 100644 --- a/src/service/rooms/timeline.rs +++ b/src/service/rooms/timeline.rs @@ -7,29 +7,31 @@ use std::{ }; pub(crate) use data::Data; - use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, events::{ push_rules::PushRulesEvent, room::{ - create::RoomCreateEventContent, encrypted::Relation, member::MembershipState, - power_levels::RoomPowerLevelsEventContent, redaction::RoomRedactionEventContent, + create::RoomCreateEventContent, encrypted::Relation, + member::MembershipState, power_levels::RoomPowerLevelsEventContent, + redaction::RoomRedactionEventContent, }, GlobalAccountDataEventType, StateEventType, TimelineEventType, }, push::{Action, Ruleset, Tweak}, serde::Base64, state_res::{self, Event, RoomVersion}, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, - OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, + uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, + OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, + ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::{Mutex, MutexGuard, RwLock}; use tracing::{error, info, warn}; +use super::state_compressor::CompressedStateEvent; use crate::{ api::server_server, service::{ @@ -39,8 +41,6 @@ use crate::{ services, utils, Error, PduEvent, Result, }; -use super::state_compressor::CompressedStateEvent; - #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] pub(crate) enum PduCount { Backfilled(u64), @@ -48,8 +48,8 @@ pub(crate) enum PduCount { } impl PduCount { - pub(crate) const MIN: Self = Self::Backfilled(u64::MAX); pub(crate) const MAX: Self = Self::Normal(u64::MAX); + pub(crate) const MIN: Self = Self::Backfilled(u64::MAX); pub(crate) fn try_from_string(token: &str) -> Result { if let Some(stripped) = token.strip_prefix('-') { @@ -57,7 +57,12 @@ impl PduCount { } else { token.parse().map(PduCount::Normal) } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid pagination token.", + ) + }) } pub(crate) fn stringify(&self) -> String { @@ -93,7 +98,10 @@ pub(crate) struct Service { impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + pub(crate) fn first_pdu_in_room( + &self, + room_id: &RoomId, + ) -> Result>> { self.all_pdus(user_id!("@doesntmatter:grapevine"), room_id)? .next() .map(|o| o.map(|(_, p)| Arc::new(p))) @@ -110,12 +118,18 @@ impl Service { } /// Returns the `count` of this pdu's id. - pub(crate) fn get_pdu_count(&self, event_id: &EventId) -> Result> { + pub(crate) fn get_pdu_count( + &self, + event_id: &EventId, + ) -> Result> { self.db.get_pdu_count(event_id) } /// Returns the json of a pdu. - pub(crate) fn get_pdu_json(&self, event_id: &EventId) -> Result> { + pub(crate) fn get_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { self.db.get_pdu_json(event_id) } @@ -128,21 +142,30 @@ impl Service { } /// Returns the pdu's id. - pub(crate) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + pub(crate) fn get_pdu_id( + &self, + event_id: &EventId, + ) -> Result>> { self.db.get_pdu_id(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(crate) fn get_pdu(&self, event_id: &EventId) -> Result>> { + pub(crate) fn get_pdu( + &self, + event_id: &EventId, + ) -> Result>> { self.db.get_pdu(event_id) } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(crate) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + pub(crate) fn get_pdu_from_id( + &self, + pdu_id: &[u8], + ) -> Result> { self.db.get_pdu_from_id(pdu_id) } @@ -167,8 +190,8 @@ impl Service { /// Creates a new persisted data unit and adds it to a room. /// - /// By this point the incoming event should be fully authenticated, no auth happens - /// in `append_pdu`. + /// By this point the incoming event should be fully authenticated, no auth + /// happens in `append_pdu`. /// /// Returns pdu id #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] @@ -186,13 +209,15 @@ impl Service { .get_shortroomid(&pdu.room_id)? .expect("room exists"); - // Make unsigned fields correct. This is not properly documented in the spec, but state - // events need to have previous content in the unsigned field, so clients can easily - // interpret things like membership changes + // Make unsigned fields correct. This is not properly documented in the + // spec, but state events need to have previous content in the + // unsigned field, so clients can easily interpret things like + // membership changes if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(CanonicalJsonObject::default())) + if let CanonicalJsonValue::Object(unsigned) = + pdu_json.entry("unsigned".to_owned()).or_insert_with(|| { + CanonicalJsonValue::Object(CanonicalJsonObject::default()) + }) { if let Some(shortstatehash) = services() .rooms @@ -203,14 +228,20 @@ impl Service { if let Some(prev_state) = services() .rooms .state_accessor - .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) + .state_get( + shortstatehash, + &pdu.kind.to_string().into(), + state_key, + ) .unwrap() { unsigned.insert( "prev_content".to_owned(), CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content.clone()) - .expect("event is valid, we just created it"), + utils::to_canonical_object( + prev_state.content.clone(), + ) + .expect("event is valid, we just created it"), ), ); } @@ -225,10 +256,11 @@ impl Service { .rooms .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms - .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + services().rooms.state.set_forward_extremities( + &pdu.room_id, + leaves, + state_lock, + )?; let mutex_insert = Arc::clone( services() @@ -242,13 +274,13 @@ impl Service { let insert_lock = mutex_insert.lock().await; let count1 = services().globals.next_count()?; - // Mark as read first so the sending client doesn't get a notification even if appending - // fails - services() - .rooms - .edus - .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + // Mark as read first so the sending client doesn't get a notification + // even if appending fails + services().rooms.edus.read_receipt.private_read_set( + &pdu.room_id, + &pdu.sender, + count1, + )?; services() .rooms .user @@ -269,8 +301,9 @@ impl Service { .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + serde_json::from_str(ev.content.get()).map_err(|_| { + Error::bad_database("invalid m.room.power_levels event") + }) }) .transpose()? .unwrap_or_default(); @@ -280,10 +313,8 @@ impl Service { let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = services() - .rooms - .state_cache - .get_our_real_users(&pdu.room_id)?; + let mut push_target = + services().rooms.state_cache.get_our_real_users(&pdu.room_id)?; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { @@ -312,8 +343,13 @@ impl Service { GlobalAccountDataEventType::PushRules.to_string().into(), )? .map(|event| { - serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid push rules event in db.")) + serde_json::from_str::(event.get()).map_err( + |_| { + Error::bad_database( + "Invalid push rules event in db.", + ) + }, + ) }) .transpose()? .map_or_else( @@ -353,12 +389,16 @@ impl Service { } } - self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + self.db.increment_notification_counts( + &pdu.room_id, + notifies, + highlights, + )?; match pdu.kind { TimelineEventType::RoomRedaction => { - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let room_version_id = + services().rooms.state.get_room_version(&pdu.room_id)?; match room_version_id { RoomVersionId::V1 | RoomVersionId::V2 @@ -383,10 +423,14 @@ impl Service { } RoomVersionId::V11 => { let content = - serde_json::from_str::(pdu.content.get()) - .map_err(|_| { - Error::bad_database("Invalid content in redaction pdu.") - })?; + serde_json::from_str::( + pdu.content.get(), + ) + .map_err(|_| { + Error::bad_database( + "Invalid content in redaction pdu.", + ) + })?; if let Some(redact_id) = &content.redacts { if services().rooms.state_accessor.user_can_redact( redact_id, @@ -398,7 +442,9 @@ impl Service { } } } - _ => unreachable!("Validity of room version already checked"), + _ => { + unreachable!("Validity of room version already checked") + } }; } TimelineEventType::SpaceChild => { @@ -423,19 +469,27 @@ impl Service { let target_user_id = UserId::parse(state_key.clone()) .expect("This state_key was previously validated"); - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + let content = serde_json::from_str::( + pdu.content.get(), + ) + .map_err(|_| { + Error::bad_database("Invalid content in pdu.") + })?; let invite_state = match content.membership { MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; + let state = services() + .rooms + .state + .calculate_invite_state(pdu)?; Some(state) } _ => None, }; - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth + // Update our membership info, we do this here incase a user + // is invited and immediately leaves we + // need the DB to record the invite event for auth services().rooms.state_cache.update_membership( &pdu.room_id, &target_user_id, @@ -452,14 +506,18 @@ impl Service { body: Option, } - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + let content = + serde_json::from_str::(pdu.content.get()) + .map_err(|_| { + Error::bad_database("Invalid content in pdu.") + })?; if let Some(body) = content.body { - services() - .rooms - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + services().rooms.search.index_pdu( + shortroomid, + &pdu_id, + &body, + )?; let server_user = format!( "@{}:{}", @@ -471,18 +529,25 @@ impl Service { services().globals.server_name() ); - let to_grapevine = body.starts_with(&format!("{server_user}: ")) + let to_grapevine = body + .starts_with(&format!("{server_user}: ")) || body.starts_with(&format!("{server_user} ")) || body == format!("{server_user}:") || body == server_user; - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as grapevine + // This will evaluate to false if the emergency password is + // set up so that the administrator can + // execute commands as grapevine let from_grapevine = pdu.sender == server_user && services().globals.emergency_password().is_none(); - if let Some(admin_room) = services().admin.get_admin_room()? { - if to_grapevine && !from_grapevine && admin_room == pdu.room_id { + if let Some(admin_room) = + services().admin.get_admin_room()? + { + if to_grapevine + && !from_grapevine + && admin_room == pdu.room_id + { services().admin.process_message(body); } } @@ -493,7 +558,9 @@ impl Service { // Update Relationships - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Ok(content) = + serde_json::from_str::(pdu.content.get()) + { if let Some(related_pducount) = services() .rooms .timeline @@ -506,9 +573,13 @@ impl Service { } } - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Ok(content) = + serde_json::from_str::(pdu.content.get()) + { match content.relates_to { - Relation::Reply { in_reply_to } => { + Relation::Reply { + in_reply_to, + } => { // We need to do it again here, because replies don't have // event_id as a top level field if let Some(related_pducount) = services() @@ -516,10 +587,10 @@ impl Service { .timeline .get_pdu_count(&in_reply_to.event_id)? { - services() - .rooms - .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + services().rooms.pdu_metadata.add_relation( + PduCount::Normal(count2), + related_pducount, + )?; } } Relation::Thread(thread) => { @@ -539,21 +610,24 @@ impl Service { .state_cache .appservice_in_room(&pdu.room_id, appservice)? { - services() - .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + services().sending.send_pdu_appservice( + appservice.registration.id.clone(), + pdu_id.clone(), + )?; continue; } - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. + // If the RoomMember event has a non-empty state_key, it is targeted + // at someone. If it is our appservice user, we send + // this PDU to it. if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + if let Some(state_key_uid) = + &pdu.state_key.as_ref().and_then(|state_key| { + UserId::parse(state_key.as_str()).ok() + }) { - let appservice_uid = appservice.registration.sender_localpart.as_str(); + let appservice_uid = + appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { services().sending.send_pdu_appservice( appservice.registration.id.clone(), @@ -567,10 +641,9 @@ impl Service { let matching_users = |users: &NamespaceRegex| { appservice.users.is_match(pdu.sender.as_str()) || pdu.kind == TimelineEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) + && pdu.state_key.as_ref().map_or(false, |state_key| { + users.is_match(state_key) + }) }; let matching_aliases = |aliases: &NamespaceRegex| { services() @@ -585,9 +658,10 @@ impl Service { || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { - services() - .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + services().sending.send_pdu_appservice( + appservice.registration.id.clone(), + pdu_id.clone(), + )?; } } @@ -620,13 +694,13 @@ impl Service { .collect(); // If there was no create event yet, assume we are creating a room - let room_version_id = services() - .rooms - .state - .get_room_version(room_id) - .or_else(|_| { + let room_version_id = + services().rooms.state.get_room_version(room_id).or_else(|_| { if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) + let content = + serde_json::from_str::( + content.get(), + ) .expect("Invalid content in RoomCreate pdu."); Ok(content.room_version) } else { @@ -637,7 +711,8 @@ impl Service { } })?; - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); + let room_version = RoomVersion::new(&room_version_id) + .expect("room version is supported"); let auth_events = services().rooms.state.get_auth_events( room_id, @@ -658,18 +733,22 @@ impl Service { let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { - if let Some(prev_pdu) = services().rooms.state_accessor.room_state_get( - room_id, - &event_type.to_string().into(), - state_key, - )? { + if let Some(prev_pdu) = + services().rooms.state_accessor.room_state_get( + room_id, + &event_type.to_string().into(), + state_key, + )? + { unsigned.insert( "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), + serde_json::from_str(prev_pdu.content.get()) + .expect("string is valid json"), ); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender) + .expect("UserId::to_value always works"), ); } } @@ -694,7 +773,9 @@ impl Service { unsigned: if unsigned.is_empty() { None } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) + Some( + to_raw_value(&unsigned).expect("to_raw_value always works"), + ) }, hashes: EventHash { sha256: "aaa".to_owned(), @@ -722,8 +803,8 @@ impl Service { } // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + let mut pdu_json = utils::to_canonical_object(&pdu) + .expect("event is valid, we just created it"); pdu_json.remove("event_id"); @@ -769,16 +850,15 @@ impl Service { ); // Generate short event id - let _shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(&pdu.event_id)?; + let _shorteventid = + services().rooms.short.get_or_create_shorteventid(&pdu.event_id)?; Ok((pdu, pdu_json)) } - /// Creates a new persisted data unit and adds it to a room. This function takes a - /// roomid_mutex_state, meaning that only this function is able to mutate the room state. + /// Creates a new persisted data unit and adds it to a room. This function + /// takes a roomid_mutex_state, meaning that only this function is able + /// to mutate the room state. #[tracing::instrument(skip(self, state_lock))] pub(crate) async fn build_and_append_pdu( &self, @@ -788,8 +868,12 @@ impl Service { // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, ) -> Result> { - let (pdu, pdu_json) = - self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; + let (pdu, pdu_json) = self.create_hash_and_sign_event( + pdu_builder, + sender, + room_id, + state_lock, + )?; if let Some(admin_room) = services().admin.get_admin_room()? { if admin_room == room_id { @@ -820,15 +904,24 @@ impl Service { "grapevine" }, ); - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + let content = + serde_json::from_str::( + pdu.content.get(), + ) + .map_err(|_| { + Error::bad_database("Invalid content in pdu.") + })?; if content.membership == MembershipState::Leave { if target == server_user { - warn!("Grapevine user cannot leave from admins room"); + warn!( + "Grapevine user cannot leave from admins \ + room" + ); return Err(Error::BadRequest( ErrorKind::Forbidden, - "Grapevine user cannot leave from admins room.", + "Grapevine user cannot leave from admins \ + room.", )); } @@ -841,7 +934,9 @@ impl Service { .filter(|m| m != target) .count(); if count < 2 { - warn!("Last admin cannot leave from admins room"); + warn!( + "Last admin cannot leave from admins room" + ); return Err(Error::BadRequest( ErrorKind::Forbidden, "Last admin cannot leave from admins room.", @@ -849,12 +944,18 @@ impl Service { } } - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if content.membership == MembershipState::Ban + && pdu.state_key().is_some() + { if target == server_user { - warn!("Grapevine user cannot be banned in admins room"); + warn!( + "Grapevine user cannot be banned in \ + admins room" + ); return Err(Error::BadRequest( ErrorKind::Forbidden, - "Grapevine user cannot be banned in admins room.", + "Grapevine user cannot be banned in \ + admins room.", )); } @@ -867,10 +968,14 @@ impl Service { .filter(|m| m != target) .count(); if count < 2 { - warn!("Last admin cannot be banned in admins room"); + warn!( + "Last admin cannot be banned in admins \ + room" + ); return Err(Error::BadRequest( ErrorKind::Forbidden, - "Last admin cannot be banned in admins room.", + "Last admin cannot be banned in admins \ + room.", )); } } @@ -880,7 +985,8 @@ impl Service { } } - // If redaction event is not authorized, do not append it to the timeline + // If redaction event is not authorized, do not append it to the + // timeline if pdu.kind == TimelineEventType::RoomRedaction { match services().rooms.state.get_room_version(&pdu.room_id)? { RoomVersionId::V1 @@ -908,11 +1014,12 @@ impl Service { }; } RoomVersionId::V11 => { - let content = - serde_json::from_str::(pdu.content.get()) - .map_err(|_| { - Error::bad_database("Invalid content in redaction pdu.") - })?; + let content = serde_json::from_str::< + RoomRedactionEventContent, + >(pdu.content.get()) + .map_err(|_| { + Error::bad_database("Invalid content in redaction pdu.") + })?; if let Some(redact_id) = &content.redacts { if !services().rooms.state_accessor.user_can_redact( @@ -937,27 +1044,30 @@ impl Service { } } - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. + // We append to state before appending the pdu, so we don't have a + // moment in time with the pdu without it's state. This is okay + // because append_pdu can't fail. let statehashid = services().rooms.state.append_to_state(&pdu)?; let pdu_id = self .append_pdu( &pdu, pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room + // Since this PDU references all pdu_leaves we can update the + // leaves of the room vec![(*pdu.event_id).to_owned()], state_lock, ) .await?; - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehashid, state_lock)?; + // We set the room state after inserting the pdu, so that we never have + // a moment in time where events in the current room state do + // not exist + services().rooms.state.set_room_state( + room_id, + statehashid, + state_lock, + )?; let mut servers: HashSet = services() .rooms @@ -966,7 +1076,8 @@ impl Service { .filter_map(Result::ok) .collect(); - // In case we are kicking or banning a user, we need to inform their server of the change + // In case we are kicking or banning a user, we need to inform their + // server of the change if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key_uid) = &pdu .state_key @@ -977,7 +1088,8 @@ impl Service { } } - // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above + // Remove our server from the server list since it will be added to it + // by room_servers() and/or the if statement above servers.remove(services().globals.server_name()); services().sending.send_pdu(servers.into_iter(), &pdu_id)?; @@ -985,8 +1097,8 @@ impl Service { Ok(pdu.event_id) } - /// Append the incoming event setting the state snapshot to the state from the - /// server that sent the event. + /// Append the incoming event setting the state snapshot to the state from + /// the server that sent the event. #[tracing::instrument(skip_all)] pub(crate) async fn append_incoming_pdu( &self, @@ -998,8 +1110,9 @@ impl Service { // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, ) -> Result>> { - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. + // We append to state before appending the pdu, so we don't have a + // moment in time with the pdu without it's state. This is okay + // because append_pdu can't fail. services().rooms.state.set_event_state( &pdu.event_id, &pdu.room_id, @@ -1037,8 +1150,9 @@ impl Service { self.pdus_after(user_id, room_id, PduCount::MIN) } - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. #[tracing::instrument(skip(self))] pub(crate) fn pdus_until<'a>( &'a self, @@ -1049,8 +1163,8 @@ impl Service { self.db.pdus_until(user_id, room_id, until) } - /// Returns an iterator over all events and their token in a room that happened after the event - /// with id `from` in chronological order. + /// Returns an iterator over all events and their token in a room that + /// happened after the event with id `from` in chronological order. #[tracing::instrument(skip(self))] pub(crate) fn pdus_after<'a>( &'a self, @@ -1063,13 +1177,18 @@ impl Service { /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub(crate) fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + pub(crate) fn redact_pdu( + &self, + event_id: &EventId, + reason: &PduEvent, + ) -> Result<()> { // TODO: Don't reserialize, keep original json if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let mut pdu = self.get_pdu_from_id(&pdu_id)?.ok_or_else(|| { + Error::bad_database("PDU ID points to invalid PDU.") + })?; + let room_version_id = + services().rooms.state.get_room_version(&pdu.room_id)?; pdu.redact(room_version_id, reason)?; self.replace_pdu( &pdu_id, @@ -1102,8 +1221,9 @@ impl Service { .state_accessor .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + serde_json::from_str(ev.content.get()).map_err(|_| { + Error::bad_database("invalid m.room.power_levels event") + }) }) .transpose()? .unwrap_or_default(); @@ -1133,7 +1253,9 @@ impl Service { Ok(response) => { let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await + if let Err(e) = self + .backfill_pdu(backfill_server, pdu, &pub_key_map) + .await { warn!("Failed to add backfilled pdu: {e}"); } @@ -1157,7 +1279,8 @@ impl Service { pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = + server_server::parse_incoming_pdu(&pdu)?; // Lock so we cannot backfill the same pdu twice at the same time let mutex = Arc::clone( @@ -1180,7 +1303,14 @@ impl Service { services() .rooms .event_handler - .handle_incoming_pdu(origin, &event_id, &room_id, value, false, pub_key_map) + .handle_incoming_pdu( + origin, + &event_id, + &room_id, + value, + false, + pub_key_map, + ) .await?; let value = self.get_pdu_json(&event_id)?.expect("We just created it"); @@ -1219,14 +1349,18 @@ impl Service { body: Option, } - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + let content = + serde_json::from_str::(pdu.content.get()) + .map_err(|_| { + Error::bad_database("Invalid content in pdu.") + })?; if let Some(body) = content.body { - services() - .rooms - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + services().rooms.search.index_pdu( + shortroomid, + &pdu_id, + &body, + )?; } } drop(mutex_lock); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 025e3ef0..5ea5ae03 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -2,21 +2,30 @@ use std::sync::Arc; use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; +use super::PduCount; use crate::{PduEvent, Result}; -use super::PduCount; - pub(crate) trait Data: Send + Sync { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; + fn last_timeline_count( + &self, + sender_user: &UserId, + room_id: &RoomId, + ) -> Result; /// Returns the `count` of this pdu's id. fn get_pdu_count(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_pdu_json( + &self, + event_id: &EventId, + ) -> Result>; /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result>; /// Returns the pdu's id. fn get_pdu_id(&self, event_id: &EventId) -> Result>>; @@ -24,7 +33,10 @@ pub(crate) trait Data: Send + Sync { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; + fn get_non_outlier_pdu( + &self, + event_id: &EventId, + ) -> Result>; /// Returns the pdu. /// @@ -37,7 +49,10 @@ pub(crate) trait Data: Send + Sync { fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; + fn get_pdu_json_from_id( + &self, + pdu_id: &[u8], + ) -> Result>; /// Adds a new pdu to the timeline fn append_pdu( @@ -64,8 +79,9 @@ pub(crate) trait Data: Send + Sync { pdu: &PduEvent, ) -> Result<()>; - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. #[allow(clippy::type_complexity)] fn pdus_until<'a>( &'a self, @@ -74,8 +90,8 @@ pub(crate) trait Data: Send + Sync { until: PduCount, ) -> Result> + 'a>>; - /// Returns an iterator over all events in a room that happened after the event with id `from` - /// in chronological order. + /// Returns an iterator over all events in a room that happened after the + /// event with id `from` in chronological order. #[allow(clippy::type_complexity)] fn pdus_after<'a>( &'a self, diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 6b5f1cde..bfcb8f70 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,15 +1,32 @@ -use crate::Result; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn reset_notification_counts( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()>; - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn notification_count( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result; - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn highlight_count( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result; // Returns the count at which the last reset_notification_counts was called - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn last_notification_read( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result; fn associate_token_shortstatehash( &self, @@ -18,7 +35,11 @@ pub(crate) trait Data: Send + Sync { shortstatehash: u64, ) -> Result<()>; - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; + fn get_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + ) -> Result>; fn get_shared_rooms<'a>( &'a self, diff --git a/src/service/sending.rs b/src/service/sending.rs index 22216faf..1cef2ab2 100644 --- a/src/service/sending.rs +++ b/src/service/sending.rs @@ -1,7 +1,5 @@ mod data; -pub(crate) use data::Data; - use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, @@ -9,34 +7,29 @@ use std::{ time::{Duration, Instant}, }; -use crate::{ - api::{appservice_server, server_server}, - services, - utils::calculate_hash, - Config, Error, PduEvent, Result, -}; +use base64::{engine::general_purpose, Engine as _}; +pub(crate) use data::Data; use federation::transactions::send_transaction_message; use futures_util::{stream::FuturesUnordered, StreamExt}; - -use base64::{engine::general_purpose, Engine as _}; - use ruma::{ api::{ appservice::{self, Registration}, federation::{ self, transactions::edu::{ - DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, ReceiptMap, + DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, + ReceiptMap, }, }, OutgoingRequest, }, device_id, events::{ - push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, - GlobalAccountDataEventType, + push_rules::PushRulesEvent, receipt::ReceiptType, + AnySyncEphemeralRoomEvent, GlobalAccountDataEventType, }, - push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName, UInt, UserId, + push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, + ServerName, UInt, UserId, }; use tokio::{ select, @@ -44,6 +37,13 @@ use tokio::{ }; use tracing::{debug, error, warn}; +use crate::{ + api::{appservice_server, server_server}, + services, + utils::calculate_hash, + Config, Error, PduEvent, Result, +}; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) enum OutgoingKind { Appservice(String), @@ -64,7 +64,7 @@ impl OutgoingKind { OutgoingKind::Push(user, pushkey) => { let mut p = b"$".to_vec(); p.extend_from_slice(user.as_bytes()); - p.push(0xff); + p.push(0xFF); p.extend_from_slice(pushkey.as_bytes()); p } @@ -74,7 +74,7 @@ impl OutgoingKind { p } }; - prefix.push(0xff); + prefix.push(0xFF); prefix } @@ -93,8 +93,11 @@ pub(crate) struct Service { /// The state for a given state hash. pub(super) maximum_requests: Arc, - pub(crate) sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, - receiver: Mutex)>>, + pub(crate) sender: + mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, + receiver: Mutex< + mpsc::UnboundedReceiver<(OutgoingKind, SendingEventType, Vec)>, + >, } enum TransactionStatus { @@ -112,7 +115,9 @@ impl Service { db, sender, receiver: Mutex::new(receiver), - maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests.into())), + maximum_requests: Arc::new(Semaphore::new( + config.max_concurrent_requests.into(), + )), }) } @@ -129,15 +134,18 @@ impl Service { let mut futures = FuturesUnordered::new(); - let mut current_transaction_status = HashMap::::new(); + let mut current_transaction_status = + HashMap::::new(); // Retry requests we could not finish yet - let mut initial_transactions = HashMap::>::new(); + let mut initial_transactions = + HashMap::>::new(); - for (key, outgoing_kind, event) in self.db.active_requests().filter_map(Result::ok) { - let entry = initial_transactions - .entry(outgoing_kind.clone()) - .or_default(); + for (key, outgoing_kind, event) in + self.db.active_requests().filter_map(Result::ok) + { + let entry = + initial_transactions.entry(outgoing_kind.clone()).or_default(); if entry.len() > 30 { warn!( @@ -152,77 +160,90 @@ impl Service { } for (outgoing_kind, events) in initial_transactions { - current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running); + current_transaction_status + .insert(outgoing_kind.clone(), TransactionStatus::Running); futures.push(Self::handle_events(outgoing_kind.clone(), events)); } - let handle_futures = |response, - current_transaction_status: &mut HashMap<_, _>, - futures: &mut FuturesUnordered<_>| { - match response { - Ok(outgoing_kind) => { - self.db.delete_all_active_requests_for(&outgoing_kind)?; + let handle_futures = + |response, + current_transaction_status: &mut HashMap<_, _>, + futures: &mut FuturesUnordered<_>| { + match response { + Ok(outgoing_kind) => { + self.db + .delete_all_active_requests_for(&outgoing_kind)?; - // Find events that have been added since starting the - // last request - let new_events = self - .db - .queued_requests(&outgoing_kind) - .filter_map(Result::ok) - .take(30) - .collect::>(); + // Find events that have been added since starting the + // last request + let new_events = self + .db + .queued_requests(&outgoing_kind) + .filter_map(Result::ok) + .take(30) + .collect::>(); - if new_events.is_empty() { - current_transaction_status.remove(&outgoing_kind); - } else { - // Insert pdus we found - self.db.mark_as_active(&new_events)?; + if new_events.is_empty() { + current_transaction_status.remove(&outgoing_kind); + } else { + // Insert pdus we found + self.db.mark_as_active(&new_events)?; - futures.push(Self::handle_events( - outgoing_kind.clone(), - new_events.into_iter().map(|(event, _)| event).collect(), - )); + futures.push(Self::handle_events( + outgoing_kind.clone(), + new_events + .into_iter() + .map(|(event, _)| event) + .collect(), + )); + } } - } - Err((outgoing_kind, _)) => { - current_transaction_status - .entry(outgoing_kind) - .and_modify(|e| { - *e = match e { - TransactionStatus::Running => { - TransactionStatus::Failed(1, Instant::now()) - } - TransactionStatus::Retrying(n) => { - TransactionStatus::Failed(*n + 1, Instant::now()) - } - TransactionStatus::Failed(..) => { - error!( - "Request that was not even \ + Err((outgoing_kind, _)) => { + current_transaction_status + .entry(outgoing_kind) + .and_modify(|e| { + *e = match e { + TransactionStatus::Running => { + TransactionStatus::Failed( + 1, + Instant::now(), + ) + } + TransactionStatus::Retrying(n) => { + TransactionStatus::Failed( + *n + 1, + Instant::now(), + ) + } + TransactionStatus::Failed(..) => { + error!( + "Request that was not even \ running failed?!" - ); - return; + ); + return; + } } - } - }); - } + }); + } + }; + + Result::<_>::Ok(()) }; - Result::<_>::Ok(()) - }; - - let handle_receiver = |outgoing_kind, - event, - key, - current_transaction_status: &mut HashMap<_, _>, - futures: &mut FuturesUnordered<_>| { - if let Ok(Some(events)) = self.select_events( - &outgoing_kind, - vec![(event, key)], - current_transaction_status, - ) { - futures.push(Self::handle_events(outgoing_kind, events)); - } - }; + let handle_receiver = + |outgoing_kind, + event, + key, + current_transaction_status: &mut HashMap<_, _>, + futures: &mut FuturesUnordered<_>| { + if let Ok(Some(events)) = self.select_events( + &outgoing_kind, + vec![(event, key)], + current_transaction_status, + ) { + futures.push(Self::handle_events(outgoing_kind, events)); + } + }; loop { select! { @@ -244,13 +265,21 @@ impl Service { } } - #[tracing::instrument(skip(self, outgoing_kind, new_events, current_transaction_status))] + #[tracing::instrument(skip( + self, + outgoing_kind, + new_events, + current_transaction_status + ))] fn select_events( &self, outgoing_kind: &OutgoingKind, // Events we want to send: event and full key new_events: Vec<(SendingEventType, Vec)>, - current_transaction_status: &mut HashMap, + current_transaction_status: &mut HashMap< + OutgoingKind, + TransactionStatus, + >, ) -> Result>> { let mut retry = false; let mut allow = true; @@ -264,10 +293,14 @@ impl Service { allow = false; } TransactionStatus::Failed(tries, time) => { - // Fail if a request has failed recently (exponential backoff) - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + // Fail if a request has failed recently (exponential + // backoff) + let mut min_elapsed_duration = + Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) + { + min_elapsed_duration = + Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { @@ -302,8 +335,12 @@ impl Service { } if let OutgoingKind::Normal(server_name) = outgoing_kind { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { - events.extend(select_edus.into_iter().map(SendingEventType::Edu)); + if let Ok((select_edus, last_count)) = + self.select_edus(server_name) + { + events.extend( + select_edus.into_iter().map(SendingEventType::Edu), + ); self.db.set_latest_educount(server_name, last_count)?; } @@ -314,14 +351,19 @@ impl Service { } #[tracing::instrument(skip(self, server_name))] - pub(crate) fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + pub(crate) fn select_edus( + &self, + server_name: &ServerName, + ) -> Result<(Vec>, u64)> { // u64: count of last edu let since = self.db.get_latest_educount(server_name)?; let mut events = Vec::new(); let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - 'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { + 'outer: for room_id in + services().rooms.state_cache.server_rooms(server_name) + { let room_id = room_id?; // Look for device list updates in this room device_list_changes.extend( @@ -329,7 +371,10 @@ impl Service { .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok) - .filter(|user_id| user_id.server_name() == services().globals.server_name()), + .filter(|user_id| { + user_id.server_name() + == services().globals.server_name() + }), ); // Look for read receipts in this room @@ -349,44 +394,57 @@ impl Service { continue; } - let event: AnySyncEphemeralRoomEvent = - serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { - let mut read = BTreeMap::new(); + let event: AnySyncEphemeralRoomEvent = serde_json::from_str( + read_receipt.json().get(), + ) + .map_err(|_| { + Error::bad_database("Invalid edu event in read_receipts.") + })?; + let federation_event = + if let AnySyncEphemeralRoomEvent::Receipt(r) = event { + let mut read = BTreeMap::new(); - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); + let (event_id, mut receipt) = + r.content.0.into_iter().next().expect( + "we only use one event per read receipt", + ); + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect( + "our read receipts always have the user here", + ); - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); - let receipt_map = ReceiptMap { read }; + let receipt_map = ReceiptMap { + read, + }; - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.clone(), receipt_map); + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.clone(), receipt_map); - Edu::Receipt(ReceiptContent { receipts }) - } else { - Error::bad_database("Invalid event type in read_receipts"); - continue; - }; + Edu::Receipt(ReceiptContent { + receipts, + }) + } else { + Error::bad_database( + "Invalid event type in read_receipts", + ); + continue; + }; - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + events.push( + serde_json::to_vec(&federation_event) + .expect("json can be serialized"), + ); if events.len() >= 20 { break 'outer; @@ -407,7 +465,9 @@ impl Service { keys: None, }); - events.push(serde_json::to_vec(&edu).expect("json can be serialized")); + events.push( + serde_json::to_vec(&edu).expect("json can be serialized"), + ); } Ok((events, max_edu_count)) @@ -422,7 +482,8 @@ impl Service { ) -> Result<()> { let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); let event = SendingEventType::Pdu(pdu_id.to_owned()); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = + self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -446,15 +507,10 @@ impl Service { }) .collect::>(); let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), + &requests.iter().map(|(o, e)| (o, e.clone())).collect::>(), )?; for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender - .send((outgoing_kind.clone(), event, key)) - .unwrap(); + self.sender.send((outgoing_kind.clone(), event, key)).unwrap(); } Ok(()) @@ -469,7 +525,8 @@ impl Service { ) -> Result<()> { let outgoing_kind = OutgoingKind::Normal(server.to_owned()); let event = SendingEventType::Edu(serialized); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = + self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -478,10 +535,15 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub(crate) fn send_pdu_appservice( + &self, + appservice_id: String, + pdu_id: Vec, + ) -> Result<()> { let outgoing_kind = OutgoingKind::Appservice(appservice_id); let event = SendingEventType::Pdu(pdu_id); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = + self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -493,8 +555,9 @@ impl Service { /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self))] pub(crate) fn cleanup_events(&self, appservice_id: String) -> Result<()> { - self.db - .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; + self.db.delete_all_requests_for(&OutgoingKind::Appservice( + appservice_id, + ))?; Ok(()) } @@ -511,18 +574,24 @@ impl Service { for event in &events { match event { SendingEventType::Pdu(pdu_id) => { - pdu_jsons.push(services().rooms.timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Appservice] Event in servernameevent_data not found in db.", - ), - ) - })? - .to_room_event()); + pdu_jsons.push( + services() + .rooms + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Appservice] Event in \ + servernameevent_data not \ + found in db.", + ), + ) + })? + .to_room_event(), + ); } SendingEventType::Edu(_) => { // Appservices don't need EDUs (?) @@ -530,7 +599,8 @@ impl Service { } } - let permit = services().sending.maximum_requests.acquire().await; + let permit = + services().sending.maximum_requests.acquire().await; let response = match appservice_server::send_request( services() @@ -541,20 +611,24 @@ impl Service { ( kind.clone(), Error::bad_database( - "[Appservice] Could not load registration from db.", + "[Appservice] Could not load registration \ + from db.", ), ) })?, appservice::event::push_events::v1::Request { events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, - }) - .collect::>(), - ))) + txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode( + calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEventType::Edu(b) + | SendingEventType::Pdu(b) => &**b, + }) + .collect::>(), + ), + )) .into(), }, ) @@ -575,7 +649,8 @@ impl Service { match event { SendingEventType::Pdu(pdu_id) => { pdus.push( - services().rooms + services() + .rooms .timeline .get_pdu_from_id(pdu_id) .map_err(|e| (kind.clone(), e))? @@ -583,7 +658,9 @@ impl Service { ( kind.clone(), Error::bad_database( - "[Push] Event in servernamevent_datas not found in db.", + "[Push] Event in \ + servernamevent_datas not \ + found in db.", ), ) })?, @@ -595,10 +672,13 @@ impl Service { } for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) + // Redacted events are not notification targets (we don't + // send push for them) if let Some(unsigned) = &pdu.unsigned { if let Ok(unsigned) = - serde_json::from_str::(unsigned.get()) + serde_json::from_str::( + unsigned.get(), + ) { if unsigned.get("redacted_because").is_some() { continue; @@ -609,7 +689,15 @@ impl Service { let Some(pusher) = services() .pusher .get_pusher(userid, pushkey) - .map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? + .map_err(|e| { + ( + OutgoingKind::Push( + userid.clone(), + pushkey.clone(), + ), + e, + ) + })? else { continue; }; @@ -619,10 +707,15 @@ impl Service { .get( None, userid, - GlobalAccountDataEventType::PushRules.to_string().into(), + GlobalAccountDataEventType::PushRules + .to_string() + .into(), ) .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) + .and_then(|event| { + serde_json::from_str::(event.get()) + .ok() + }) .map_or_else( || push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global, @@ -636,11 +729,18 @@ impl Service { .try_into() .expect("notification count can't go that high"); - let permit = services().sending.maximum_requests.acquire().await; + let permit = + services().sending.maximum_requests.acquire().await; let _response = services() .pusher - .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .send_push_notice( + userid, + unread, + &pusher, + rules_for_user, + &pdu, + ) .await .map(|_response| kind.clone()) .map_err(|e| (kind.clone(), e)); @@ -656,22 +756,39 @@ impl Service { for event in &events { match event { SendingEventType::Pdu(pdu_id) => { - // TODO: check room version and remove event_id if needed - let raw = PduEvent::convert_to_outgoing_federation_event( - services().rooms - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? - .ok_or_else(|| { - error!("event not found: {server} {pdu_id:?}"); - ( - OutgoingKind::Normal(server.clone()), - Error::bad_database( - "[Normal] Event in servernamevent_datas not found in db.", - ), - ) - })?, - ); + // TODO: check room version and remove event_id if + // needed + let raw = + PduEvent::convert_to_outgoing_federation_event( + services() + .rooms + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| { + ( + OutgoingKind::Normal( + server.clone(), + ), + e, + ) + })? + .ok_or_else(|| { + error!( + "event not found: {server} \ + {pdu_id:?}" + ); + ( + OutgoingKind::Normal( + server.clone(), + ), + Error::bad_database( + "[Normal] Event in \ + servernamevent_datas not \ + found in db.", + ), + ) + })?, + ); pdu_jsons.push(raw); } SendingEventType::Edu(edu) => { @@ -682,7 +799,8 @@ impl Service { } } - let permit = services().sending.maximum_requests.acquire().await; + let permit = + services().sending.maximum_requests.acquire().await; let response = server_server::send_request( server, @@ -691,16 +809,16 @@ impl Service { pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode( - calculate_hash( + transaction_id: (&*general_purpose::URL_SAFE_NO_PAD + .encode(calculate_hash( &events .iter() .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + SendingEventType::Edu(b) + | SendingEventType::Pdu(b) => &**b, }) .collect::>(), - ), - )) + ))) .into(), }, ) @@ -750,7 +868,8 @@ impl Service { /// Sends a request to an appservice /// - /// Only returns None if there is no url specified in the appservice registration file + /// Only returns None if there is no url specified in the appservice + /// registration file #[tracing::instrument(skip(self, registration, request))] pub(crate) async fn send_appservice_request( &self, @@ -761,7 +880,8 @@ impl Service { T: Debug, { let permit = self.maximum_requests.acquire().await; - let response = appservice_server::send_request(registration, request).await; + let response = + appservice_server::send_request(registration, request).await; drop(permit); response diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index f2351534..b250c07c 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,21 +1,29 @@ use ruma::ServerName; -use crate::Result; - use super::{OutgoingKind, SendingEventType}; +use crate::Result; pub(crate) trait Data: Send + Sync { #[allow(clippy::type_complexity)] fn active_requests<'a>( &'a self, - ) -> Box, OutgoingKind, SendingEventType)>> + 'a>; + ) -> Box< + dyn Iterator, OutgoingKind, SendingEventType)>> + + 'a, + >; fn active_requests_for<'a>( &'a self, outgoing_kind: &OutgoingKind, ) -> Box, SendingEventType)>> + 'a>; fn delete_active_request(&self, key: Vec) -> Result<()>; - fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; - fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn delete_all_active_requests_for( + &self, + outgoing_kind: &OutgoingKind, + ) -> Result<()>; + fn delete_all_requests_for( + &self, + outgoing_kind: &OutgoingKind, + ) -> Result<()>; fn queue_requests( &self, requests: &[(&OutgoingKind, SendingEventType)], @@ -24,7 +32,14 @@ pub(crate) trait Data: Send + Sync { &'a self, outgoing_kind: &OutgoingKind, ) -> Box)>> + 'a>; - fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()>; - fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; + fn mark_as_active( + &self, + events: &[(SendingEventType, Vec)], + ) -> Result<()>; + fn set_latest_educount( + &self, + server_name: &ServerName, + educount: u64, + ) -> Result<()>; fn get_latest_educount(&self, server_name: &ServerName) -> Result; } diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 6964abcd..e7012956 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,6 +1,7 @@ -use crate::Result; use ruma::{DeviceId, TransactionId, UserId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn add_txnid( &self, diff --git a/src/service/uiaa.rs b/src/service/uiaa.rs index 1cb026d9..61e4e252 100644 --- a/src/service/uiaa.rs +++ b/src/service/uiaa.rs @@ -1,7 +1,6 @@ mod data; pub(crate) use data::Data; - use ruma::{ api::client::{ error::ErrorKind, @@ -11,7 +10,9 @@ use ruma::{ }; use tracing::error; -use crate::{api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result}; +use crate::{ + api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result, +}; pub(crate) struct Service { pub(crate) db: &'static dyn Data, @@ -29,7 +30,8 @@ impl Service { self.db.set_uiaa_request( user_id, device_id, - // TODO: better session error handling (why is it optional in ruma?) + // TODO: better session error handling (why is it optional in + // ruma?) uiaainfo.session.as_ref().expect("session should be set"), json_body, )?; @@ -64,7 +66,8 @@ impl Service { password, .. }) => { - let UserIdentifier::UserIdOrLocalpart(username) = identifier else { + let UserIdentifier::UserIdOrLocalpart(username) = identifier + else { return Err(Error::BadRequest( ErrorKind::Unrecognized, "Identifier type not recognized.", @@ -75,18 +78,26 @@ impl Service { username.clone(), services().globals.server_name(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "User ID is invalid.", + ) + })?; // Check if password is correct if let Some(hash) = services().users.password_hash(&user_id)? { let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); + argon2::verify_encoded(&hash, password.as_bytes()) + .unwrap_or(false); if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); + uiaainfo.auth_error = + Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password." + .to_owned(), + }); return Ok((false, uiaainfo)); } } @@ -95,13 +106,16 @@ impl Service { uiaainfo.completed.push(AuthType::Password); } AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { + if Some(t.token.trim()) + == services().globals.config.registration_token.as_deref() + { uiaainfo.completed.push(AuthType::RegistrationToken); } else { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid registration token.".to_owned(), - }); + uiaainfo.auth_error = + Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid registration token.".to_owned(), + }); return Ok((false, uiaainfo)); } } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index 9a63fbb8..c8f41793 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,6 +1,7 @@ -use crate::Result; use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId}; +use crate::Result; + pub(crate) trait Data: Send + Sync { fn set_uiaa_request( &self, diff --git a/src/service/users.rs b/src/service/users.rs index b69c98b0..5fd8cfbd 100644 --- a/src/service/users.rs +++ b/src/service/users.rs @@ -19,8 +19,8 @@ use ruma::{ encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::AnyToDeviceEvent, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, - OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId, }; use crate::{services, Error, Result}; @@ -36,8 +36,12 @@ pub(crate) struct SlidingSyncCache { pub(crate) struct Service { pub(crate) db: &'static dyn Data, #[allow(clippy::type_complexity)] - pub(crate) connections: - Mutex>>>, + pub(crate) connections: Mutex< + BTreeMap< + (OwnedUserId, OwnedDeviceId, String), + Arc>, + >, + >, } impl Service { @@ -52,10 +56,7 @@ impl Service { device_id: OwnedDeviceId, conn_id: String, ) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); + self.connections.lock().unwrap().remove(&(user_id, device_id, conn_id)); } #[allow(clippy::too_many_lines)] @@ -71,16 +72,14 @@ impl Service { let mut cache = self.connections.lock().unwrap(); let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), + cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), ); let cached = &mut cached.lock().unwrap(); drop(cache); @@ -104,19 +103,22 @@ impl Service { .or(cached_list.include_old_rooms.clone()); match (&mut list.filters, cached_list.filters.clone()) { (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + list_filters.is_dm = + list_filters.is_dm.or(cached_filters.is_dm); if list_filters.spaces.is_empty() { list_filters.spaces = cached_filters.spaces; } - list_filters.is_encrypted = - list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_encrypted = list_filters + .is_encrypted + .or(cached_filters.is_encrypted); list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); if list_filters.room_types.is_empty() { list_filters.room_types = cached_filters.room_types; } if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; + list_filters.not_room_types = + cached_filters.not_room_types; } list_filters.room_name_like = list_filters .room_name_like @@ -129,12 +131,17 @@ impl Service { list_filters.not_tags = cached_filters.not_tags; } } - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (_, _) => {} + (_, Some(cached_filters)) => { + list.filters = Some(cached_filters); + } + (Some(list_filters), _) => { + list.filters = Some(list_filters.clone()); + } + (..) => {} } if list.bump_event_types.is_empty() { - list.bump_event_types = cached_list.bump_event_types.clone(); + list.bump_event_types = + cached_list.bump_event_types.clone(); }; } cached.lists.insert(list_id.clone(), list.clone()); @@ -147,17 +154,11 @@ impl Service { .map(|(k, v)| (k.clone(), v.clone())), ); request.room_subscriptions.extend( - cached - .subscriptions - .iter() - .map(|(k, v)| (k.clone(), v.clone())), + cached.subscriptions.iter().map(|(k, v)| (k.clone(), v.clone())), ); - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); + request.extensions.e2ee.enabled = + request.extensions.e2ee.enabled.or(cached.extensions.e2ee.enabled); request.extensions.to_device.enabled = request .extensions @@ -197,16 +198,14 @@ impl Service { ) { let mut cache = self.connections.lock().unwrap(); let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), + cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), ); let cached = &mut cached.lock().unwrap(); drop(cache); @@ -225,25 +224,20 @@ impl Service { ) { let mut cache = self.connections.lock().unwrap(); let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), + cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), ); let cached = &mut cached.lock().unwrap(); drop(cache); - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() + for (roomid, lastsince) in + cached.known_rooms.entry(list_id.clone()).or_default().iter_mut() { if !new_cached_rooms.contains(roomid) { *lastsince = 0; @@ -264,23 +258,28 @@ impl Service { // Allowed because this function uses `services()` #[allow(clippy::unused_self)] pub(crate) fn is_admin(&self, user_id: &UserId) -> Result { - let admin_room_alias_id = - RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_alias_id = RoomAliasId::parse(format!( + "#admins:{}", + services().globals.server_name() + )) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") + })?; let admin_room_id = services() .rooms .alias .resolve_local_alias(&admin_room_alias_id)? .unwrap(); - services() - .rooms - .state_cache - .is_joined(user_id, &admin_room_id) + services().rooms.state_cache.is_joined(user_id, &admin_room_id) } /// Create a new user account on this homeserver. - pub(crate) fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub(crate) fn create( + &self, + user_id: &UserId, + password: Option<&str>, + ) -> Result<()> { self.db.set_password(user_id, password)?; Ok(()) } @@ -291,38 +290,55 @@ impl Service { } /// Find out which user an access token belongs to. - pub(crate) fn find_from_token(&self, token: &str) -> Result> { + pub(crate) fn find_from_token( + &self, + token: &str, + ) -> Result> { self.db.find_from_token(token) } /// Returns an iterator over all users on this homeserver. - pub(crate) fn iter(&self) -> impl Iterator> + '_ { + pub(crate) fn iter( + &self, + ) -> impl Iterator> + '_ { self.db.iter() } /// Returns a list of local users as list of usernames. /// - /// A user account is considered `local` if the length of it's password is greater then zero. + /// A user account is considered `local` if the length of it's password is + /// greater then zero. pub(crate) fn list_local_users(&self) -> Result> { self.db.list_local_users() } /// Returns the password hash for the given user. - pub(crate) fn password_hash(&self, user_id: &UserId) -> Result> { + pub(crate) fn password_hash( + &self, + user_id: &UserId, + ) -> Result> { self.db.password_hash(user_id) } /// Hash and set the user's password to the Argon2 hash - pub(crate) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub(crate) fn set_password( + &self, + user_id: &UserId, + password: Option<&str>, + ) -> Result<()> { self.db.set_password(user_id, password) } /// Returns the displayname of a user on this homeserver. - pub(crate) fn displayname(&self, user_id: &UserId) -> Result> { + pub(crate) fn displayname( + &self, + user_id: &UserId, + ) -> Result> { self.db.displayname(user_id) } - /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. pub(crate) fn set_displayname( &self, user_id: &UserId, @@ -332,7 +348,10 @@ impl Service { } /// Get the `avatar_url` of a user. - pub(crate) fn avatar_url(&self, user_id: &UserId) -> Result> { + pub(crate) fn avatar_url( + &self, + user_id: &UserId, + ) -> Result> { self.db.avatar_url(user_id) } @@ -351,7 +370,11 @@ impl Service { } /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. - pub(crate) fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + pub(crate) fn set_blurhash( + &self, + user_id: &UserId, + blurhash: Option, + ) -> Result<()> { self.db.set_blurhash(user_id, blurhash) } @@ -363,12 +386,20 @@ impl Service { token: &str, initial_device_display_name: Option, ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name) + self.db.create_device( + user_id, + device_id, + token, + initial_device_display_name, + ) } /// Removes a device from a user. - pub(crate) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub(crate) fn remove_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<()> { self.db.remove_device(user_id, device_id) } @@ -397,8 +428,12 @@ impl Service { one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { - self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + self.db.add_one_time_key( + user_id, + device_id, + one_time_key_key, + one_time_key_value, + ) } pub(crate) fn take_one_time_key( @@ -463,7 +498,10 @@ impl Service { self.db.keys_changed(user_or_room_id, from, to) } - pub(crate) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + pub(crate) fn mark_device_key_update( + &self, + user_id: &UserId, + ) -> Result<()> { self.db.mark_device_key_update(user_id) } @@ -490,8 +528,7 @@ impl Service { user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) + self.db.get_key(key, sender_user, user_id, allowed_signatures) } pub(crate) fn get_master_key( @@ -500,8 +537,7 @@ impl Service { user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) + self.db.get_master_key(sender_user, user_id, allowed_signatures) } pub(crate) fn get_self_signing_key( @@ -510,8 +546,7 @@ impl Service { user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) + self.db.get_self_signing_key(sender_user, user_id, allowed_signatures) } pub(crate) fn get_user_signing_key( @@ -573,7 +608,10 @@ impl Service { self.db.get_device_metadata(user_id, device_id) } - pub(crate) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + pub(crate) fn get_devicelist_version( + &self, + user_id: &UserId, + ) -> Result> { self.db.get_devicelist_version(user_id) } @@ -591,9 +629,10 @@ impl Service { self.remove_device(user_id, &device_id?)?; } - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. + // Set the password to "" to indicate a deactivated account. Hashes will + // never result in an empty string, so the user will not be able + // to log in again. Systems like changing the password without + // logging in should check if the account is deactivated. self.db.set_password(user_id, None)?; // TODO: Unhook 3PID @@ -625,19 +664,23 @@ pub(crate) fn clean_signatures bool>( user_id: &UserId, allowed_signatures: F, ) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) + if let Some(signatures) = + cross_signing_key.get_mut("signatures").and_then(|v| v.as_object_mut()) { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped + // Don't allocate for the full size of the current signatures, but + // require at most one resize if nothing is dropped let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let sid = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { + for (user, signature) in mem::replace( + signatures, + serde_json::Map::with_capacity(new_capacity), + ) { + let sid = <&UserId>::try_from(user.as_str()).map_err(|_| { + Error::bad_database("Invalid user ID in database.") + })?; + if sender_user == Some(user_id) + || sid == user_id + || allowed_signatures(sid) + { signatures.insert(user, signature); } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 3a7af883..ca6442e3 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,13 +1,15 @@ -use crate::Result; +use std::collections::BTreeMap; + use ruma::{ api::client::{device::Device, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::AnyToDeviceEvent, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, - OwnedUserId, UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use std::collections::BTreeMap; + +use crate::Result; pub(crate) trait Data: Send + Sync { /// Check if a user has an account on this homeserver. @@ -20,39 +22,61 @@ pub(crate) trait Data: Send + Sync { fn count(&self) -> Result; /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result>; + fn find_from_token( + &self, + token: &str, + ) -> Result>; /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a>; + fn iter<'a>(&'a self) + -> Box> + 'a>; /// Returns a list of local users as list of usernames. /// - /// A user account is considered `local` if the length of it's password is greater then zero. + /// A user account is considered `local` if the length of it's password is + /// greater then zero. fn list_local_users(&self) -> Result>; /// Returns the password hash for the given user. fn password_hash(&self, user_id: &UserId) -> Result>; /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + fn set_password( + &self, + user_id: &UserId, + password: Option<&str>, + ) -> Result<()>; /// Returns the displayname of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result>; - /// Sets a new `displayname` or removes it if `displayname` is `None`. You still need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; + /// Sets a new `displayname` or removes it if `displayname` is `None`. You + /// still need to nofify all rooms of this change. + fn set_displayname( + &self, + user_id: &UserId, + displayname: Option, + ) -> Result<()>; /// Get the `avatar_url` of a user. fn avatar_url(&self, user_id: &UserId) -> Result>; /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()>; + fn set_avatar_url( + &self, + user_id: &UserId, + avatar_url: Option, + ) -> Result<()>; /// Get the `blurhash` of a user. fn blurhash(&self, user_id: &UserId) -> Result>; /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; + fn set_blurhash( + &self, + user_id: &UserId, + blurhash: Option, + ) -> Result<()>; /// Adds a new device to a user. fn create_device( @@ -64,7 +88,11 @@ pub(crate) trait Data: Send + Sync { ) -> Result<()>; /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + fn remove_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<()>; /// Returns an iterator over all device ids of this user. fn all_device_ids<'a>( @@ -73,7 +101,12 @@ pub(crate) trait Data: Send + Sync { ) -> Box> + 'a>; /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; + fn set_token( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + ) -> Result<()>; fn add_one_time_key( &self, @@ -165,7 +198,10 @@ pub(crate) trait Data: Send + Sync { allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>>; - fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; + fn get_user_signing_key( + &self, + user_id: &UserId, + ) -> Result>>; fn add_to_device_event( &self, @@ -197,8 +233,11 @@ pub(crate) trait Data: Send + Sync { ) -> Result<()>; /// Get device metadata. - fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) - -> Result>; + fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>; fn get_devicelist_version(&self, user_id: &UserId) -> Result>; @@ -208,7 +247,15 @@ pub(crate) trait Data: Send + Sync { ) -> Box> + 'a>; /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result; + fn create_filter( + &self, + user_id: &UserId, + filter: &FilterDefinition, + ) -> Result; - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; + fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result>; } diff --git a/src/utils.rs b/src/utils.rs index 194f0eef..0b40acbb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,14 +1,17 @@ pub(crate) mod error; +use std::{ + cmp, + str::FromStr, + time::{SystemTime, UNIX_EPOCH}, +}; + use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; use ring::digest; -use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; -use std::{ - cmp, - str::FromStr, - time::{SystemTime, UNIX_EPOCH}, +use ruma::{ + canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, }; // Hopefully we have a better chat protocol in 530 years @@ -36,7 +39,7 @@ pub(crate) fn increment(old: Option<&[u8]>) -> Vec { pub(crate) fn generate_keypair() -> Vec { let mut value = random_string(8).as_bytes().to_vec(); - value.push(0xff); + value.push(0xFF); value.extend_from_slice( &ruma::signatures::Ed25519KeyPair::generate() .expect("Ed25519KeyPair generation always works (?)"), @@ -45,13 +48,17 @@ pub(crate) fn generate_keypair() -> Vec { } /// Parses the bytes into an u64. -pub(crate) fn u64_from_bytes(bytes: &[u8]) -> Result { +pub(crate) fn u64_from_bytes( + bytes: &[u8], +) -> Result { let array: [u8; 8] = bytes.try_into()?; Ok(u64::from_be_bytes(array)) } /// Parses the bytes into a string. -pub(crate) fn string_from_bytes(bytes: &[u8]) -> Result { +pub(crate) fn string_from_bytes( + bytes: &[u8], +) -> Result { String::from_utf8(bytes.to_vec()) } @@ -64,7 +71,9 @@ pub(crate) fn random_string(length: usize) -> String { } /// Calculate a new hash for the given password -pub(crate) fn calculate_password_hash(password: &str) -> Result { +pub(crate) fn calculate_password_hash( + password: &str, +) -> Result { let hashing_config = Config { variant: Variant::Argon2id, ..Default::default() @@ -77,7 +86,7 @@ pub(crate) fn calculate_password_hash(password: &str) -> Result Vec { // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xff); + let bytes = keys.join(&0xFF); let hash = digest::digest(&digest::SHA256, &bytes); hash.as_ref().to_owned() } @@ -92,7 +101,8 @@ where F: Fn(&[u8], &[u8]) -> Ordering, { let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); + let mut other_iterators = + iterators.map(Iterator::peekable).collect::>(); Some(first_iterator.filter(move |target| { other_iterators.iter_mut().all(|it| { @@ -113,7 +123,8 @@ where })) } -/// Fallible conversion from any value that implements `Serialize` to a `CanonicalJsonObject`. +/// Fallible conversion from any value that implements `Serialize` to a +/// `CanonicalJsonObject`. /// /// `value` must serialize to an `serde_json::Value::Object`. pub(crate) fn to_canonical_object( @@ -138,11 +149,18 @@ pub(crate) fn deserialize_from_str< deserializer: D, ) -> Result { struct Visitor, E>(std::marker::PhantomData); - impl, Err: std::fmt::Display> serde::de::Visitor<'_> for Visitor { + impl, Err: std::fmt::Display> serde::de::Visitor<'_> + for Visitor + { type Value = T; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { write!(formatter, "a parsable string") } + fn visit_str(self, v: &str) -> Result where E: serde::de::Error, diff --git a/src/utils/error.rs b/src/utils/error.rs index c0655de5..581975ff 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -19,13 +19,19 @@ pub(crate) type Result = std::result::Result; #[allow(clippy::error_impl_error)] pub(crate) enum Error { #[cfg(feature = "sqlite")] - #[error("There was a problem with the connection to the sqlite database: {source}")] + #[error( + "There was a problem with the connection to the sqlite database: \ + {source}" + )] Sqlite { #[from] source: rusqlite::Error, }, #[cfg(feature = "rocksdb")] - #[error("There was a problem with the connection to the rocksdb database: {source}")] + #[error( + "There was a problem with the connection to the rocksdb database: \ + {source}" + )] RocksDb { #[from] source: rocksdb::Error, @@ -91,9 +97,10 @@ impl Error { pub(crate) fn to_response(&self) -> RumaResponse { use ErrorKind::{ - Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, - ThreepidAuthFailed, ThreepidDenied, TooLarge, Unauthorized, Unknown, UnknownToken, - Unrecognized, UserDeactivated, WrongRoomKeysVersion, + Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, + NotFound, ThreepidAuthFailed, ThreepidDenied, TooLarge, + Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, + WrongRoomKeysVersion, }; if let Self::Uiaa(uiaainfo) = self { @@ -115,15 +122,23 @@ impl Error { Self::BadRequest(kind, _) => ( kind.clone(), match kind { - WrongRoomKeysVersion { .. } + WrongRoomKeysVersion { + .. + } | Forbidden | GuestAccessForbidden | ThreepidAuthFailed | UserDeactivated | ThreepidDenied => StatusCode::FORBIDDEN, - Unauthorized | UnknownToken { .. } | MissingToken => StatusCode::UNAUTHORIZED, + Unauthorized + | UnknownToken { + .. + } + | MissingToken => StatusCode::UNAUTHORIZED, NotFound | Unrecognized => StatusCode::NOT_FOUND, - LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, TooLarge => StatusCode::PAYLOAD_TOO_LARGE, _ => StatusCode::BAD_REQUEST, }, @@ -135,7 +150,10 @@ impl Error { info!("Returning an error: {}: {}", status_code, message); RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { kind, message }, + body: ErrorBody::Standard { + kind, + message, + }, status_code, })) } @@ -146,12 +164,22 @@ impl Error { match self { #[cfg(feature = "sqlite")] - Self::Sqlite { .. } => db_error, + Self::Sqlite { + .. + } => db_error, #[cfg(feature = "rocksdb")] - Self::RocksDb { .. } => db_error, - Self::Io { .. } => db_error, - Self::BadConfig { .. } => db_error, - Self::BadDatabase { .. } => db_error, + Self::RocksDb { + .. + } => db_error, + Self::Io { + .. + } => db_error, + Self::BadConfig { + .. + } => db_error, + Self::BadDatabase { + .. + } => db_error, _ => self.to_string(), } }