change rustfmt configuration

This change is fully automated, except the `rustfmt.toml` changes and
a few clippy directives to allow specific functions with too many lines
because they are longer now.
This commit is contained in:
Charles Hall 2024-05-16 01:19:04 -07:00
parent 40d6ce230d
commit 0afc1d2f50
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
123 changed files with 7881 additions and 4687 deletions

View file

@ -1,2 +1,17 @@
unstable_features = true edition = "2021"
imports_granularity="Crate"
condense_wildcard_suffixes = true
format_code_in_doc_comments = true
format_macro_bodies = true
format_macro_matchers = true
format_strings = true
group_imports = "StdExternalCrate"
hex_literal_case = "Upper"
imports_granularity = "Crate"
max_width = 80
newline_style = "Unix"
reorder_impl_items = true
use_field_init_shorthand = true
use_small_heuristics = "Off"
use_try_shorthand = true
wrap_comments = true

View file

@ -1,14 +1,18 @@
use crate::{services, utils, Error, Result}; use std::{fmt::Debug, mem, time::Duration};
use bytes::BytesMut; use bytes::BytesMut;
use ruma::api::{ use ruma::api::{
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken,
}; };
use std::{fmt::Debug, mem, time::Duration};
use tracing::warn; use tracing::warn;
use crate::{services, utils, Error, Result};
/// Sends a request to an appservice /// Sends a request to an appservice
/// ///
/// Only returns None if there is no url specified in the appservice registration file /// Only returns None if there is no url specified in the appservice
/// registration file
#[tracing::instrument(skip(request))] #[tracing::instrument(skip(request))]
pub(crate) async fn send_request<T: OutgoingRequest>( pub(crate) async fn send_request<T: OutgoingRequest>(
registration: Registration, registration: Registration,
@ -45,7 +49,8 @@ where
.parse() .parse()
.unwrap(), .unwrap(),
); );
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); *http_request.uri_mut() =
parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = reqwest::Request::try_from(http_request)?; let mut reqwest_request = reqwest::Request::try_from(http_request)?;
@ -70,9 +75,8 @@ where
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder() let mut http_response_builder =
.status(status) http::Response::builder().status(status).version(response.version());
.version(response.version());
mem::swap( mem::swap(
response.headers_mut(), response.headers_mut(),
http_response_builder http_response_builder

View file

@ -1,22 +1,25 @@
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use register::RegistrationKind;
use crate::{api::client_server, services, utils, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
account::{ account::{
change_password, deactivate, get_3pids, get_username_availability, change_password, deactivate, get_3pids, get_username_availability,
register::{self, LoginType}, register::{self, LoginType},
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, request_3pid_management_token_via_email,
whoami, ThirdPartyIdRemovalStatus, request_3pid_management_token_via_msisdn, whoami,
ThirdPartyIdRemovalStatus,
}, },
error::ErrorKind, error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo}, uiaa::{AuthFlow, AuthType, UiaaInfo},
}, },
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, events::{
room::message::RoomMessageEventContent, GlobalAccountDataEventType,
},
push, UserId, push, UserId,
}; };
use tracing::{info, warn}; use tracing::{info, warn};
use register::RegistrationKind; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{api::client_server, services, utils, Error, Result, Ruma};
const RANDOM_USER_ID_LENGTH: usize = 10; const RANDOM_USER_ID_LENGTH: usize = 10;
@ -29,7 +32,8 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
/// - The server name of the user id matches this server /// - The server name of the user id matches this server
/// - No user or appservice on this server already claimed this username /// - No user or appservice on this server already claimed this username
/// ///
/// Note: This will not reserve the username, so the username might become invalid when trying to register /// Note: This will not reserve the username, so the username might become
/// invalid when trying to register
pub(crate) async fn get_register_available_route( pub(crate) async fn get_register_available_route(
body: Ruma<get_username_availability::v3::Request>, body: Ruma<get_username_availability::v3::Request>,
) -> Result<get_username_availability::v3::Response> { ) -> Result<get_username_availability::v3::Response> {
@ -40,7 +44,8 @@ pub(crate) async fn get_register_available_route(
) )
.ok() .ok()
.filter(|user_id| { .filter(|user_id| {
!user_id.is_historical() && user_id.server_name() == services().globals.server_name() !user_id.is_historical()
&& user_id.server_name() == services().globals.server_name()
}) })
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidUsername, ErrorKind::InvalidUsername,
@ -58,27 +63,35 @@ pub(crate) async fn get_register_available_route(
// TODO add check for appservice namespaces // TODO add check for appservice namespaces
// If no if check is true we have an username that's available to be used. // If no if check is true we have an username that's available to be used.
Ok(get_username_availability::v3::Response { available: true }) Ok(get_username_availability::v3::Response {
available: true,
})
} }
/// # `POST /_matrix/client/r0/register` /// # `POST /_matrix/client/r0/register`
/// ///
/// Register an account on this homeserver. /// Register an account on this homeserver.
/// ///
/// You can use [`GET /_matrix/client/r0/register/available`](get_register_available_route) /// You can use [`GET
/// /_matrix/client/r0/register/available`](get_register_available_route)
/// to check if the user id is valid and available. /// to check if the user id is valid and available.
/// ///
/// - Only works if registration is enabled /// - Only works if registration is enabled
/// - If type is guest: ignores all parameters except `initial_device_display_name` /// - If type is guest: ignores all parameters except
/// `initial_device_display_name`
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage) /// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
/// - If type is not guest and no username is given: Always fails after UIAA check /// - If type is not guest and no username is given: Always fails after UIAA
/// check
/// - Creates a new account and populates it with default account data /// - Creates a new account and populates it with default account data
/// - If `inhibit_login` is false: Creates a device and returns `device_id` and `access_token` /// - If `inhibit_login` is false: Creates a device and returns `device_id` and
/// `access_token`
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn register_route( pub(crate) async fn register_route(
body: Ruma<register::v3::Request>, body: Ruma<register::v3::Request>,
) -> Result<register::v3::Response> { ) -> Result<register::v3::Response> {
if !services().globals.allow_registration() && body.appservice_info.is_none() { if !services().globals.allow_registration()
&& body.appservice_info.is_none()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Registration has been disabled.", "Registration has been disabled.",
@ -158,7 +171,8 @@ pub(crate) async fn register_route(
}; };
body.appservice_info.is_some() body.appservice_info.is_some()
} else { } else {
// No registration token necessary, but clients must still go through the flow // No registration token necessary, but clients must still go through
// the flow
uiaainfo = UiaaInfo { uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Dummy], stages: vec![AuthType::Dummy],
@ -174,8 +188,11 @@ pub(crate) async fn register_route(
if !skip_auth { if !skip_auth {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
&UserId::parse_with_server_name("", services().globals.server_name()) &UserId::parse_with_server_name(
.expect("we know this is valid"), "",
services().globals.server_name(),
)
.expect("we know this is valid"),
"".into(), "".into(),
auth, auth,
&uiaainfo, &uiaainfo,
@ -187,8 +204,11 @@ pub(crate) async fn register_route(
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create( services().uiaa.create(
&UserId::parse_with_server_name("", services().globals.server_name()) &UserId::parse_with_server_name(
.expect("we know this is valid"), "",
services().globals.server_name(),
)
.expect("we know this is valid"),
"".into(), "".into(),
&uiaainfo, &uiaainfo,
&json, &json,
@ -211,9 +231,7 @@ pub(crate) async fn register_route(
// Default to pretty displayname // Default to pretty displayname
let displayname = user_id.localpart().to_owned(); let displayname = user_id.localpart().to_owned();
services() services().users.set_displayname(&user_id, Some(displayname.clone()))?;
.users
.set_displayname(&user_id, Some(displayname.clone()))?;
// Initial account data // Initial account data
services().account_data.update( services().account_data.update(
@ -260,29 +278,24 @@ pub(crate) async fn register_route(
info!("New user {} registered on this server.", user_id); info!("New user {} registered on this server.", user_id);
if body.appservice_info.is_none() && !is_guest { if body.appservice_info.is_none() && !is_guest {
services() services().admin.send_message(RoomMessageEventContent::notice_plain(
.admin format!("New user {user_id} registered on this server."),
.send_message(RoomMessageEventContent::notice_plain(format!( ));
"New user {user_id} registered on this server."
)));
} }
// If this is the first real user, grant them admin privileges // If this is the first real user, grant them admin privileges
// Note: the server user, @grapevine:servername, is generated first // Note: the server user, @grapevine:servername, is generated first
if !is_guest { if !is_guest {
if let Some(admin_room) = services().admin.get_admin_room()? { if let Some(admin_room) = services().admin.get_admin_room()? {
if services() if services().rooms.state_cache.room_joined_count(&admin_room)?
.rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1) == Some(1)
{ {
services() services().admin.make_user_admin(&user_id, displayname).await?;
.admin
.make_user_admin(&user_id, displayname)
.await?;
warn!("Granting {} admin privileges as the first user", user_id); warn!(
"Granting {} admin privileges as the first user",
user_id
);
} }
} }
} }
@ -302,19 +315,23 @@ pub(crate) async fn register_route(
/// ///
/// - Requires UIAA to verify user password /// - Requires UIAA to verify user password
/// - Changes the password of the sender user /// - Changes the password of the sender user
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is /// - The password hash is calculated using argon2 with 32 character salt, the
/// plain password is
/// not saved /// not saved
/// ///
/// If `logout_devices` is true it does the following for each device except the sender device: /// If `logout_devices` is true it does the following for each device except the
/// sender device:
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device ID, device display name, last seen IP, last seen timestamp) /// - Deletes device metadata (device ID, device display name, last seen IP,
/// last seen timestamp)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn change_password_route( pub(crate) async fn change_password_route(
body: Ruma<change_password::v3::Request>, body: Ruma<change_password::v3::Request>,
) -> Result<change_password::v3::Response> { ) -> Result<change_password::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
@ -327,27 +344,25 @@ pub(crate) async fn change_password_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(
services() sender_user,
.uiaa sender_device,
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; auth,
&uiaainfo,
)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services() services().users.set_password(sender_user, Some(&body.new_password))?;
.users
.set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
@ -362,11 +377,9 @@ pub(crate) async fn change_password_route(
} }
info!("User {} changed their password.", sender_user); info!("User {} changed their password.", sender_user);
services() services().admin.send_message(RoomMessageEventContent::notice_plain(
.admin format!("User {sender_user} changed their password."),
.send_message(RoomMessageEventContent::notice_plain(format!( ));
"User {sender_user} changed their password."
)));
Ok(change_password::v3::Response {}) Ok(change_password::v3::Response {})
} }
@ -376,14 +389,17 @@ pub(crate) async fn change_password_route(
/// Get `user_id` of the sender user. /// Get `user_id` of the sender user.
/// ///
/// Note: Also works for Application Services /// Note: Also works for Application Services
pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> { pub(crate) async fn whoami_route(
body: Ruma<whoami::v3::Request>,
) -> Result<whoami::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device_id = body.sender_device.as_ref().cloned(); let device_id = body.sender_device.as_ref().cloned();
Ok(whoami::v3::Response { Ok(whoami::v3::Response {
user_id: sender_user.clone(), user_id: sender_user.clone(),
device_id, device_id,
is_guest: services().users.is_deactivated(sender_user)? && body.appservice_info.is_none(), is_guest: services().users.is_deactivated(sender_user)?
&& body.appservice_info.is_none(),
}) })
} }
@ -393,7 +409,8 @@ pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoa
/// ///
/// - Leaves all rooms and rejects all invitations /// - Leaves all rooms and rejects all invitations
/// - Invalidates all access tokens /// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events /// - Forgets all to-device events
/// - Triggers device list updates /// - Triggers device list updates
/// - Removes ability to log in again /// - Removes ability to log in again
@ -401,7 +418,8 @@ pub(crate) async fn deactivate_route(
body: Ruma<deactivate::v3::Request>, body: Ruma<deactivate::v3::Request>,
) -> Result<deactivate::v3::Response> { ) -> Result<deactivate::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
@ -414,19 +432,19 @@ pub(crate) async fn deactivate_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(
services() sender_user,
.uiaa sender_device,
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; auth,
&uiaainfo,
)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -439,11 +457,9 @@ pub(crate) async fn deactivate_route(
services().users.deactivate_account(sender_user)?; services().users.deactivate_account(sender_user)?;
info!("User {} deactivated their account.", sender_user); info!("User {} deactivated their account.", sender_user);
services() services().admin.send_message(RoomMessageEventContent::notice_plain(
.admin format!("User {sender_user} deactivated their account."),
.send_message(RoomMessageEventContent::notice_plain(format!( ));
"User {sender_user} deactivated their account."
)));
Ok(deactivate::v3::Response { Ok(deactivate::v3::Response {
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
@ -458,16 +474,19 @@ pub(crate) async fn deactivate_route(
pub(crate) async fn third_party_route( pub(crate) async fn third_party_route(
body: Ruma<get_3pids::v3::Request>, body: Ruma<get_3pids::v3::Request>,
) -> Result<get_3pids::v3::Response> { ) -> Result<get_3pids::v3::Response> {
let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); let _sender_user =
body.sender_user.as_ref().expect("user is authenticated");
Ok(get_3pids::v3::Response::new(Vec::new())) Ok(get_3pids::v3::Response::new(Vec::new()))
} }
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken` /// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
/// ///
/// "This API should be used to request validation tokens when adding an email address to an account" /// "This API should be used to request validation tokens when adding an email
/// address to an account"
/// ///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. /// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub(crate) async fn request_3pid_management_token_via_email_route( pub(crate) async fn request_3pid_management_token_via_email_route(
_body: Ruma<request_3pid_management_token_via_email::v3::Request>, _body: Ruma<request_3pid_management_token_via_email::v3::Request>,
) -> Result<request_3pid_management_token_via_email::v3::Response> { ) -> Result<request_3pid_management_token_via_email::v3::Response> {
@ -479,9 +498,11 @@ pub(crate) async fn request_3pid_management_token_via_email_route(
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken` /// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
/// ///
/// "This API should be used to request validation tokens when adding an phone number to an account" /// "This API should be used to request validation tokens when adding an phone
/// number to an account"
/// ///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. /// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub(crate) async fn request_3pid_management_token_via_msisdn_route( pub(crate) async fn request_3pid_management_token_via_msisdn_route(
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>, _body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> { ) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {

View file

@ -1,4 +1,3 @@
use crate::{services, Error, Result, Ruma};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use ruma::{ use ruma::{
api::{ api::{
@ -12,6 +11,8 @@ use ruma::{
OwnedRoomAliasId, OwnedRoomAliasId,
}; };
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/directory/room/{roomAlias}` /// # `PUT /_matrix/client/r0/directory/room/{roomAlias}`
/// ///
/// Creates a new room alias on this server. /// Creates a new room alias on this server.
@ -32,30 +33,18 @@ pub(crate) async fn create_alias_route(
"Room alias is not in namespace.", "Room alias is not in namespace.",
)); ));
} }
} else if services() } else if services().appservice.is_exclusive_alias(&body.room_alias).await {
.appservice
.is_exclusive_alias(&body.room_alias)
.await
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Exclusive, ErrorKind::Exclusive,
"Room alias reserved by appservice.", "Room alias reserved by appservice.",
)); ));
} }
if services() if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.is_some()
{
return Err(Error::Conflict("Alias already exists.")); return Err(Error::Conflict("Alias already exists."));
} }
services() services().rooms.alias.set_alias(&body.room_alias, &body.room_id)?;
.rooms
.alias
.set_alias(&body.room_alias, &body.room_id)?;
Ok(create_alias::v3::Response::new()) Ok(create_alias::v3::Response::new())
} }
@ -83,11 +72,7 @@ pub(crate) async fn delete_alias_route(
"Room alias is not in namespace.", "Room alias is not in namespace.",
)); ));
} }
} else if services() } else if services().appservice.is_exclusive_alias(&body.room_alias).await {
.appservice
.is_exclusive_alias(&body.room_alias)
.await
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Exclusive, ErrorKind::Exclusive,
"Room alias reserved by appservice.", "Room alias reserved by appservice.",
@ -157,7 +142,10 @@ pub(crate) async fn get_alias_helper(
.alias .alias
.resolve_local_alias(&room_alias)? .resolve_local_alias(&room_alias)?
.ok_or_else(|| { .ok_or_else(|| {
Error::bad_config("Appservice lied to us. Room does not exist.") Error::bad_config(
"Appservice lied to us. Room does not \
exist.",
)
})?, })?,
); );
break; break;

View file

@ -1,15 +1,16 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
backup::{ backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
create_backup_version, delete_backup_keys, delete_backup_keys_for_room, create_backup_version, delete_backup_keys, delete_backup_keys_for_room,
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys, delete_backup_keys_for_session, delete_backup_version, get_backup_info,
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
update_backup_version, get_latest_backup_info, update_backup_version,
}, },
error::ErrorKind, error::ErrorKind,
}; };
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version` /// # `POST /_matrix/client/r0/room_keys/version`
/// ///
/// Creates a new backup. /// Creates a new backup.
@ -17,23 +18,27 @@ pub(crate) async fn create_backup_version_route(
body: Ruma<create_backup_version::v3::Request>, body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> { ) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services() let version =
.key_backups services().key_backups.create_backup(sender_user, &body.algorithm)?;
.create_backup(sender_user, &body.algorithm)?;
Ok(create_backup_version::v3::Response { version }) Ok(create_backup_version::v3::Response {
version,
})
} }
/// # `PUT /_matrix/client/r0/room_keys/version/{version}` /// # `PUT /_matrix/client/r0/room_keys/version/{version}`
/// ///
/// Update information about an existing backup. Only `auth_data` can be modified. /// Update information about an existing backup. Only `auth_data` can be
/// modified.
pub(crate) async fn update_backup_version_route( pub(crate) async fn update_backup_version_route(
body: Ruma<update_backup_version::v3::Request>, body: Ruma<update_backup_version::v3::Request>,
) -> Result<update_backup_version::v3::Response> { ) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.update_backup(
.key_backups sender_user,
.update_backup(sender_user, &body.version, &body.algorithm)?; &body.version,
&body.algorithm,
)?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
@ -88,9 +93,7 @@ pub(crate) async fn get_backup_info_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
version: body.version.clone(), version: body.version.clone(),
}) })
} }
@ -99,15 +102,14 @@ pub(crate) async fn get_backup_info_route(
/// ///
/// Delete an existing key backup. /// Delete an existing key backup.
/// ///
/// - Deletes both information about the backup, as well as all key data related to the backup /// - Deletes both information about the backup, as well as all key data related
/// to the backup
pub(crate) async fn delete_backup_version_route( pub(crate) async fn delete_backup_version_route(
body: Ruma<delete_backup_version::v3::Request>, body: Ruma<delete_backup_version::v3::Request>,
) -> Result<delete_backup_version::v3::Response> { ) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_backup(sender_user, &body.version)?;
.key_backups
.delete_backup(sender_user, &body.version)?;
Ok(delete_backup_version::v3::Response {}) Ok(delete_backup_version::v3::Response {})
} }
@ -116,7 +118,8 @@ pub(crate) async fn delete_backup_version_route(
/// ///
/// Add the received backup keys to the database. /// Add the received backup keys to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_route( pub(crate) async fn add_backup_keys_route(
@ -132,7 +135,8 @@ pub(crate) async fn add_backup_keys_route(
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the \
backup.",
)); ));
} }
@ -154,9 +158,7 @@ pub(crate) async fn add_backup_keys_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -164,7 +166,8 @@ pub(crate) async fn add_backup_keys_route(
/// ///
/// Add the received backup keys to the database. /// Add the received backup keys to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_room_route( pub(crate) async fn add_backup_keys_for_room_route(
@ -180,7 +183,8 @@ pub(crate) async fn add_backup_keys_for_room_route(
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the \
backup.",
)); ));
} }
@ -200,9 +204,7 @@ pub(crate) async fn add_backup_keys_for_room_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -210,7 +212,8 @@ pub(crate) async fn add_backup_keys_for_room_route(
/// ///
/// Add the received backup key to the database. /// Add the received backup key to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_session_route( pub(crate) async fn add_backup_keys_for_session_route(
@ -226,7 +229,8 @@ pub(crate) async fn add_backup_keys_for_session_route(
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the \
backup.",
)); ));
} }
@ -244,9 +248,7 @@ pub(crate) async fn add_backup_keys_for_session_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -260,7 +262,9 @@ pub(crate) async fn get_backup_keys_route(
let rooms = services().key_backups.get_all(sender_user, &body.version)?; let rooms = services().key_backups.get_all(sender_user, &body.version)?;
Ok(get_backup_keys::v3::Response { rooms }) Ok(get_backup_keys::v3::Response {
rooms,
})
} }
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}` /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
@ -271,11 +275,15 @@ pub(crate) async fn get_backup_keys_for_room_route(
) -> Result<get_backup_keys_for_room::v3::Response> { ) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = services() let sessions = services().key_backups.get_room(
.key_backups sender_user,
.get_room(sender_user, &body.version, &body.room_id)?; &body.version,
&body.room_id,
)?;
Ok(get_backup_keys_for_room::v3::Response { sessions }) Ok(get_backup_keys_for_room::v3::Response {
sessions,
})
} }
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
@ -288,13 +296,20 @@ pub(crate) async fn get_backup_keys_for_session_route(
let key_data = services() let key_data = services()
.key_backups .key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .get_session(
sender_user,
&body.version,
&body.room_id,
&body.session_id,
)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"Backup key not found for this user's session.", "Backup key not found for this user's session.",
))?; ))?;
Ok(get_backup_keys_for_session::v3::Response { key_data }) Ok(get_backup_keys_for_session::v3::Response {
key_data,
})
} }
/// # `DELETE /_matrix/client/r0/room_keys/keys` /// # `DELETE /_matrix/client/r0/room_keys/keys`
@ -305,9 +320,7 @@ pub(crate) async fn delete_backup_keys_route(
) -> Result<delete_backup_keys::v3::Response> { ) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_all_keys(sender_user, &body.version)?;
.key_backups
.delete_all_keys(sender_user, &body.version)?;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: services() count: services()
@ -315,9 +328,7 @@ pub(crate) async fn delete_backup_keys_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -329,9 +340,11 @@ pub(crate) async fn delete_backup_keys_for_room_route(
) -> Result<delete_backup_keys_for_room::v3::Response> { ) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_room_keys(
.key_backups sender_user,
.delete_room_keys(sender_user, &body.version, &body.room_id)?; &body.version,
&body.room_id,
)?;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: services() count: services()
@ -339,9 +352,7 @@ pub(crate) async fn delete_backup_keys_for_room_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -366,8 +377,6 @@ pub(crate) async fn delete_backup_keys_for_session_route(
.count_keys(sender_user, &body.version)? .count_keys(sender_user, &body.version)?
.try_into() .try_into()
.expect("count should fit in UInt"), .expect("count should fit in UInt"),
etag: services() etag: services().key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }

View file

@ -1,12 +1,15 @@
use crate::{services, Result, Ruma}; use std::collections::BTreeMap;
use ruma::api::client::discovery::get_capabilities::{ use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability, self, Capabilities, RoomVersionStability, RoomVersionsCapability,
}; };
use std::collections::BTreeMap;
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/r0/capabilities` /// # `GET /_matrix/client/r0/capabilities`
/// ///
/// Get information on the supported feature set and other relevent capabilities of this server. /// Get information on the supported feature set and other relevent capabilities
/// of this server.
pub(crate) async fn get_capabilities_route( pub(crate) async fn get_capabilities_route(
_body: Ruma<get_capabilities::v3::Request>, _body: Ruma<get_capabilities::v3::Request>,
) -> Result<get_capabilities::v3::Response> { ) -> Result<get_capabilities::v3::Response> {
@ -24,5 +27,7 @@ pub(crate) async fn get_capabilities_route(
available, available,
}; };
Ok(get_capabilities::v3::Response { capabilities }) Ok(get_capabilities::v3::Response {
capabilities,
})
} }

View file

@ -1,18 +1,21 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
config::{ config::{
get_global_account_data, get_room_account_data, set_global_account_data, get_global_account_data, get_room_account_data,
set_room_account_data, set_global_account_data, set_room_account_data,
}, },
error::ErrorKind, error::ErrorKind,
}, },
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent}, events::{
AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent,
},
serde::Raw, serde::Raw,
}; };
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, value::RawValue as RawJsonValue}; use serde_json::{json, value::RawValue as RawJsonValue};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
/// ///
/// Sets some account data for the sender user. /// Sets some account data for the sender user.
@ -22,7 +25,9 @@ pub(crate) async fn set_global_account_data_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get()) let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; .map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Data is invalid.")
})?;
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
@ -48,7 +53,9 @@ pub(crate) async fn set_room_account_data_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get()) let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; .map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Data is invalid.")
})?;
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
@ -78,11 +85,16 @@ pub(crate) async fn get_global_account_data_route(
.get(None, sender_user, body.event_type.to_string().into())? .get(None, sender_user, body.event_type.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get()) let account_data =
.map_err(|_| Error::bad_database("Invalid account data event in db."))? serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.content; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?
.content;
Ok(get_global_account_data::v3::Response { account_data }) Ok(get_global_account_data::v3::Response {
account_data,
})
} }
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
@ -98,11 +110,16 @@ pub(crate) async fn get_room_account_data_route(
.get(Some(&body.room_id), sender_user, body.event_type.clone())? .get(Some(&body.room_id), sender_user, body.event_type.clone())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get()) let account_data =
.map_err(|_| Error::bad_database("Invalid account data event in db."))? serde_json::from_str::<ExtractRoomEventContent>(event.get())
.content; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?
.content;
Ok(get_room_account_data::v3::Response { account_data }) Ok(get_room_account_data::v3::Response {
account_data,
})
} }
#[derive(Deserialize)] #[derive(Deserialize)]

View file

@ -1,60 +1,57 @@
use crate::{services, Error, Result, Ruma}; use std::collections::HashSet;
use ruma::{ use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, api::client::{
context::get_context, error::ErrorKind, filter::LazyLoadOptions,
},
events::StateEventType, events::StateEventType,
uint, uint,
}; };
use std::collections::HashSet;
use tracing::error; use tracing::error;
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// # `GET /_matrix/client/r0/rooms/{roomId}/context`
/// ///
/// Allows loading room history around an event. /// Allows loading room history around an event.
/// ///
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// - Only works if the user is joined (TODO: always allow, but only show events
/// if the user was
/// joined, depending on `history_visibility`) /// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn get_context_route( pub(crate) async fn get_context_route(
body: Ruma<get_context::v3::Request>, body: Ruma<get_context::v3::Request>,
) -> Result<get_context::v3::Response> { ) -> Result<get_context::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { let (lazy_load_enabled, lazy_load_send_redundant) =
LazyLoadOptions::Enabled { match &body.filter.lazy_load_options {
include_redundant_members, LazyLoadOptions::Enabled {
} => (true, *include_redundant_members), include_redundant_members,
LazyLoadOptions::Disabled => (false, false), } => (true, *include_redundant_members),
}; LazyLoadOptions::Disabled => (false, false),
};
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let base_token = services() let base_token =
.rooms services().rooms.timeline.get_pdu_count(&body.event_id)?.ok_or(
.timeline Error::BadRequest(ErrorKind::NotFound, "Base event id not found."),
.get_pdu_count(&body.event_id)? )?;
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Base event id not found.",
))?;
let base_event = let base_event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or(
services() Error::BadRequest(ErrorKind::NotFound, "Base event not found."),
.rooms )?;
.timeline
.get_pdu(&body.event_id)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Base event not found.",
))?;
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !services() if !services().rooms.state_accessor.user_can_see_event(
.rooms sender_user,
.state_accessor &room_id,
.user_can_see_event(sender_user, &room_id, &body.event_id)? &body.event_id,
{ )? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this event.", "You don't have permission to view this event.",
@ -72,8 +69,8 @@ pub(crate) async fn get_context_route(
} }
// Use limit with maximum 100 // Use limit with maximum 100
let half_limit = let half_limit = usize::try_from(body.limit.min(uint!(100)) / uint!(2))
usize::try_from(body.limit.min(uint!(100)) / uint!(2)).expect("0-50 should fit in usize"); .expect("0-50 should fit in usize");
let base_event = base_event.to_room_event(); let base_event = base_event.to_room_event();
@ -108,10 +105,8 @@ pub(crate) async fn get_context_route(
.last() .last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_before: Vec<_> = events_before let events_before: Vec<_> =
.into_iter() events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let events_after: Vec<_> = services() let events_after: Vec<_> = services()
.rooms .rooms
@ -140,41 +135,33 @@ pub(crate) async fn get_context_route(
} }
} }
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( let shortstatehash =
events_after match services().rooms.state_accessor.pdu_shortstatehash(
.last() events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id),
.map_or(&*body.event_id, |(_, e)| &*e.event_id), )? {
)? { Some(s) => s,
Some(s) => s, None => services()
None => services() .rooms
.rooms .state
.state .get_room_shortstatehash(&room_id)?
.get_room_shortstatehash(&room_id)? .expect("All rooms have state"),
.expect("All rooms have state"), };
};
let state_ids = services() let state_ids =
.rooms services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
.state_accessor
.state_full_ids(shortstatehash)
.await?;
let end_token = events_after let end_token = events_after
.last() .last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_after: Vec<_> = events_after let events_after: Vec<_> =
.into_iter() events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let mut state = Vec::new(); let mut state = Vec::new();
for (shortstatekey, id) in state_ids { for (shortstatekey, id) in state_ids {
let (event_type, state_key) = services() let (event_type, state_key) =
.rooms services().rooms.short.get_statekey_from_short(shortstatekey)?;
.short
.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else {

View file

@ -1,11 +1,14 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, device::{
self, delete_device, delete_devices, get_device, get_devices,
update_device,
},
error::ErrorKind, error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo}, uiaa::{AuthFlow, AuthType, UiaaInfo},
}; };
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/devices` /// # `GET /_matrix/client/r0/devices`
/// ///
@ -21,7 +24,9 @@ pub(crate) async fn get_devices_route(
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
Ok(get_devices::v3::Response { devices }) Ok(get_devices::v3::Response {
devices,
})
} }
/// # `GET /_matrix/client/r0/devices/{deviceId}` /// # `GET /_matrix/client/r0/devices/{deviceId}`
@ -37,7 +42,9 @@ pub(crate) async fn get_device_route(
.get_device_metadata(sender_user, &body.body.device_id)? .get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
Ok(get_device::v3::Response { device }) Ok(get_device::v3::Response {
device,
})
} }
/// # `PUT /_matrix/client/r0/devices/{deviceId}` /// # `PUT /_matrix/client/r0/devices/{deviceId}`
@ -55,9 +62,11 @@ pub(crate) async fn update_device_route(
device.display_name = body.display_name.clone(); device.display_name = body.display_name.clone();
services() services().users.update_device_metadata(
.users sender_user,
.update_device_metadata(sender_user, &body.device_id, &device)?; &body.device_id,
&device,
)?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
@ -68,14 +77,16 @@ pub(crate) async fn update_device_route(
/// ///
/// - Requires UIAA to verify user password /// - Requires UIAA to verify user password
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn delete_device_route( pub(crate) async fn delete_device_route(
body: Ruma<delete_device::v3::Request>, body: Ruma<delete_device::v3::Request>,
) -> Result<delete_device::v3::Response> { ) -> Result<delete_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
@ -89,27 +100,25 @@ pub(crate) async fn delete_device_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(
services() sender_user,
.uiaa sender_device,
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; auth,
&uiaainfo,
)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services() services().users.remove_device(sender_user, &body.device_id)?;
.users
.remove_device(sender_user, &body.device_id)?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -122,14 +131,16 @@ pub(crate) async fn delete_device_route(
/// ///
/// For each device: /// For each device:
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn delete_devices_route( pub(crate) async fn delete_devices_route(
body: Ruma<delete_devices::v3::Request>, body: Ruma<delete_devices::v3::Request>,
) -> Result<delete_devices::v3::Response> { ) -> Result<delete_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
@ -143,19 +154,19 @@ pub(crate) async fn delete_devices_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(
services() sender_user,
.uiaa sender_device,
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; auth,
&uiaainfo,
)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));

View file

@ -1,10 +1,9 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
directory::{ directory::{
get_public_rooms, get_public_rooms_filtered, get_room_visibility, get_public_rooms, get_public_rooms_filtered,
set_room_visibility, get_room_visibility, set_room_visibility,
}, },
error::ErrorKind, error::ErrorKind,
room, room,
@ -18,7 +17,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent, canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent, create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent}, guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent},
topic::RoomTopicEventContent, topic::RoomTopicEventContent,
}, },
@ -28,6 +29,8 @@ use ruma::{
}; };
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/publicRooms` /// # `POST /_matrix/client/r0/publicRooms`
/// ///
/// Lists the public rooms on this server. /// Lists the public rooms on this server.
@ -91,7 +94,9 @@ pub(crate) async fn set_room_visibility_route(
services().rooms.directory.set_public(&body.room_id)?; services().rooms.directory.set_public(&body.room_id)?;
info!("{} made {} public", sender_user, body.room_id); info!("{} made {} public", sender_user, body.room_id);
} }
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, room::Visibility::Private => {
services().rooms.directory.set_not_public(&body.room_id)?;
}
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -115,7 +120,11 @@ pub(crate) async fn get_room_visibility_route(
} }
Ok(get_room_visibility::v3::Response { Ok(get_room_visibility::v3::Response {
visibility: if services().rooms.directory.is_public_room(&body.room_id)? { visibility: if services()
.rooms
.directory
.is_public_room(&body.room_id)?
{
room::Visibility::Public room::Visibility::Public
} else { } else {
room::Visibility::Private room::Visibility::Private
@ -131,8 +140,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
filter: &Filter, filter: &Filter,
_network: &RoomNetwork, _network: &RoomNetwork,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(other_server) = if let Some(other_server) = server
server.filter(|server| *server != services().globals.server_name().as_str()) .filter(|server| *server != services().globals.server_name().as_str())
{ {
let response = services() let response = services()
.sending .sending
@ -174,10 +183,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
} }
}; };
num_since = characters num_since = characters.collect::<String>().parse().map_err(|_| {
.collect::<String>() Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token.")
.parse() })?;
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?;
if backwards { if backwards {
num_since = num_since.saturating_sub(limit); num_since = num_since.saturating_sub(limit);
@ -195,12 +203,19 @@ pub(crate) async fn get_public_rooms_filtered_helper(
canonical_alias: services() canonical_alias: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .room_state_get(
&room_id,
&StateEventType::RoomCanonicalAlias,
"",
)?
.map_or(Ok(None), |s| { .map_or(Ok(None), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias) .map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| { .map_err(|_| {
Error::bad_database("Invalid canonical alias event in database.") Error::bad_database(
"Invalid canonical alias event in \
database.",
)
}) })
})?, })?,
name: services().rooms.state_accessor.get_name(&room_id)?, name: services().rooms.state_accessor.get_name(&room_id)?,
@ -222,36 +237,55 @@ pub(crate) async fn get_public_rooms_filtered_helper(
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic)) .map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|_| { .map_err(|_| {
error!("Invalid room topic event in database for room {}", room_id); error!(
Error::bad_database("Invalid room topic event in database.") "Invalid room topic event in database for \
room {}",
room_id
);
Error::bad_database(
"Invalid room topic event in database.",
)
}) })
})?, })?,
world_readable: services() world_readable: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(
&room_id,
&StateEventType::RoomHistoryVisibility,
"",
)?
.map_or(Ok(false), |s| { .map_or(Ok(false), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| { .map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable c.history_visibility
== HistoryVisibility::WorldReadable
}) })
.map_err(|_| { .map_err(|_| {
Error::bad_database( Error::bad_database(
"Invalid room history visibility event in database.", "Invalid room history visibility event in \
database.",
) )
}) })
})?, })?,
guest_can_join: services() guest_can_join: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .room_state_get(
&room_id,
&StateEventType::RoomGuestAccess,
"",
)?
.map_or(Ok(false), |s| { .map_or(Ok(false), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| { .map(|c: RoomGuestAccessEventContent| {
c.guest_access == GuestAccess::CanJoin c.guest_access == GuestAccess::CanJoin
}) })
.map_err(|_| { .map_err(|_| {
Error::bad_database("Invalid room guest access event in database.") Error::bad_database(
"Invalid room guest access event in \
database.",
)
}) })
})?, })?,
avatar_url: services() avatar_url: services()
@ -262,7 +296,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomAvatarEventContent| c.url) .map(|c: RoomAvatarEventContent| c.url)
.map_err(|_| { .map_err(|_| {
Error::bad_database("Invalid room avatar event in database.") Error::bad_database(
"Invalid room avatar event in database.",
)
}) })
}) })
.transpose()? .transpose()?
@ -270,33 +306,59 @@ pub(crate) async fn get_public_rooms_filtered_helper(
join_rule: services() join_rule: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(
&room_id,
&StateEventType::RoomJoinRules,
"",
)?
.map(|s| { .map(|s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| match c.join_rule { .map(|c: RoomJoinRulesEventContent| {
JoinRule::Public => Some(PublicRoomJoinRule::Public), match c.join_rule {
JoinRule::Knock => Some(PublicRoomJoinRule::Knock), JoinRule::Public => {
_ => None, Some(PublicRoomJoinRule::Public)
}
JoinRule::Knock => {
Some(PublicRoomJoinRule::Knock)
}
_ => None,
}
}) })
.map_err(|e| { .map_err(|e| {
error!("Invalid room join rule event in database: {}", e); error!(
Error::BadDatabase("Invalid room join rule event in database.") "Invalid room join rule event in \
database: {}",
e
);
Error::BadDatabase(
"Invalid room join rule event in database.",
)
}) })
}) })
.transpose()? .transpose()?
.flatten() .flatten()
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, .ok_or_else(|| {
Error::bad_database(
"Missing room join rule event for room.",
)
})?,
room_type: services() room_type: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")? .room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.map(|s| { .map(|s| {
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err( serde_json::from_str::<RoomCreateEventContent>(
|e| { s.content.get(),
error!("Invalid room create event in database: {}", e);
Error::BadDatabase("Invalid room create event in database.")
},
) )
.map_err(|e| {
error!(
"Invalid room create event in database: {}",
e
);
Error::BadDatabase(
"Invalid room create event in database.",
)
})
}) })
.transpose()? .transpose()?
.and_then(|e| e.room_type), .and_then(|e| e.room_type),
@ -306,10 +368,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}) })
.filter_map(Result::<_>::ok) .filter_map(Result::<_>::ok)
.filter(|chunk| { .filter(|chunk| {
if let Some(query) = filter if let Some(query) =
.generic_search_term filter.generic_search_term.as_ref().map(|q| q.to_lowercase())
.as_ref()
.map(|q| q.to_lowercase())
{ {
if let Some(name) = &chunk.name { if let Some(name) = &chunk.name {
if name.as_str().to_lowercase().contains(&query) { if name.as_str().to_lowercase().contains(&query) {
@ -324,7 +384,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
} }
if let Some(canonical_alias) = &chunk.canonical_alias { if let Some(canonical_alias) = &chunk.canonical_alias {
if canonical_alias.as_str().to_lowercase().contains(&query) { if canonical_alias.as_str().to_lowercase().contains(&query)
{
return true; return true;
} }
} }
@ -339,7 +400,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
let total_room_count_estimate = all_rooms.len().try_into().unwrap_or(UInt::MAX); let total_room_count_estimate =
all_rooms.len().try_into().unwrap_or(UInt::MAX);
let chunk: Vec<_> = all_rooms let chunk: Vec<_> = all_rooms
.into_iter() .into_iter()
@ -353,11 +415,12 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Some(format!("p{num_since}")) Some(format!("p{num_since}"))
}; };
let next_batch = if chunk.len() < limit.try_into().expect("UInt should fit in usize") { let next_batch =
None if chunk.len() < limit.try_into().expect("UInt should fit in usize") {
} else { None
Some(format!("n{}", num_since + limit)) } else {
}; Some(format!("n{}", num_since + limit))
};
Ok(get_public_rooms_filtered::v3::Response { Ok(get_public_rooms_filtered::v3::Response {
chunk, chunk,

View file

@ -1,9 +1,10 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
filter::{create_filter, get_filter}, filter::{create_filter, get_filter},
}; };
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
/// ///
/// Loads a filter that was previously created. /// Loads a filter that was previously created.
@ -13,8 +14,13 @@ pub(crate) async fn get_filter_route(
body: Ruma<get_filter::v3::Request>, body: Ruma<get_filter::v3::Request>,
) -> Result<get_filter::v3::Response> { ) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let Some(filter) = services().users.get_filter(sender_user, &body.filter_id)? else { let Some(filter) =
return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); services().users.get_filter(sender_user, &body.filter_id)?
else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Filter not found.",
));
}; };
Ok(get_filter::v3::Response::new(filter)) Ok(get_filter::v3::Response::new(filter))

View file

@ -1,13 +1,16 @@
use super::SESSION_ID_LENGTH; use std::{
use crate::{services, utils, Error, Result, Ruma}; collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
error::ErrorKind, error::ErrorKind,
keys::{ keys::{
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, claim_keys, get_key_changes, get_keys, upload_keys,
upload_signing_keys, upload_signatures, upload_signing_keys,
}, },
uiaa::{AuthFlow, AuthType, UiaaInfo}, uiaa::{AuthFlow, AuthType, UiaaInfo},
}, },
@ -17,28 +20,32 @@ use ruma::{
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
}; };
use serde_json::json; use serde_json::json;
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use tracing::debug; use tracing::debug;
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/keys/upload` /// # `POST /_matrix/client/r0/keys/upload`
/// ///
/// Publish end-to-end encryption keys for the sender device. /// Publish end-to-end encryption keys for the sender device.
/// ///
/// - Adds one time keys /// - Adds one time keys
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) /// - If there are no device keys yet: Adds device keys (TODO: merge with
/// existing keys?)
pub(crate) async fn upload_keys_route( pub(crate) async fn upload_keys_route(
body: Ruma<upload_keys::v3::Request>, body: Ruma<upload_keys::v3::Request>,
) -> Result<upload_keys::v3::Response> { ) -> Result<upload_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services() services().users.add_one_time_key(
.users sender_user,
.add_one_time_key(sender_user, sender_device, key_key, key_value)?; sender_device,
key_key,
key_value,
)?;
} }
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
@ -49,9 +56,11 @@ pub(crate) async fn upload_keys_route(
.get_device_keys(sender_user, sender_device)? .get_device_keys(sender_user, sender_device)?
.is_none() .is_none()
{ {
services() services().users.add_device_keys(
.users sender_user,
.add_device_keys(sender_user, sender_device, device_keys)?; sender_device,
device_keys,
)?;
} }
} }
@ -68,14 +77,17 @@ pub(crate) async fn upload_keys_route(
/// ///
/// - Always fetches users from other servers over federation /// - Always fetches users from other servers over federation
/// - Gets master keys, self-signing keys, user signing keys and device keys. /// - Gets master keys, self-signing keys, user signing keys and device keys.
/// - The master and self-signing keys contain signatures that the user is allowed to see /// - The master and self-signing keys contain signatures that the user is
/// allowed to see
pub(crate) async fn get_keys_route( pub(crate) async fn get_keys_route(
body: Ruma<get_keys::v3::Request>, body: Ruma<get_keys::v3::Request>,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let response = let response = get_keys_helper(Some(sender_user), &body.device_keys, |u| {
get_keys_helper(Some(sender_user), &body.device_keys, |u| u == sender_user).await?; u == sender_user
})
.await?;
Ok(response) Ok(response)
} }
@ -100,7 +112,8 @@ pub(crate) async fn upload_signing_keys_route(
body: Ruma<upload_signing_keys::v3::Request>, body: Ruma<upload_signing_keys::v3::Request>,
) -> Result<upload_signing_keys::v3::Response> { ) -> Result<upload_signing_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
@ -114,19 +127,19 @@ pub(crate) async fn upload_signing_keys_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(
services() sender_user,
.uiaa sender_device,
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; auth,
&uiaainfo,
)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -156,8 +169,9 @@ pub(crate) async fn upload_signatures_route(
for (user_id, keys) in &body.signed_keys { for (user_id, keys) in &body.signed_keys {
for (key_id, key) in keys { for (key_id, key) in keys {
let key = serde_json::to_value(key) let key = serde_json::to_value(key).map_err(|_| {
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON")
})?;
for signature in key for signature in key
.get("signatures") .get("signatures")
@ -189,9 +203,12 @@ pub(crate) async fn upload_signatures_route(
))? ))?
.to_owned(), .to_owned(),
); );
services() services().users.sign_key(
.users user_id,
.sign_key(user_id, key_id, signature, sender_user)?; key_id,
signature,
sender_user,
)?;
} }
} }
} }
@ -204,7 +221,8 @@ pub(crate) async fn upload_signatures_route(
/// # `POST /_matrix/client/r0/keys/changes` /// # `POST /_matrix/client/r0/keys/changes`
/// ///
/// Gets a list of users who have updated their device identity keys since the previous sync token. /// Gets a list of users who have updated their device identity keys since the
/// previous sync token.
/// ///
/// - TODO: left users /// - TODO: left users
pub(crate) async fn get_key_changes_route( pub(crate) async fn get_key_changes_route(
@ -219,14 +237,15 @@ pub(crate) async fn get_key_changes_route(
.users .users
.keys_changed( .keys_changed(
sender_user.as_str(), sender_user.as_str(),
body.from body.from.parse().map_err(|_| {
.parse() Error::BadRequest(
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, ErrorKind::InvalidParam,
Some( "Invalid `from`.",
body.to )
.parse() })?,
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, Some(body.to.parse().map_err(|_| {
), Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
})?),
) )
.filter_map(Result::ok), .filter_map(Result::ok),
); );
@ -243,10 +262,16 @@ pub(crate) async fn get_key_changes_route(
.keys_changed( .keys_changed(
room_id.as_ref(), room_id.as_ref(),
body.from.parse().map_err(|_| { body.from.parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.") Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid `from`.",
)
})?, })?,
Some(body.to.parse().map_err(|_| { Some(body.to.parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.") Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid `to`.",
)
})?), })?),
) )
.filter_map(Result::ok), .filter_map(Result::ok),
@ -287,16 +312,24 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for device_id in services().users.all_device_ids(user_id) { for device_id in services().users.all_device_ids(user_id) {
let device_id = device_id?; let device_id = device_id?;
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { if let Some(mut keys) =
services().users.get_device_keys(user_id, &device_id)?
{
let metadata = services() let metadata = services()
.users .users
.get_device_metadata(user_id, &device_id)? .get_device_metadata(user_id, &device_id)?
.ok_or_else(|| { .ok_or_else(|| {
Error::bad_database("all_device_keys contained nonexistent device.") Error::bad_database(
"all_device_keys contained nonexistent device.",
)
})?; })?;
add_unsigned_device_display_name(&mut keys, metadata) add_unsigned_device_display_name(&mut keys, metadata)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| {
Error::bad_database(
"invalid device keys in database",
)
})?;
container.insert(device_id, keys); container.insert(device_id, keys);
} }
} }
@ -304,7 +337,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} else { } else {
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { if let Some(mut keys) =
services().users.get_device_keys(user_id, device_id)?
{
let metadata = services() let metadata = services()
.users .users
.get_device_metadata(user_id, device_id)? .get_device_metadata(user_id, device_id)?
@ -314,29 +349,35 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
))?; ))?;
add_unsigned_device_display_name(&mut keys, metadata) add_unsigned_device_display_name(&mut keys, metadata)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| {
Error::bad_database(
"invalid device keys in database",
)
})?;
container.insert(device_id.to_owned(), keys); container.insert(device_id.to_owned(), keys);
} }
device_keys.insert(user_id.to_owned(), container); device_keys.insert(user_id.to_owned(), container);
} }
} }
if let Some(master_key) = if let Some(master_key) = services().users.get_master_key(
services() sender_user,
.users user_id,
.get_master_key(sender_user, user_id, &allowed_signatures)? &allowed_signatures,
{ )? {
master_keys.insert(user_id.to_owned(), master_key); master_keys.insert(user_id.to_owned(), master_key);
} }
if let Some(self_signing_key) = if let Some(self_signing_key) = services().users.get_self_signing_key(
services() sender_user,
.users user_id,
.get_self_signing_key(sender_user, user_id, &allowed_signatures)? &allowed_signatures,
{ )? {
self_signing_keys.insert(user_id.to_owned(), self_signing_key); self_signing_keys.insert(user_id.to_owned(), self_signing_key);
} }
if Some(user_id) == sender_user { if Some(user_id) == sender_user {
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { if let Some(user_signing_key) =
services().users.get_user_signing_key(user_id)?
{
user_signing_keys.insert(user_id.to_owned(), user_signing_key); user_signing_keys.insert(user_id.to_owned(), user_signing_key);
} }
} }
@ -345,17 +386,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
let back_off = |id| async { let back_off = |id| async {
match services() match services().globals.bad_query_ratelimiter.write().await.entry(id) {
.globals
.bad_query_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
} }
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1 + 1);
}
} }
}; };
@ -370,7 +407,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
.get(server) .get(server)
{ {
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); let mut min_elapsed_duration =
Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24); min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
} }
@ -379,7 +417,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
debug!("Backing off query from {:?}", server); debug!("Backing off query from {:?}", server);
return ( return (
server, server,
Err(Error::BadServerResponse("bad query, still backing off")), Err(Error::BadServerResponse(
"bad query, still backing off",
)),
); );
} }
} }
@ -417,15 +457,19 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
&user, &user,
&allowed_signatures, &allowed_signatures,
)? { )? {
let (_, our_master_key) = let (_, our_master_key) = services()
services().users.parse_master_key(&user, &our_master_key)?; .users
.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures); master_key.signatures.extend(our_master_key.signatures);
} }
let json = serde_json::to_value(master_key).expect("to_value always works"); let json = serde_json::to_value(master_key)
let raw = serde_json::from_value(json).expect("Raw::from_value always works"); .expect("to_value always works");
let raw = serde_json::from_value(json)
.expect("Raw::from_value always works");
services().users.add_cross_signing_keys( services().users.add_cross_signing_keys(
&user, &raw, &None, &None, &user, &raw, &None, &None,
// Dont notify. A notification would trigger another key request resulting in an endless loop // Dont notify. A notification would trigger another key
// request resulting in an endless loop
false, false,
)?; )?;
master_keys.insert(user, raw); master_keys.insert(user, raw);
@ -454,11 +498,13 @@ fn add_unsigned_device_display_name(
metadata: ruma::api::client::device::Device, metadata: ruma::api::client::device::Device,
) -> serde_json::Result<()> { ) -> serde_json::Result<()> {
if let Some(display_name) = metadata.display_name { if let Some(display_name) = metadata.display_name {
let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?; let mut object = keys
.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
let unsigned = object.entry("unsigned").or_insert_with(|| json!({})); let unsigned = object.entry("unsigned").or_insert_with(|| json!({}));
if let serde_json::Value::Object(unsigned_object) = unsigned { if let serde_json::Value::Object(unsigned_object) = unsigned {
unsigned_object.insert("device_display_name".to_owned(), display_name.into()); unsigned_object
.insert("device_display_name".to_owned(), display_name.into());
} }
*keys = Raw::from_json(serde_json::value::to_raw_value(&object)?); *keys = Raw::from_json(serde_json::value::to_raw_value(&object)?);
@ -468,7 +514,10 @@ fn add_unsigned_device_display_name(
} }
pub(crate) async fn claim_keys_helper( pub(crate) async fn claim_keys_helper(
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, one_time_keys_input: &BTreeMap<
OwnedUserId,
BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>,
>,
) -> Result<claim_keys::v3::Response> { ) -> Result<claim_keys::v3::Response> {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
@ -484,11 +533,11 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = if let Some(one_time_keys) = services().users.take_one_time_key(
services() user_id,
.users device_id,
.take_one_time_key(user_id, device_id, key_algorithm)? key_algorithm,
{ )? {
let mut c = BTreeMap::new(); let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1); c.insert(one_time_keys.0, one_time_keys.1);
container.insert(device_id.clone(), c); container.insert(device_id.clone(), c);

View file

@ -1,14 +1,15 @@
use std::time::Duration; use std::time::Duration;
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
media::{ media::{
create_content, get_content, get_content_as_filename, get_content_thumbnail, create_content, get_content, get_content_as_filename,
get_media_config, get_content_thumbnail, get_media_config,
}, },
}; };
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
const MXC_LENGTH: usize = 32; const MXC_LENGTH: usize = 32;
/// # `GET /_matrix/media/r0/config` /// # `GET /_matrix/media/r0/config`
@ -110,9 +111,12 @@ pub(crate) async fn get_content_route(
content_disposition, content_disposition,
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
}) })
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let remote_content_response = let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(remote_content_response) Ok(remote_content_response)
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
@ -130,21 +134,32 @@ pub(crate) async fn get_content_as_filename_route(
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type,
file,
..
}) = services().media.get(mxc.clone()).await? }) = services().media.get(mxc.clone()).await?
{ {
Ok(get_content_as_filename::v3::Response { Ok(get_content_as_filename::v3::Response {
file, file,
content_type, content_type,
content_disposition: Some(format!("inline; filename={}", body.filename)), content_disposition: Some(format!(
"inline; filename={}",
body.filename
)),
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
}) })
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let remote_content_response = let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(get_content_as_filename::v3::Response { Ok(get_content_as_filename::v3::Response {
content_disposition: Some(format!("inline: filename={}", body.filename)), content_disposition: Some(format!(
"inline: filename={}",
body.filename
)),
content_type: remote_content_response.content_type, content_type: remote_content_response.content_type,
file: remote_content_response.file, file: remote_content_response.file,
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
@ -165,17 +180,19 @@ pub(crate) async fn get_content_thumbnail_route(
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type,
file,
..
}) = services() }) = services()
.media .media
.get_thumbnail( .get_thumbnail(
mxc.clone(), mxc.clone(),
body.width body.width.try_into().map_err(|_| {
.try_into() Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, })?,
body.height body.height.try_into().map_err(|_| {
.try_into() Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, })?,
) )
.await? .await?
{ {
@ -184,7 +201,9 @@ pub(crate) async fn get_content_thumbnail_route(
content_type, content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
}) })
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let get_thumbnail_response = services() let get_thumbnail_response = services()
.sending .sending
.send_federation_request( .send_federation_request(

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,8 @@
use crate::{ use std::{
service::{pdu::PduBuilder, rooms::timeline::PduCount}, collections::{BTreeMap, HashSet},
services, utils, Error, Result, Ruma, sync::Arc,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -10,18 +11,21 @@ use ruma::{
events::{StateEventType, TimelineEventType}, events::{StateEventType, TimelineEventType},
uint, uint,
}; };
use std::{
collections::{BTreeMap, HashSet}, use crate::{
sync::Arc, service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, Result, Ruma,
}; };
/// # `PUT /_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}`
/// ///
/// Send a message event into the room. /// Send a message event into the room.
/// ///
/// - Is a NOOP if the txn id was already used before and returns the same event id again /// - Is a NOOP if the txn id was already used before and returns the same event
/// id again
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
pub(crate) async fn send_message_event_route( pub(crate) async fn send_message_event_route(
body: Ruma<send_message_event::v3::Request>, body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> { ) -> Result<send_message_event::v3::Response> {
@ -50,29 +54,37 @@ pub(crate) async fn send_message_event_route(
} }
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = if let Some(response) = services().transaction_ids.existing_txnid(
services() sender_user,
.transaction_ids sender_device,
.existing_txnid(sender_user, sender_device, &body.txn_id)? &body.txn_id,
{ )? {
// The client might have sent a txnid of the /sendToDevice endpoint // The client might have sent a txnid of the /sendToDevice endpoint
// This txnid has no response associated with it // This txnid has no response associated with it
if response.is_empty() { if response.is_empty() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Tried to use txn id already used for an incompatible endpoint.", "Tried to use txn id already used for an incompatible \
endpoint.",
)); ));
} }
let event_id = utils::string_from_bytes(&response) let event_id = utils::string_from_bytes(&response)
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? .map_err(|_| {
Error::bad_database("Invalid txnid bytes in database.")
})?
.try_into() .try_into()
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; .map_err(|_| {
return Ok(send_message_event::v3::Response { event_id }); Error::bad_database("Invalid event id in txnid data.")
})?;
return Ok(send_message_event::v3::Response {
event_id,
});
} }
let mut unsigned = BTreeMap::new(); let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); unsigned
.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let event_id = services() let event_id = services()
.rooms .rooms
@ -81,7 +93,12 @@ pub(crate) async fn send_message_event_route(
PduBuilder { PduBuilder {
event_type: body.event_type.to_string().into(), event_type: body.event_type.to_string().into(),
content: serde_json::from_str(body.body.body.json().get()) content: serde_json::from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, .map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid JSON body.",
)
})?,
unsigned: Some(unsigned), unsigned: Some(unsigned),
state_key: None, state_key: None,
redacts: None, redacts: None,
@ -101,23 +118,23 @@ pub(crate) async fn send_message_event_route(
drop(state_lock); drop(state_lock);
Ok(send_message_event::v3::Response::new( Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
(*event_id).to_owned(),
))
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
/// ///
/// Allows paginating through room history. /// Allows paginating through room history.
/// ///
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// - Only works if the user is joined (TODO: always allow, but only show events
/// where the user was
/// joined, depending on `history_visibility`) /// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn get_message_events_route( pub(crate) async fn get_message_events_route(
body: Ruma<get_message_events::v3::Request>, body: Ruma<get_message_events::v3::Request>,
) -> Result<get_message_events::v3::Response> { ) -> Result<get_message_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?, Some(from) => PduCount::try_from_string(&from)?,
@ -127,15 +144,17 @@ pub(crate) async fn get_message_events_route(
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
services() services()
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) .lazy_load_confirm_delivery(
sender_user,
sender_device,
&body.room_id,
from,
)
.await?; .await?;
let limit = body let limit = body
@ -162,7 +181,11 @@ pub(crate) async fn get_message_events_route(
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(
sender_user,
&body.room_id,
&pdu.event_id,
)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) .take_while(|&(k, _)| Some(k) != to)
@ -214,7 +237,11 @@ pub(crate) async fn get_message_events_route(
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(
sender_user,
&body.room_id,
&pdu.event_id,
)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) .take_while(|&(k, _)| Some(k) != to)
@ -254,11 +281,13 @@ pub(crate) async fn get_message_events_route(
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = services().rooms.state_accessor.room_state_get( if let Some(member_event) =
&body.room_id, services().rooms.state_accessor.room_state_get(
&StateEventType::RoomMember, &body.room_id,
ll_id.as_str(), &StateEventType::RoomMember,
)? { ll_id.as_str(),
)?
{
resp.state.push(member_event.to_state_event()); resp.state.push(member_event.to_state_event());
} }
} }

View file

@ -1,20 +1,25 @@
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; use std::sync::Arc;
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
error::ErrorKind, error::ErrorKind,
profile::{ profile::{
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, get_avatar_url, get_display_name, get_profile, set_avatar_url,
set_display_name,
}, },
}, },
federation::{self, query::get_profile_information::v1::ProfileField}, federation::{self, query::get_profile_information::v1::ProfileField},
}, },
events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, events::{
room::member::RoomMemberEventContent, StateEventType, TimelineEventType,
},
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use std::sync::Arc;
use tracing::warn; use tracing::warn;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// # `PUT /_matrix/client/r0/profile/{userId}/displayname`
/// ///
/// Updates the displayname. /// Updates the displayname.
@ -25,9 +30,7 @@ pub(crate) async fn set_displayname_route(
) -> Result<set_display_name::v3::Response> { ) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().users.set_displayname(sender_user, body.displayname.clone())?;
.users
.set_displayname(sender_user, body.displayname.clone())?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_rooms_joined: Vec<_> = services() let all_rooms_joined: Vec<_> = services()
@ -53,14 +56,18 @@ pub(crate) async fn set_displayname_route(
)? )?
.ok_or_else(|| { .ok_or_else(|| {
Error::bad_database( Error::bad_database(
"Tried to send displayname update for user not in the \ "Tried to send displayname update for \
room.", user not in the room.",
) )
})? })?
.content .content
.get(), .get(),
) )
.map_err(|_| Error::bad_database("Database contains invalid PDU."))? .map_err(|_| {
Error::bad_database(
"Database contains invalid PDU.",
)
})?
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -88,7 +95,12 @@ pub(crate) async fn set_displayname_route(
if let Err(error) = services() if let Err(error) = services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) .build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await .await
{ {
warn!(%error, "failed to add PDU"); warn!(%error, "failed to add PDU");
@ -138,13 +150,9 @@ pub(crate) async fn set_avatar_url_route(
) -> Result<set_avatar_url::v3::Response> { ) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().users.set_avatar_url(sender_user, body.avatar_url.clone())?;
.users
.set_avatar_url(sender_user, body.avatar_url.clone())?;
services() services().users.set_blurhash(sender_user, body.blurhash.clone())?;
.users
.set_blurhash(sender_user, body.blurhash.clone())?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_joined_rooms: Vec<_> = services() let all_joined_rooms: Vec<_> = services()
@ -170,14 +178,18 @@ pub(crate) async fn set_avatar_url_route(
)? )?
.ok_or_else(|| { .ok_or_else(|| {
Error::bad_database( Error::bad_database(
"Tried to send displayname update for user not in the \ "Tried to send displayname update for \
room.", user not in the room.",
) )
})? })?
.content .content
.get(), .get(),
) )
.map_err(|_| Error::bad_database("Database contains invalid PDU."))? .map_err(|_| {
Error::bad_database(
"Database contains invalid PDU.",
)
})?
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -205,7 +217,12 @@ pub(crate) async fn set_avatar_url_route(
if let Err(error) = services() if let Err(error) = services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) .build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await .await
{ {
warn!(%error, "failed to add PDU"); warn!(%error, "failed to add PDU");
@ -219,7 +236,8 @@ pub(crate) async fn set_avatar_url_route(
/// ///
/// Returns the `avatar_url` and `blurhash` of the user. /// Returns the `avatar_url` and `blurhash` of the user.
/// ///
/// - If user is on another server: Fetches `avatar_url` and `blurhash` over federation /// - If user is on another server: Fetches `avatar_url` and `blurhash` over
/// federation
pub(crate) async fn get_avatar_url_route( pub(crate) async fn get_avatar_url_route(
body: Ruma<get_avatar_url::v3::Request>, body: Ruma<get_avatar_url::v3::Request>,
) -> Result<get_avatar_url::v3::Response> { ) -> Result<get_avatar_url::v3::Response> {

View file

@ -1,17 +1,18 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
push::{ push::{
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions,
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions, get_pushrule_enabled, get_pushrules_all, set_pusher, set_pushrule,
set_pushrule_enabled, RuleScope, set_pushrule_actions, set_pushrule_enabled, RuleScope,
}, },
}, },
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
push::{AnyPushRuleRef, InsertPushRuleError, RemovePushRuleError}, push::{AnyPushRuleRef, InsertPushRuleError, RemovePushRuleError},
}; };
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/pushrules` /// # `GET /_matrix/client/r0/pushrules`
/// ///
/// Retrieves the push rules event for this user. /// Retrieves the push rules event for this user.
@ -71,12 +72,11 @@ pub(crate) async fn get_pushrule_route(
.map(Into::into); .map(Into::into);
if let Some(rule) = rule { if let Some(rule) = rule {
Ok(get_pushrule::v3::Response { rule }) Ok(get_pushrule::v3::Response {
rule,
})
} else { } else {
Err(Error::BadRequest( Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
ErrorKind::NotFound,
"Push rule not found.",
))
} }
} }
@ -109,7 +109,9 @@ pub(crate) async fn set_pushrule_route(
))?; ))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if let Err(error) = account_data.content.global.insert( if let Err(error) = account_data.content.global.insert(
body.rule.clone(), body.rule.clone(),
@ -119,16 +121,20 @@ pub(crate) async fn set_pushrule_route(
let err = match error { let err = match error {
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Rule IDs starting with a dot are reserved for server-default rules.", "Rule IDs starting with a dot are reserved for server-default \
rules.",
), ),
InsertPushRuleError::InvalidRuleId => Error::BadRequest( InsertPushRuleError::InvalidRuleId => Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Rule ID containing invalid characters.", "Rule ID containing invalid characters.",
), ),
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest( InsertPushRuleError::RelativeToServerDefaultRule => {
ErrorKind::InvalidParam, Error::BadRequest(
"Can't place a push rule relatively to a server-default rule.", ErrorKind::InvalidParam,
), "Can't place a push rule relatively to a server-default \
rule.",
)
}
InsertPushRuleError::UnknownRuleId => Error::BadRequest( InsertPushRuleError::UnknownRuleId => Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"The before or after rule could not be found.", "The before or after rule could not be found.",
@ -147,7 +153,8 @@ pub(crate) async fn set_pushrule_route(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
Ok(set_pushrule::v3::Response {}) Ok(set_pushrule::v3::Response {})
@ -193,7 +200,9 @@ pub(crate) async fn get_pushrule_actions_route(
"Push rule not found.", "Push rule not found.",
))?; ))?;
Ok(get_pushrule_actions::v3::Response { actions }) Ok(get_pushrule_actions::v3::Response {
actions,
})
} }
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
@ -224,7 +233,9 @@ pub(crate) async fn set_pushrule_actions_route(
))?; ))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if account_data if account_data
.content .content
@ -242,7 +253,8 @@ pub(crate) async fn set_pushrule_actions_route(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
Ok(set_pushrule_actions::v3::Response {}) Ok(set_pushrule_actions::v3::Response {})
@ -276,7 +288,9 @@ pub(crate) async fn get_pushrule_enabled_route(
))?; ))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
let global = account_data.content.global; let global = account_data.content.global;
let enabled = global let enabled = global
@ -287,7 +301,9 @@ pub(crate) async fn get_pushrule_enabled_route(
"Push rule not found.", "Push rule not found.",
))?; ))?;
Ok(get_pushrule_enabled::v3::Response { enabled }) Ok(get_pushrule_enabled::v3::Response {
enabled,
})
} }
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
@ -318,7 +334,9 @@ pub(crate) async fn set_pushrule_enabled_route(
))?; ))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if account_data if account_data
.content .content
@ -336,7 +354,8 @@ pub(crate) async fn set_pushrule_enabled_route(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
Ok(set_pushrule_enabled::v3::Response {}) Ok(set_pushrule_enabled::v3::Response {})
@ -370,12 +389,12 @@ pub(crate) async fn delete_pushrule_route(
))?; ))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if let Err(error) = account_data if let Err(error) =
.content account_data.content.global.remove(body.kind.clone(), &body.rule_id)
.global
.remove(body.kind.clone(), &body.rule_id)
{ {
let err = match error { let err = match error {
RemovePushRuleError::ServerDefault => Error::BadRequest( RemovePushRuleError::ServerDefault => Error::BadRequest(
@ -395,7 +414,8 @@ pub(crate) async fn delete_pushrule_route(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
Ok(delete_pushrule::v3::Response {}) Ok(delete_pushrule::v3::Response {})
@ -424,9 +444,7 @@ pub(crate) async fn set_pushers_route(
) -> Result<set_pusher::v3::Response> { ) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().pusher.set_pusher(sender_user, body.action.clone())?;
.pusher
.set_pusher(sender_user, body.action.clone())?;
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }

View file

@ -1,20 +1,27 @@
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; use std::collections::BTreeMap;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, api::client::{
error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt,
},
events::{ events::{
receipt::{ReceiptThread, ReceiptType}, receipt::{ReceiptThread, ReceiptType},
RoomAccountDataEventType, RoomAccountDataEventType,
}, },
MilliSecondsSinceUnixEpoch, MilliSecondsSinceUnixEpoch,
}; };
use std::collections::BTreeMap;
use crate::{
service::rooms::timeline::PduCount, services, Error, Result, Ruma,
};
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
/// ///
/// Sets different types of read markers. /// Sets different types of read markers.
/// ///
/// - Updates fully-read account data event to `fully_read` /// - Updates fully-read account data event to `fully_read`
/// - If `read_receipt` is set: Update private marker and public read receipt EDU /// - If `read_receipt` is set: Update private marker and public read receipt
/// EDU
pub(crate) async fn set_read_marker_route( pub(crate) async fn set_read_marker_route(
body: Ruma<set_read_marker::v3::Request>, body: Ruma<set_read_marker::v3::Request>,
) -> Result<set_read_marker::v3::Response> { ) -> Result<set_read_marker::v3::Response> {
@ -30,7 +37,8 @@ pub(crate) async fn set_read_marker_route(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event)
.expect("to json value always works"),
)?; )?;
} }
@ -42,14 +50,9 @@ pub(crate) async fn set_read_marker_route(
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
let count = services() let count = services().rooms.timeline.get_pdu_count(event)?.ok_or(
.rooms Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."),
.timeline )?;
.get_pdu_count(event)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Event does not exist.",
))?;
let count = match count { let count = match count {
PduCount::Backfilled(_) => { PduCount::Backfilled(_) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -59,11 +62,11 @@ pub(crate) async fn set_read_marker_route(
} }
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services() services().rooms.edus.read_receipt.private_read_set(
.rooms &body.room_id,
.edus sender_user,
.read_receipt count,
.private_read_set(&body.room_id, sender_user, count)?; )?;
} }
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
@ -86,7 +89,9 @@ pub(crate) async fn set_read_marker_route(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(
receipt_content,
),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )?;
@ -105,7 +110,8 @@ pub(crate) async fn create_receipt_route(
if matches!( if matches!(
&body.receipt_type, &body.receipt_type,
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate create_receipt::v3::ReceiptType::Read
| create_receipt::v3::ReceiptType::ReadPrivate
) { ) {
services() services()
.rooms .rooms
@ -124,7 +130,8 @@ pub(crate) async fn create_receipt_route(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event)
.expect("to json value always works"),
)?; )?;
} }
create_receipt::v3::ReceiptType::Read => { create_receipt::v3::ReceiptType::Read => {
@ -146,7 +153,9 @@ pub(crate) async fn create_receipt_route(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(
receipt_content,
),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )?;

View file

@ -1,13 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
use ruma::{ use ruma::{
api::client::redact::redact_event, api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
/// ///
/// Tries to send a redaction event into the room. /// Tries to send a redaction event into the room.
@ -54,5 +54,7 @@ pub(crate) async fn redact_event_route(
drop(state_lock); drop(state_lock);
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(redact_event::v3::Response { event_id }) Ok(redact_event::v3::Response {
event_id,
})
} }

View file

@ -23,10 +23,7 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body
@ -36,27 +33,22 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
.try_into() .try_into()
.expect("0-100 should fit in usize"); .expect("0-100 should fit in usize");
let res = services() let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, Some(&body.event_type),
&body.room_id, Some(&body.rel_type),
&body.event_id, from,
Some(&body.event_type), to,
Some(&body.rel_type), limit,
from, )?;
to,
limit,
)?;
Ok( Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk,
chunk: res.chunk, next_batch: res.next_batch,
next_batch: res.next_batch, prev_batch: res.prev_batch,
prev_batch: res.prev_batch, })
},
)
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
@ -74,10 +66,7 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body
@ -87,19 +76,16 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
.try_into() .try_into()
.expect("0-100 should fit in usize"); .expect("0-100 should fit in usize");
let res = services() let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, None,
&body.room_id, Some(&body.rel_type),
&body.event_id, from,
None, to,
Some(&body.rel_type), limit,
from, )?;
to,
limit,
)?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
@ -123,10 +109,7 @@ pub(crate) async fn get_relating_events_route(
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body
@ -136,17 +119,14 @@ pub(crate) async fn get_relating_events_route(
.try_into() .try_into()
.expect("0-100 should fit in usize"); .expect("0-100 should fit in usize");
services() services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, None,
&body.room_id, None,
&body.event_id, from,
None, to,
None, limit,
from, )
to,
limit,
)
} }

View file

@ -1,14 +1,14 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, room::report_content}, api::client::{error::ErrorKind, room::report_content},
events::room::message, events::room::message,
int, int,
}; };
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/rooms/{roomId}/report/{eventId}` /// # `POST /_matrix/client/r0/rooms/{roomId}/report/{eventId}`
/// ///
/// Reports an inappropriate event to homeserver admins /// Reports an inappropriate event to homeserver admins
///
pub(crate) async fn report_event_route( pub(crate) async fn report_event_route(
body: Ruma<report_content::v3::Request>, body: Ruma<report_content::v3::Request>,
) -> Result<report_content::v3::Response> { ) -> Result<report_content::v3::Response> {

View file

@ -1,6 +1,5 @@
use crate::{ use std::{cmp::max, collections::BTreeMap, sync::Arc};
api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma,
};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -11,7 +10,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent, canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent, create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent}, guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent}, member::{MembershipState, RoomMemberEventContent},
name::RoomNameEventContent, name::RoomNameEventContent,
@ -26,9 +27,13 @@ use ruma::{
CanonicalJsonObject, OwnedRoomAliasId, RoomAliasId, RoomId, RoomVersionId, CanonicalJsonObject, OwnedRoomAliasId, RoomAliasId, RoomId, RoomVersionId,
}; };
use serde_json::{json, value::to_raw_value}; use serde_json::{json, value::to_raw_value};
use std::{cmp::max, collections::BTreeMap, sync::Arc};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{
api::client_server::invite_helper, service::pdu::PduBuilder, services,
Error, Result, Ruma,
};
/// # `POST /_matrix/client/r0/createRoom` /// # `POST /_matrix/client/r0/createRoom`
/// ///
/// Creates a new room. /// Creates a new room.
@ -79,32 +84,27 @@ pub(crate) async fn create_room_route(
} }
let alias: Option<OwnedRoomAliasId> = let alias: Option<OwnedRoomAliasId> =
body.room_alias_name body.room_alias_name.as_ref().map_or(Ok(None), |localpart| {
.as_ref() // TODO: Check for invalid characters and maximum length
.map_or(Ok(None), |localpart| { let alias = RoomAliasId::parse(format!(
// TODO: Check for invalid characters and maximum length "#{}:{}",
let alias = RoomAliasId::parse(format!( localpart,
"#{}:{}", services().globals.server_name()
localpart, ))
services().globals.server_name() .map_err(|_| {
)) Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.")
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
if services()
.rooms
.alias
.resolve_local_alias(&alias)?
.is_some()
{
Err(Error::BadRequest(
ErrorKind::RoomInUse,
"Room alias already exists.",
))
} else {
Ok(Some(alias))
}
})?; })?;
if services().rooms.alias.resolve_local_alias(&alias)?.is_some() {
Err(Error::BadRequest(
ErrorKind::RoomInUse,
"Room alias already exists.",
))
} else {
Ok(Some(alias))
}
})?;
if let Some(alias) = &alias { if let Some(alias) = &alias {
if let Some(info) = &body.appservice_info { if let Some(info) = &body.appservice_info {
if !info.aliases.is_match(alias.as_str()) { if !info.aliases.is_match(alias.as_str()) {
@ -159,7 +159,10 @@ pub(crate) async fn create_room_route(
content.insert( content.insert(
"creator".into(), "creator".into(),
json!(&sender_user).try_into().map_err(|_| { json!(&sender_user).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?, })?,
); );
} }
@ -171,7 +174,10 @@ pub(crate) async fn create_room_route(
content.insert( content.insert(
"room_version".into(), "room_version".into(),
json!(room_version.as_str()).try_into().map_err(|_| { json!(room_version.as_str()).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?, })?,
); );
content content
@ -187,20 +193,30 @@ pub(crate) async fn create_room_route(
| RoomVersionId::V7 | RoomVersionId::V7
| RoomVersionId::V8 | RoomVersionId::V8
| RoomVersionId::V9 | RoomVersionId::V9
| RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), | RoomVersionId::V10 => {
RoomCreateEventContent::new_v1(sender_user.clone())
}
RoomVersionId::V11 => RoomCreateEventContent::new_v11(), RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => unreachable!("Validity of room version already checked"), _ => unreachable!("Validity of room version already checked"),
}; };
let mut content = serde_json::from_str::<CanonicalJsonObject>( let mut content = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&content) to_raw_value(&content)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))? .map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?
.get(), .get(),
) )
.unwrap(); .unwrap();
content.insert( content.insert(
"room_version".into(), "room_version".into(),
json!(room_version.as_str()).try_into().map_err(|_| { json!(room_version.as_str()).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?, })?,
); );
content content
@ -209,9 +225,7 @@ pub(crate) async fn create_room_route(
// Validate creation content // Validate creation content
let de_result = serde_json::from_str::<CanonicalJsonObject>( let de_result = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&content) to_raw_value(&content).expect("Invalid creation content").get(),
.expect("Invalid creation content")
.get(),
); );
if de_result.is_err() { if de_result.is_err() {
@ -228,7 +242,8 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomCreate, event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"), content: to_raw_value(&content)
.expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
redacts: None, redacts: None,
@ -285,17 +300,24 @@ pub(crate) async fn create_room_route(
} }
} }
let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { let mut power_levels_content =
users, serde_json::to_value(RoomPowerLevelsEventContent {
..Default::default() users,
}) ..Default::default()
.expect("event is valid, we just created it"); })
.expect("event is valid, we just created it");
if let Some(power_level_content_override) = &body.power_level_content_override { if let Some(power_level_content_override) =
let json: JsonObject = serde_json::from_str(power_level_content_override.json().get()) &body.power_level_content_override
.map_err(|_| { {
Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.") let json: JsonObject =
})?; serde_json::from_str(power_level_content_override.json().get())
.map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid power_level_content_override.",
)
})?;
for (key, value) in json { for (key, value) in json {
power_levels_content[key] = value; power_levels_content[key] = value;
@ -353,11 +375,13 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomJoinRules, event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { content: to_raw_value(&RoomJoinRulesEventContent::new(
RoomPreset::PublicChat => JoinRule::Public, match preset {
// according to spec "invite" is the default RoomPreset::PublicChat => JoinRule::Public,
_ => JoinRule::Invite, // according to spec "invite" is the default
})) _ => JoinRule::Invite,
},
))
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
@ -397,10 +421,12 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomGuestAccess, event_type: TimelineEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { content: to_raw_value(&RoomGuestAccessEventContent::new(
RoomPreset::PublicChat => GuestAccess::Forbidden, match preset {
_ => GuestAccess::CanJoin, RoomPreset::PublicChat => GuestAccess::Forbidden,
})) _ => GuestAccess::CanJoin,
},
))
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
@ -414,10 +440,14 @@ pub(crate) async fn create_room_route(
// 6. Events listed in initial_state // 6. Events listed in initial_state
for event in &body.initial_state { for event in &body.initial_state {
let mut pdu_builder = event.deserialize_as::<PduBuilder>().map_err(|e| { let mut pdu_builder =
warn!("Invalid initial state event: {:?}", e); event.deserialize_as::<PduBuilder>().map_err(|e| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.") warn!("Invalid initial state event: {:?}", e);
})?; Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid initial state event.",
)
})?;
// Implicit state key defaults to "" // Implicit state key defaults to ""
pdu_builder.state_key.get_or_insert_with(String::new); pdu_builder.state_key.get_or_insert_with(String::new);
@ -432,7 +462,12 @@ pub(crate) async fn create_room_route(
services() services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) .build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await?; .await?;
} }
@ -444,8 +479,10 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomName, event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(name.clone())) content: to_raw_value(&RoomNameEventContent::new(
.expect("event is valid, we just created it"), name.clone(),
))
.expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
redacts: None, redacts: None,
@ -483,7 +520,8 @@ pub(crate) async fn create_room_route(
drop(state_lock); drop(state_lock);
for user_id in &body.invite { for user_id in &body.invite {
if let Err(error) = if let Err(error) =
invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await invite_helper(sender_user, user_id, &room_id, None, body.is_direct)
.await
{ {
warn!(%error, "invite helper failed"); warn!(%error, "invite helper failed");
}; };
@ -507,20 +545,19 @@ pub(crate) async fn create_room_route(
/// ///
/// Gets a single event. /// Gets a single event.
/// ///
/// - You have to currently be joined to the room (TODO: Respect history visibility) /// - You have to currently be joined to the room (TODO: Respect history
/// visibility)
pub(crate) async fn get_room_event_route( pub(crate) async fn get_room_event_route(
body: Ruma<get_room_event::v3::Request>, body: Ruma<get_room_event::v3::Request>,
) -> Result<get_room_event::v3::Response> { ) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(
.rooms || {
.timeline
.get_pdu(&body.event_id)?
.ok_or_else(|| {
warn!("Event not found, event ID: {:?}", &body.event_id); warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.") Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?; },
)?;
if !services().rooms.state_accessor.user_can_see_event( if !services().rooms.state_accessor.user_can_see_event(
sender_user, sender_user,
@ -545,17 +582,14 @@ pub(crate) async fn get_room_event_route(
/// ///
/// Lists all aliases of the room. /// Lists all aliases of the room.
/// ///
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if `history_visibility` is world readable /// - Only users joined to the room are allowed to call this TODO: Allow any
/// user to call it if `history_visibility` is world readable
pub(crate) async fn get_room_aliases_route( pub(crate) async fn get_room_aliases_route(
body: Ruma<aliases::v3::Request>, body: Ruma<aliases::v3::Request>,
) -> Result<aliases::v3::Response> { ) -> Result<aliases::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -588,10 +622,7 @@ pub(crate) async fn upgrade_room_route(
) -> Result<upgrade_room::v3::Response> { ) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().globals.supported_room_versions().contains(&body.new_version)
.globals
.supported_room_versions()
.contains(&body.new_version)
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnsupportedRoomVersion, ErrorKind::UnsupportedRoomVersion,
@ -601,10 +632,7 @@ pub(crate) async fn upgrade_room_route(
// Create a replacement room // Create a replacement room
let replacement_room = RoomId::new(services().globals.server_name()); let replacement_room = RoomId::new(services().globals.server_name());
services() services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
.rooms
.short
.get_or_create_shortroomid(&replacement_room)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
services() services()
@ -617,8 +645,9 @@ pub(crate) async fn upgrade_room_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Send a m.room.tombstone event to the old room to indicate that it is not
// Fail if the sender does not have the required permissions // intended to be used any further Fail if the sender does not have the
// required permissions
let tombstone_event_id = services() let tombstone_event_id = services()
.rooms .rooms
.timeline .timeline
@ -659,7 +688,9 @@ pub(crate) async fn upgrade_room_route(
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .ok_or_else(|| {
Error::bad_database("Found room without m.room.create event.")
})?
.content .content
.get(), .get(),
) )
@ -671,7 +702,8 @@ pub(crate) async fn upgrade_room_route(
(*tombstone_event_id).to_owned(), (*tombstone_event_id).to_owned(),
)); ));
// Send a m.room.create event containing a predecessor field and the applicable room_version // Send a m.room.create event containing a predecessor field and the
// applicable room_version
match body.new_version { match body.new_version {
RoomVersionId::V1 RoomVersionId::V1
| RoomVersionId::V2 | RoomVersionId::V2
@ -686,7 +718,10 @@ pub(crate) async fn upgrade_room_route(
create_event_content.insert( create_event_content.insert(
"creator".into(), "creator".into(),
json!(&sender_user).try_into().map_err(|_| { json!(&sender_user).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") Error::BadRequest(
ErrorKind::BadJson,
"Error forming creation event",
)
})?, })?,
); );
} }
@ -698,15 +733,21 @@ pub(crate) async fn upgrade_room_route(
} }
create_event_content.insert( create_event_content.insert(
"room_version".into(), "room_version".into(),
json!(&body.new_version) json!(&body.new_version).try_into().map_err(|_| {
.try_into() Error::BadRequest(
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, ErrorKind::BadJson,
"Error forming creation event",
)
})?,
); );
create_event_content.insert( create_event_content.insert(
"predecessor".into(), "predecessor".into(),
json!(predecessor) json!(predecessor).try_into().map_err(|_| {
.try_into() Error::BadRequest(
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, ErrorKind::BadJson,
"Error forming creation event",
)
})?,
); );
// Validate creation event content // Validate creation event content
@ -784,16 +825,15 @@ pub(crate) async fn upgrade_room_route(
// Replicate transferable state events to the new room // Replicate transferable state events to the new room
for event_type in transferable_state_events { for event_type in transferable_state_events {
let event_content = let event_content = match services()
match services() .rooms
.rooms .state_accessor
.state_accessor .room_state_get(&body.room_id, &event_type, "")?
.room_state_get(&body.room_id, &event_type, "")? {
{ Some(v) => v.content.clone(),
Some(v) => v.content.clone(), // Skipping missing events.
// Skipping missing events. None => continue,
None => continue, };
};
services() services()
.rooms .rooms
@ -820,30 +860,39 @@ pub(crate) async fn upgrade_room_route(
.local_aliases_for_room(&body.room_id) .local_aliases_for_room(&body.room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
{ {
services() services().rooms.alias.set_alias(&alias, &replacement_room)?;
.rooms
.alias
.set_alias(&alias, &replacement_room)?;
} }
// Get the old room power levels // Get the old room power levels
let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( let mut power_levels_event_content: RoomPowerLevelsEventContent =
services() serde_json::from_str(
.rooms services()
.state_accessor .rooms
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .state_accessor
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .room_state_get(
.content &body.room_id,
.get(), &StateEventType::RoomPowerLevels,
) "",
.map_err(|_| Error::bad_database("Invalid room event in database."))?; )?
.ok_or_else(|| {
Error::bad_database(
"Found room without m.room.create event.",
)
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
// Setting events_default and invite to the greater of 50 and users_default + 1 // Setting events_default and invite to the greater of 50 and users_default
let new_level = max(int!(50), power_levels_event_content.users_default + int!(1)); // + 1
let new_level =
max(int!(50), power_levels_event_content.users_default + int!(1));
power_levels_event_content.events_default = new_level; power_levels_event_content.events_default = new_level;
power_levels_event_content.invite = new_level; power_levels_event_content.invite = new_level;
// Modify the power levels in the old room to prevent sending of events and inviting new users // Modify the power levels in the old room to prevent sending of events and
// inviting new users
let _ = services() let _ = services()
.rooms .rooms
.timeline .timeline
@ -865,5 +914,7 @@ pub(crate) async fn upgrade_room_route(
drop(state_lock); drop(state_lock);
// Return the replacement room id // Return the replacement room id
Ok(upgrade_room::v3::Response { replacement_room }) Ok(upgrade_room::v3::Response {
replacement_room,
})
} }

View file

@ -1,22 +1,27 @@
use crate::{services, Error, Result, Ruma}; use std::collections::BTreeMap;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
search::search_events::{ search::search_events::{
self, self,
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, v3::{
EventContextResult, ResultCategories, ResultRoomEvents,
SearchResult,
},
}, },
}, },
uint, uint,
}; };
use std::collections::BTreeMap; use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/search` /// # `POST /_matrix/client/r0/search`
/// ///
/// Searches rooms for messages. /// Searches rooms for messages.
/// ///
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility) /// - Only works if the user is currently joined to the room (TODO: Respect
/// history visibility)
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn search_events_route( pub(crate) async fn search_events_route(
body: Ruma<search_events::v3::Request>, body: Ruma<search_events::v3::Request>,
@ -46,11 +51,7 @@ pub(crate) async fn search_events_route(
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in room_ids { for room_id in room_ids {
if !services() if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
.rooms
.state_cache
.is_joined(sender_user, &room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -102,7 +103,11 @@ pub(crate) async fn search_events_route(
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .user_can_see_event(
sender_user,
&pdu.room_id,
&pdu.event_id,
)
.unwrap_or(false) .unwrap_or(false)
}) })
.map(|pdu| pdu.to_room_event()) .map(|pdu| pdu.to_room_event())

View file

@ -1,5 +1,3 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -17,6 +15,9 @@ use ruma::{
use serde::Deserialize; use serde::Deserialize;
use tracing::{info, warn}; use tracing::{info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Claims { struct Claims {
sub: String, sub: String,
@ -24,30 +25,36 @@ struct Claims {
/// # `GET /_matrix/client/r0/login` /// # `GET /_matrix/client/r0/login`
/// ///
/// Get the supported login types of this server. One of these should be used as the `type` field /// Get the supported login types of this server. One of these should be used as
/// when logging in. /// the `type` field when logging in.
pub(crate) async fn get_login_types_route( pub(crate) async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>, _body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> { ) -> Result<get_login_types::v3::Response> {
Ok(get_login_types::v3::Response::new(vec![ Ok(get_login_types::v3::Response::new(vec![
get_login_types::v3::LoginType::Password(PasswordLoginType::default()), get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()), get_login_types::v3::LoginType::ApplicationService(
ApplicationServiceLoginType::default(),
),
])) ]))
} }
/// # `POST /_matrix/client/r0/login` /// # `POST /_matrix/client/r0/login`
/// ///
/// Authenticates the user and returns an access token it can use in subsequent requests. /// Authenticates the user and returns an access token it can use in subsequent
/// requests.
/// ///
/// - The user needs to authenticate using their password (or if enabled using a json web token) /// - The user needs to authenticate using their password (or if enabled using a
/// json web token)
/// - If `device_id` is known: invalidates old access token of that device /// - If `device_id` is known: invalidates old access token of that device
/// - If `device_id` is unknown: creates a new device /// - If `device_id` is unknown: creates a new device
/// - Returns access token that is associated with the user and device /// - Returns access token that is associated with the user and device
/// ///
/// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to see /// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to
/// supported login types. /// see supported login types.
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> { pub(crate) async fn login_route(
body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
// To allow deprecated login methods // To allow deprecated login methods
#![allow(deprecated)] #![allow(deprecated)]
// Validate login method // Validate login method
@ -59,18 +66,29 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
user, user,
.. ..
}) => { }) => {
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { let user_id =
UserId::parse_with_server_name( if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) =
user_id.to_lowercase(), identifier
services().globals.server_name(), {
) UserId::parse_with_server_name(
} else if let Some(user) = user { user_id.to_lowercase(),
UserId::parse(user) services().globals.server_name(),
} else { )
warn!("Bad login type: {:?}", &body.login_info); } else if let Some(user) = user {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); UserId::parse(user)
} } else {
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Bad login type.",
));
}
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await { if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -79,13 +97,12 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
)); ));
} }
let hash = services() let hash = services().users.password_hash(&user_id)?.ok_or(
.users Error::BadRequest(
.password_hash(&user_id)?
.ok_or(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Wrong username or password.", "Wrong username or password.",
))?; ),
)?;
if hash.is_empty() { if hash.is_empty() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -94,7 +111,9 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
)); ));
} }
let hash_matches = argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes())
.unwrap_or(false);
if !hash_matches { if !hash_matches {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -105,20 +124,34 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
user_id user_id
} }
login::v3::LoginInfo::Token(login::v3::Token { token }) => { login::v3::LoginInfo::Token(login::v3::Token {
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { token,
}) => {
if let Some(jwt_decoding_key) =
services().globals.jwt_decoding_key()
{
let token = jsonwebtoken::decode::<Claims>( let token = jsonwebtoken::decode::<Claims>(
token, token,
jwt_decoding_key, jwt_decoding_key,
&jsonwebtoken::Validation::default(), &jsonwebtoken::Validation::default(),
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; .map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Token is invalid.",
)
})?;
let username = token.claims.sub.to_lowercase(); let username = token.claims.sub.to_lowercase();
let user_id = let user_id = UserId::parse_with_server_name(
UserId::parse_with_server_name(username, services().globals.server_name()) username,
.map_err(|_| { services().globals.server_name(),
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") )
})?; .map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await { if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -131,26 +164,40 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Unknown, ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).", "Token login is not supported (server has no jwt decoding \
key).",
)); ));
} }
} }
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { login::v3::LoginInfo::ApplicationService(
identifier, login::v3::ApplicationService {
user, identifier,
}) => { user,
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { },
UserId::parse_with_server_name( ) => {
user_id.to_lowercase(), let user_id =
services().globals.server_name(), if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) =
) identifier
} else if let Some(user) = user { {
UserId::parse(user) UserId::parse_with_server_name(
} else { user_id.to_lowercase(),
warn!("Bad login type: {:?}", &body.login_info); services().globals.server_name(),
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); )
} } else if let Some(user) = user {
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Bad login type.",
));
}
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if let Some(info) = &body.appservice_info { if let Some(info) = &body.appservice_info {
if !info.is_user_match(&user_id) { if !info.is_user_match(&user_id) {
@ -225,12 +272,16 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
/// Log out the current device. /// Log out the current device.
/// ///
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> { pub(crate) async fn logout_route(
body: Ruma<logout::v3::Request>,
) -> Result<logout::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
if let Some(info) = &body.appservice_info { if let Some(info) = &body.appservice_info {
if !info.is_user_match(sender_user) { if !info.is_user_match(sender_user) {
@ -251,12 +302,13 @@ pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logo
/// Log out all devices of this user. /// Log out all devices of this user.
/// ///
/// - Invalidates all access tokens /// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events /// - Forgets all to-device events
/// - Triggers device list updates /// - Triggers device list updates
/// ///
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](logout_route) /// Note: This is equivalent to calling [`GET
/// from each device of this user. /// /_matrix/client/r0/logout`](logout_route) from each device of this user.
pub(crate) async fn logout_all_route( pub(crate) async fn logout_all_route(
body: Ruma<logout_all::v3::Request>, body: Ruma<logout_all::v3::Request>,
) -> Result<logout_all::v3::Response> { ) -> Result<logout_all::v3::Response> {

View file

@ -1,19 +1,18 @@
use crate::{services, Result, Ruma};
use ruma::{api::client::space::get_hierarchy, uint}; use ruma::{api::client::space::get_hierarchy, uint};
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`
/// ///
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space. /// Paginates over the space tree in a depth-first manner to locate child rooms
/// of a given space.
pub(crate) async fn get_hierarchy_route( pub(crate) async fn get_hierarchy_route(
body: Ruma<get_hierarchy::v1::Request>, body: Ruma<get_hierarchy::v1::Request>,
) -> Result<get_hierarchy::v1::Response> { ) -> Result<get_hierarchy::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let skip = body let skip =
.from body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
.as_ref()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
let limit = body let limit = body
.limit .limit
@ -23,8 +22,10 @@ pub(crate) async fn get_hierarchy_route(
.expect("0-100 should fit in usize"); .expect("0-100 should fit in usize");
// Plus one to skip the space room itself // Plus one to skip the space room itself
let max_depth = usize::try_from(body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3))) let max_depth = usize::try_from(
.expect("0-10 should fit in usize") body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3)),
)
.expect("0-10 should fit in usize")
+ 1; + 1;
services() services()

View file

@ -1,25 +1,30 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
state::{get_state_events, get_state_events_for_key, send_state_event}, state::{get_state_events, get_state_events_for_key, send_state_event},
}, },
events::{ events::{
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType, room::canonical_alias::RoomCanonicalAliasEventContent,
AnyStateEventContent, StateEventType,
}, },
serde::Raw, serde::Raw,
EventId, RoomId, UserId, EventId, RoomId, UserId,
}; };
use tracing::log::warn; use tracing::log::warn;
use crate::{
service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse,
};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
/// ///
/// Sends a state event into the room. /// Sends a state event into the room.
/// ///
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect /// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_key_route( pub(crate) async fn send_state_event_for_key_route(
body: Ruma<send_state_event::v3::Request>, body: Ruma<send_state_event::v3::Request>,
@ -37,7 +42,9 @@ pub(crate) async fn send_state_event_for_key_route(
.await?; .await?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }) Ok(send_state_event::v3::Response {
event_id,
})
} }
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
@ -45,7 +52,8 @@ pub(crate) async fn send_state_event_for_key_route(
/// Sends a state event into the room. /// Sends a state event into the room.
/// ///
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect /// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_empty_key_route( pub(crate) async fn send_state_event_for_empty_key_route(
body: Ruma<send_state_event::v3::Request>, body: Ruma<send_state_event::v3::Request>,
@ -53,7 +61,9 @@ pub(crate) async fn send_state_event_for_empty_key_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Forbid m.room.encryption if encryption is disabled // Forbid m.room.encryption if encryption is disabled
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { if body.event_type == StateEventType::RoomEncryption
&& !services().globals.allow_encryption()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Encryption has been disabled", "Encryption has been disabled",
@ -70,14 +80,18 @@ pub(crate) async fn send_state_event_for_empty_key_route(
.await?; .await?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }.into()) Ok(send_state_event::v3::Response {
event_id,
}
.into())
} }
/// # `GET /_matrix/client/r0/rooms/{roomid}/state` /// # `GET /_matrix/client/r0/rooms/{roomid}/state`
/// ///
/// Get all state events for a room. /// Get all state events for a room.
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_route( pub(crate) async fn get_state_events_route(
body: Ruma<get_state_events::v3::Request>, body: Ruma<get_state_events::v3::Request>,
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
@ -110,7 +124,8 @@ pub(crate) async fn get_state_events_route(
/// ///
/// Get single state event of a room. /// Get single state event of a room.
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_for_key_route( pub(crate) async fn get_state_events_for_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
@ -140,8 +155,9 @@ pub(crate) async fn get_state_events_for_key_route(
})?; })?;
Ok(get_state_events_for_key::v3::Response { Ok(get_state_events_for_key::v3::Response {
content: serde_json::from_str(event.content.get()) content: serde_json::from_str(event.content.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid event content in database"))?, Error::bad_database("Invalid event content in database")
})?,
}) })
} }
@ -149,7 +165,8 @@ pub(crate) async fn get_state_events_for_key_route(
/// ///
/// Get single state event of a room. /// Get single state event of a room.
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_for_empty_key_route( pub(crate) async fn get_state_events_for_empty_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
@ -179,8 +196,9 @@ pub(crate) async fn get_state_events_for_empty_key_route(
})?; })?;
Ok(get_state_events_for_key::v3::Response { Ok(get_state_events_for_key::v3::Response {
content: serde_json::from_str(event.content.get()) content: serde_json::from_str(event.content.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid event content in database"))?, Error::bad_database("Invalid event content in database")
})?,
} }
.into()) .into())
} }
@ -194,10 +212,11 @@ async fn send_state_event_for_key_helper(
) -> Result<Arc<EventId>> { ) -> Result<Arc<EventId>> {
let sender_user = sender; let sender_user = sender;
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it // TODO: Review this check, error if event is unparsable, use event type,
// previously existed // allow alias if it previously existed
if let Ok(canonical_alias) = if let Ok(canonical_alias) = serde_json::from_str::<
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) RoomCanonicalAliasEventContent,
>(json.json().get())
{ {
let mut aliases = canonical_alias.alt_aliases.clone(); let mut aliases = canonical_alias.alt_aliases.clone();
@ -216,8 +235,8 @@ async fn send_state_event_for_key_helper(
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You are only allowed to send canonical_alias \ "You are only allowed to send canonical_alias events when \
events when it's aliases already exists", it's aliases already exists",
)); ));
} }
} }
@ -240,7 +259,8 @@ async fn send_state_event_for_key_helper(
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: event_type.to_string().into(), event_type: event_type.to_string().into(),
content: serde_json::from_str(json.json().get()).expect("content is valid json"), content: serde_json::from_str(json.json().get())
.expect("content is valid json"),
unsigned: None, unsigned: None,
state_key: Some(state_key), state_key: Some(state_key),
redacts: None, redacts: None,

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,5 @@
use crate::{services, Error, Result, Ruma}; use std::collections::BTreeMap;
use ruma::{ use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags}, api::client::tag::{create_tag, delete_tag, get_tags},
events::{ events::{
@ -6,7 +7,8 @@ use ruma::{
RoomAccountDataEventType, RoomAccountDataEventType,
}, },
}; };
use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
/// ///
@ -33,8 +35,9 @@ pub(crate) async fn update_tag_route(
}) })
}, },
|e| { |e| {
serde_json::from_str(e.get()) serde_json::from_str(e.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Error::bad_database("Invalid account data event in db.")
})
}, },
)?; )?;
@ -78,8 +81,9 @@ pub(crate) async fn delete_tag_route(
}) })
}, },
|e| { |e| {
serde_json::from_str(e.get()) serde_json::from_str(e.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Error::bad_database("Invalid account data event in db.")
})
}, },
)?; )?;
@ -120,8 +124,9 @@ pub(crate) async fn get_tags_route(
}) })
}, },
|e| { |e| {
serde_json::from_str(e.get()) serde_json::from_str(e.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Error::bad_database("Invalid account data event in db.")
})
}, },
)?; )?;

View file

@ -1,7 +1,8 @@
use crate::{Result, Ruma}; use std::collections::BTreeMap;
use ruma::api::client::thirdparty::get_protocols; use ruma::api::client::thirdparty::get_protocols;
use std::collections::BTreeMap; use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/thirdparty/protocols` /// # `GET /_matrix/client/r0/thirdparty/protocols`
/// ///

View file

@ -9,11 +9,8 @@ pub(crate) async fn get_threads_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit =
.limit body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
.and_then(|l| l.try_into().ok())
.unwrap_or(10)
.min(100);
let from = if let Some(from) = &body.from { let from = if let Some(from) = &body.from {
from.parse() from.parse()

View file

@ -1,6 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::{ api::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
@ -9,6 +8,8 @@ use ruma::{
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
}; };
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
/// ///
/// Send a to-device event to a set of client devices. /// Send a to-device event to a set of client devices.
@ -29,7 +30,8 @@ pub(crate) async fn send_event_to_device_route(
for (target_user_id, map) in &body.messages { for (target_user_id, map) in &body.messages {
for (target_device_id_maybe, event) in map { for (target_device_id_maybe, event) in map {
if target_user_id.server_name() != services().globals.server_name() { if target_user_id.server_name() != services().globals.server_name()
{
let mut map = BTreeMap::new(); let mut map = BTreeMap::new();
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
@ -38,14 +40,16 @@ pub(crate) async fn send_event_to_device_route(
services().sending.send_reliable_edu( services().sending.send_reliable_edu(
target_user_id.server_name(), target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( serde_json::to_vec(
DirectDeviceContent { &federation::transactions::edu::Edu::DirectToDevice(
sender: sender_user.clone(), DirectDeviceContent {
ev_type: body.event_type.clone(), sender: sender_user.clone(),
message_id: count.to_string().into(), ev_type: body.event_type.clone(),
messages, message_id: count.to_string().into(),
}, messages,
)) },
),
)
.expect("DirectToDevice EDU can be serialized"), .expect("DirectToDevice EDU can be serialized"),
count, count,
)?; )?;
@ -61,20 +65,28 @@ pub(crate) async fn send_event_to_device_route(
target_device_id, target_device_id,
&body.event_type.to_string(), &body.event_type.to_string(),
event.deserialize_as().map_err(|_| { event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
})?, })?,
)?; )?;
} }
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) { for target_device_id in
services().users.all_device_ids(target_user_id)
{
services().users.add_to_device_event( services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id?, &target_device_id?,
&body.event_type.to_string(), &body.event_type.to_string(),
event.deserialize_as().map_err(|_| { event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
})?, })?,
)?; )?;
} }
@ -84,9 +96,12 @@ pub(crate) async fn send_event_to_device_route(
} }
// Save transaction id with empty data // Save transaction id with empty data
services() services().transaction_ids.add_txnid(
.transaction_ids sender_user,
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; sender_device,
&body.txn_id,
&[],
)?;
Ok(send_event_to_device::v3::Response {}) Ok(send_event_to_device::v3::Response {})
} }

View file

@ -1,6 +1,7 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
use crate::{services, utils, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
/// ///
/// Sets the typing state of the sender user. /// Sets the typing state of the sender user.
@ -11,11 +12,7 @@ pub(crate) async fn create_typing_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You are not in this room.", "You are not in this room.",

View file

@ -6,14 +6,16 @@ use crate::{Result, Ruma};
/// # `GET /_matrix/client/versions` /// # `GET /_matrix/client/versions`
/// ///
/// Get the versions of the specification and unstable features supported by this server. /// Get the versions of the specification and unstable features supported by
/// this server.
/// ///
/// - Versions take the form MAJOR.MINOR.PATCH /// - Versions take the form MAJOR.MINOR.PATCH
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value /// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
/// - Unstable features are namespaced and may include version information in their name /// - Unstable features are namespaced and may include version information in
/// their name
/// ///
/// Note: Unstable features are used while developing new features. Clients should avoid using /// Note: Unstable features are used while developing new features. Clients
/// unstable features in their stable releases /// should avoid using unstable features in their stable releases
pub(crate) async fn get_supported_versions_route( pub(crate) async fn get_supported_versions_route(
_body: Ruma<get_supported_versions::Request>, _body: Ruma<get_supported_versions::Request>,
) -> Result<get_supported_versions::Response> { ) -> Result<get_supported_versions::Response> {
@ -27,7 +29,10 @@ pub(crate) async fn get_supported_versions_route(
"v1.4".to_owned(), "v1.4".to_owned(),
"v1.5".to_owned(), "v1.5".to_owned(),
], ],
unstable_features: BTreeMap::from_iter([("org.matrix.e2e_cross_signing".to_owned(), true)]), unstable_features: BTreeMap::from_iter([(
"org.matrix.e2e_cross_signing".to_owned(),
true,
)]),
}; };
Ok(resp) Ok(resp)

View file

@ -1,4 +1,3 @@
use crate::{services, Result, Ruma};
use ruma::{ use ruma::{
api::client::user_directory::search_users, api::client::user_directory::search_users,
events::{ events::{
@ -7,11 +6,14 @@ use ruma::{
}, },
}; };
use crate::{services, Result, Ruma};
/// # `POST /_matrix/client/r0/user_directory/search` /// # `POST /_matrix/client/r0/user_directory/search`
/// ///
/// Searches all known users for a match. /// Searches all known users for a match.
/// ///
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// - Hides any local users that aren't in any public rooms (i.e. those that
/// have the join rule set to public)
/// and don't share a room with the sender /// and don't share a room with the sender
pub(crate) async fn search_users_route( pub(crate) async fn search_users_route(
body: Ruma<search_users::v3::Request>, body: Ruma<search_users::v3::Request>,
@ -38,8 +40,7 @@ pub(crate) async fn search_users_route(
.display_name .display_name
.as_ref() .as_ref()
.filter(|name| { .filter(|name| {
name.to_lowercase() name.to_lowercase().contains(&body.search_term.to_lowercase())
.contains(&body.search_term.to_lowercase())
}) })
.is_some(); .is_some();
@ -62,10 +63,12 @@ pub(crate) async fn search_users_route(
.room_state_get(&room, &StateEventType::RoomJoinRules, "") .room_state_get(&room, &StateEventType::RoomJoinRules, "")
.map_or(false, |event| { .map_or(false, |event| {
event.map_or(false, |event| { event.map_or(false, |event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get()).map_or(
.map_or(false, |r: RoomJoinRulesEventContent| { false,
|r: RoomJoinRulesEventContent| {
r.join_rule == JoinRule::Public r.join_rule == JoinRule::Public
}) },
)
}) })
}) })
}); });
@ -96,5 +99,8 @@ pub(crate) async fn search_users_route(
let results = users.by_ref().take(limit).collect(); let results = users.by_ref().take(limit).collect();
let limited = users.next().is_some(); let limited = users.next().is_some();
Ok(search_users::v3::Response { results, limited }) Ok(search_users::v3::Response {
results,
limited,
})
} }

View file

@ -1,9 +1,11 @@
use crate::{services, Result, Ruma}; use std::time::{Duration, SystemTime};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
use sha1::Sha1; use sha1::Sha1;
use std::time::{Duration, SystemTime};
use crate::{services, Result, Ruma};
type HmacSha1 = Hmac<Sha1>; type HmacSha1 = Hmac<Sha1>;
@ -24,7 +26,8 @@ pub(crate) async fn turn_server_route(
) )
} else { } else {
let expiry = SecondsSinceUnixEpoch::from_system_time( let expiry = SecondsSinceUnixEpoch::from_system_time(
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), SystemTime::now()
+ Duration::from_secs(services().globals.turn_ttl()),
) )
.expect("time is valid"); .expect("time is valid");
@ -34,7 +37,8 @@ pub(crate) async fn turn_server_route(
.expect("HMAC can take key of any size"); .expect("HMAC can take key of any size");
mac.update(username.as_bytes()); mac.update(username.as_bytes());
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes()); let password: String =
general_purpose::STANDARD.encode(mac.finalize().into_bytes());
(username, password) (username, password)
}; };

View file

@ -1,10 +1,12 @@
use crate::{service::appservice::RegistrationInfo, Error};
use ruma::{
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
OwnedUserId,
};
use std::ops::Deref; use std::ops::Deref;
use ruma::{
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId,
OwnedServerName, OwnedUserId,
};
use crate::{service::appservice::RegistrationInfo, Error};
mod axum; mod axum;
/// Extractor for Ruma request structs /// Extractor for Ruma request structs

View file

@ -3,7 +3,9 @@ use std::{collections::BTreeMap, iter::FromIterator, str};
use axum::{ use axum::{
async_trait, async_trait,
body::{Full, HttpBody}, body::{Full, HttpBody},
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, extract::{
rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader,
},
headers::{ headers::{
authorization::{Bearer, Credentials}, authorization::{Bearer, Credentials},
Authorization, Authorization,
@ -14,7 +16,9 @@ use axum::{
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, api::{
client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse,
},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
@ -41,7 +45,10 @@ where
type Rejection = Error; type Rejection = Error;
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { async fn from_request(
req: Request<B>,
_state: &S,
) -> Result<Self, Self::Rejection> {
#[derive(Deserialize)] #[derive(Deserialize)]
struct QueryParams { struct QueryParams {
access_token: Option<String>, access_token: Option<String>,
@ -51,22 +58,23 @@ where
let (mut parts, mut body) = match req.with_limited_body() { let (mut parts, mut body) = match req.with_limited_body() {
Ok(limited_req) => { Ok(limited_req) => {
let (parts, body) = limited_req.into_parts(); let (parts, body) = limited_req.into_parts();
let body = to_bytes(body) let body = to_bytes(body).await.map_err(|_| {
.await Error::BadRequest(ErrorKind::MissingToken, "Missing token.")
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; })?;
(parts, body) (parts, body)
} }
Err(original_req) => { Err(original_req) => {
let (parts, body) = original_req.into_parts(); let (parts, body) = original_req.into_parts();
let body = to_bytes(body) let body = to_bytes(body).await.map_err(|_| {
.await Error::BadRequest(ErrorKind::MissingToken, "Missing token.")
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; })?;
(parts, body) (parts, body)
} }
}; };
let metadata = T::METADATA; let metadata = T::METADATA;
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?; let auth_header: Option<TypedHeader<Authorization<Bearer>>> =
parts.extract().await?;
let path_params: Path<Vec<String>> = parts.extract().await?; let path_params: Path<Vec<String>> = parts.extract().await?;
let query = parts.uri.query().unwrap_or_default(); let query = parts.uri.query().unwrap_or_default();
@ -87,9 +95,13 @@ where
}; };
let token = if let Some(token) = token { let token = if let Some(token) = token {
if let Some(reg_info) = services().appservice.find_from_token(token).await { if let Some(reg_info) =
services().appservice.find_from_token(token).await
{
Token::Appservice(Box::new(reg_info.clone())) Token::Appservice(Box::new(reg_info.clone()))
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { } else if let Some((user_id, device_id)) =
services().users.find_from_token(token)?
{
Token::User((user_id, OwnedDeviceId::from(device_id))) Token::User((user_id, OwnedDeviceId::from(device_id)))
} else { } else {
Token::Invalid Token::Invalid
@ -98,13 +110,16 @@ where
Token::None Token::None
}; };
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body =
serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let (sender_user, sender_device, sender_servername, appservice_info) = let (sender_user, sender_device, sender_servername, appservice_info) =
match (metadata.authentication, token) { match (metadata.authentication, token) {
(_, Token::Invalid) => { (_, Token::Invalid) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false }, ErrorKind::UnknownToken {
soft_logout: false,
},
"Unknown access token.", "Unknown access token.",
)) ))
} }
@ -121,7 +136,10 @@ where
UserId::parse, UserId::parse,
) )
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?; })?;
if !info.is_user_match(&user_id) { if !info.is_user_match(&user_id) {
@ -153,7 +171,9 @@ where
)); ));
} }
( (
AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, AuthScheme::AccessToken
| AuthScheme::AccessTokenOptional
| AuthScheme::None,
Token::User((user_id, device_id)), Token::User((user_id, device_id)),
) => (Some(user_id), Some(device_id), None, None), ) => (Some(user_id), Some(device_id), None, None),
(AuthScheme::ServerSignatures, Token::None) => { (AuthScheme::ServerSignatures, Token::None) => {
@ -161,7 +181,10 @@ where
.extract::<TypedHeader<Authorization<XMatrix>>>() .extract::<TypedHeader<Authorization<XMatrix>>>()
.await .await
.map_err(|e| { .map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e); warn!(
"Missing or invalid Authorization header: {}",
e
);
let msg = match e.reason() { let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => { TypedHeaderRejectionReason::Missing => {
@ -189,7 +212,9 @@ where
let mut request_map = BTreeMap::from_iter([ let mut request_map = BTreeMap::from_iter([
( (
"method".to_owned(), "method".to_owned(),
CanonicalJsonValue::String(parts.method.to_string()), CanonicalJsonValue::String(
parts.method.to_string(),
),
), ),
( (
"uri".to_owned(), "uri".to_owned(),
@ -197,12 +222,18 @@ where
), ),
( (
"origin".to_owned(), "origin".to_owned(),
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), CanonicalJsonValue::String(
x_matrix.origin.as_str().to_owned(),
),
), ),
( (
"destination".to_owned(), "destination".to_owned(),
CanonicalJsonValue::String( CanonicalJsonValue::String(
services().globals.server_name().as_str().to_owned(), services()
.globals
.server_name()
.as_str()
.to_owned(),
), ),
), ),
( (
@ -212,13 +243,17 @@ where
]); ]);
if let Some(json_body) = &json_body { if let Some(json_body) = &json_body {
request_map.insert("content".to_owned(), json_body.clone()); request_map
.insert("content".to_owned(), json_body.clone());
}; };
let keys_result = services() let keys_result = services()
.rooms .rooms
.event_handler .event_handler
.fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.clone()]) .fetch_signing_keys(
&x_matrix.origin,
vec![x_matrix.key.clone()],
)
.await; .await;
let keys = match keys_result { let keys = match keys_result {
@ -232,22 +267,29 @@ where
} }
}; };
let pub_key_map = let pub_key_map = BTreeMap::from_iter([(
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); x_matrix.origin.as_str().to_owned(),
keys,
)]);
match ruma::signatures::verify_json(&pub_key_map, &request_map) { match ruma::signatures::verify_json(
&pub_key_map,
&request_map,
) {
Ok(()) => (None, None, Some(x_matrix.origin), None), Ok(()) => (None, None, Some(x_matrix.origin), None),
Err(e) => { Err(e) => {
warn!( warn!(
"Failed to verify json request from {}: {}\n{:?}", "Failed to verify json request from {}: \
{}\n{:?}",
x_matrix.origin, e, request_map x_matrix.origin, e, request_map
); );
if parts.uri.to_string().contains('@') { if parts.uri.to_string().contains('@') {
warn!( warn!(
"Request uri contained '@' character. Make sure your \ "Request uri contained '@' character. \
reverse proxy gives Grapevine the raw uri (apache: use \ Make sure your reverse proxy gives \
nocanon)" Grapevine the raw uri (apache: use \
nocanon)"
); );
} }
@ -264,27 +306,36 @@ where
| AuthScheme::AccessTokenOptional, | AuthScheme::AccessTokenOptional,
Token::None, Token::None,
) => (None, None, None, None), ) => (None, None, None, None),
(AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { (
AuthScheme::ServerSignatures,
Token::Appservice(_) | Token::User(_),
) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Unauthorized, ErrorKind::Unauthorized,
"Only server signatures should be used on this endpoint.", "Only server signatures should be used on this \
endpoint.",
)); ));
} }
(AuthScheme::AppserviceToken, Token::User(_)) => { (AuthScheme::AppserviceToken, Token::User(_)) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Unauthorized, ErrorKind::Unauthorized,
"Only appservice access tokens should be used on this endpoint.", "Only appservice access tokens should be used on this \
endpoint.",
)); ));
} }
}; };
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); let mut http_request =
http::Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers; *http_request.headers_mut().unwrap() = parts.headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| { let user_id = sender_user.clone().unwrap_or_else(|| {
UserId::parse_with_server_name("", services().globals.server_name()) UserId::parse_with_server_name(
.expect("we know this is valid") "",
services().globals.server_name(),
)
.expect("we know this is valid")
}); });
let uiaa_request = json_body let uiaa_request = json_body
@ -300,14 +351,17 @@ where
) )
}); });
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { if let Some(CanonicalJsonValue::Object(initial_request)) =
uiaa_request
{
for (key, value) in initial_request { for (key, value) in initial_request {
json_body.entry(key).or_insert(value); json_body.entry(key).or_insert(value);
} }
} }
let mut buf = BytesMut::new().writer(); let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); serde_json::to_writer(&mut buf, json_body)
.expect("value serialization can't fail");
body = buf.into_inner().freeze(); body = buf.into_inner().freeze();
} }
@ -315,11 +369,15 @@ where
debug!("{:?}", http_request); debug!("{:?}", http_request);
let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { let body = T::try_from_http_request(http_request, &path_params)
warn!("try_from_http_request failed: {:?}", e); .map_err(|e| {
debug!("JSON body: {:?}", json_body); warn!("try_from_http_request failed: {:?}", e);
Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") debug!("JSON body: {:?}", json_body);
})?; Error::BadRequest(
ErrorKind::BadJson,
"Failed to deserialize request.",
)
})?;
Ok(Ruma { Ok(Ruma {
body, body,
@ -345,7 +403,8 @@ impl Credentials for XMatrix {
fn decode(value: &http::HeaderValue) -> Option<Self> { fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!( debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "), value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", "HeaderValue to decode should start with \"X-Matrix ..\", \
received = {value:?}",
); );
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
@ -359,8 +418,9 @@ impl Credentials for XMatrix {
for entry in parameters.split_terminator(',') { for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?; let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not in the spec, // It's not at all clear why some fields are quoted and others not
// let's simply accept either form for every field. // in the spec, let's simply accept either form for
// every field.
let value = value let value = value
.strip_prefix('"') .strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"')) .and_then(|rest| rest.strip_suffix('"'))

File diff suppressed because it is too large Load diff

View file

@ -108,7 +108,10 @@ impl Config {
} }
if was_deprecated { if was_deprecated {
warn!("Read grapevine documentation and check your configuration if any new configuration parameters should be adjusted"); warn!(
"Read grapevine documentation and check your configuration if \
any new configuration parameters should be adjusted"
);
} }
} }
} }

View file

@ -24,9 +24,10 @@ use crate::Result;
/// ## Include vs. Exclude /// ## Include vs. Exclude
/// If include is an empty list, it is assumed to be `["*"]`. /// If include is an empty list, it is assumed to be `["*"]`.
/// ///
/// If a domain matches both the exclude and include list, the proxy will only be used if it was /// If a domain matches both the exclude and include list, the proxy will only
/// included because of a more specific rule than it was excluded. In the above example, the proxy /// be used if it was included because of a more specific rule than it was
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. /// excluded. In the above example, the proxy would be used for
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[derive(Default)] #[derive(Default)]
@ -43,7 +44,9 @@ impl ProxyConfig {
pub(crate) fn to_proxy(&self) -> Result<Option<Proxy>> { pub(crate) fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() { Ok(match self.clone() {
ProxyConfig::None => None, ProxyConfig::None => None,
ProxyConfig::Global { url } => Some(Proxy::all(url)?), ProxyConfig::Global {
url,
} => Some(Proxy::all(url)?),
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
// first matching proxy // first matching proxy
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned()
@ -112,25 +115,32 @@ impl WildCardedDomain {
WildCardedDomain::Exact(d) => domain == d, WildCardedDomain::Exact(d) => domain == d,
} }
} }
pub(crate) fn more_specific_than(&self, other: &Self) -> bool { pub(crate) fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) { match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true, (_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => {
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { other.matches(a)
a != b && a.ends_with(b)
} }
(
WildCardedDomain::WildCarded(a),
WildCardedDomain::WildCarded(b),
) => a != b && a.ends_with(b),
_ => false, _ => false,
} }
} }
} }
impl std::str::FromStr for WildCardedDomain { impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible; type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation? // maybe do some domain validation?
Ok(s.strip_prefix("*.") Ok(s.strip_prefix("*.")
.map(|x| WildCardedDomain::WildCarded(x.to_owned())) .map(|x| WildCardedDomain::WildCarded(x.to_owned()))
.or_else(|| (s == "*").then(|| WildCardedDomain::WildCarded(String::new()))) .or_else(|| {
(s == "*").then(|| WildCardedDomain::WildCarded(String::new()))
})
.unwrap_or_else(|| WildCardedDomain::Exact(s.to_owned()))) .unwrap_or_else(|| WildCardedDomain::Exact(s.to_owned())))
} }
} }

View file

@ -1,23 +1,6 @@
pub(crate) mod abstraction; pub(crate) mod abstraction;
pub(crate) mod key_value; pub(crate) mod key_value;
use crate::{
service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services,
SERVICES,
};
use abstraction::{KeyValueDatabaseEngine, KvTree};
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
UserId,
};
use std::{ use std::{
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
fs, fs,
@ -27,8 +10,25 @@ use std::{
sync::{Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock},
}; };
use abstraction::{KeyValueDatabaseEngine, KvTree};
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId,
OwnedUserId, RoomId, UserId,
};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{
service::rooms::timeline::PduCount, services, utils, Config, Error,
PduEvent, Result, Services, SERVICES,
};
pub(crate) struct KeyValueDatabase { pub(crate) struct KeyValueDatabase {
db: Arc<dyn KeyValueDatabaseEngine>, db: Arc<dyn KeyValueDatabaseEngine>,
@ -74,8 +74,9 @@ pub(crate) struct KeyValueDatabase {
// Trees "owned" by `self::key_value::uiaa` // Trees "owned" by `self::key_value::uiaa`
// User-interactive authentication // User-interactive authentication
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>,
pub(super) userdevicesessionid_uiaarequest: pub(super) userdevicesessionid_uiaarequest: RwLock<
RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>, BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>,
>,
// Trees "owned" by `self::key_value::rooms::edus` // Trees "owned" by `self::key_value::rooms::edus`
// ReadReceiptId = RoomId + Count + UserId // ReadReceiptId = RoomId + Count + UserId
@ -169,13 +170,15 @@ pub(crate) struct KeyValueDatabase {
pub(super) statehash_shortstatehash: Arc<dyn KvTree>, pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
// StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 +
// (shortstatekey+shorteventid--)
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, pub(super) shortstatehash_statediff: Arc<dyn KvTree>,
pub(super) shorteventid_authchain: Arc<dyn KvTree>, pub(super) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU. /// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. /// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub(super) eventid_outlierpdu: Arc<dyn KvTree>, pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
pub(super) softfailedeventids: Arc<dyn KvTree>, pub(super) softfailedeventids: Arc<dyn KvTree>,
@ -214,10 +217,12 @@ pub(crate) struct KeyValueDatabase {
// EduCount: Count of last EDU sync // EduCount: Count of last EDU sync
pub(super) servername_educount: Arc<dyn KvTree>, pub(super) servername_educount: Arc<dyn KvTree>,
// ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id
// (for edus), Data = EDU content
pub(super) servernameevent_data: Arc<dyn KvTree>, pub(super) servernameevent_data: Arc<dyn KvTree>,
// ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for
// edus), Data = EDU content
pub(super) servercurrentevent_data: Arc<dyn KvTree>, pub(super) servercurrentevent_data: Arc<dyn KvTree>,
// Trees "owned" by `self::key_value::appservice` // Trees "owned" by `self::key_value::appservice`
@ -231,10 +236,14 @@ pub(crate) struct KeyValueDatabase {
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>, pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>, pub(super) statekeyshort_cache:
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>, Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>, pub(super) shortstatekey_cache:
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>, Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache:
RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache:
RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>, pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
} }
@ -271,13 +280,15 @@ impl KeyValueDatabase {
if sqlite_exists && config.database_backend != "sqlite" { if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config( return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.", "Found sqlite at database_path, but is not specified in \
config.",
)); ));
} }
if rocksdb_exists && config.database_backend != "rocksdb" { if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config( return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.", "Found rocksdb at database_path, but is not specified in \
config.",
)); ));
} }
@ -294,19 +305,30 @@ impl KeyValueDatabase {
Self::check_db_setup(&config)?; Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() { if !Path::new(&config.database_path).exists() {
std::fs::create_dir_all(&config.database_path) std::fs::create_dir_all(&config.database_path).map_err(|_| {
.map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; Error::BadConfig(
"Database folder doesn't exists and couldn't be created \
(e.g. due to missing permissions). Please create the \
database folder yourself.",
)
})?;
} }
#[cfg_attr( #[cfg_attr(
not(any(feature = "rocksdb", feature = "sqlite")), not(any(feature = "rocksdb", feature = "sqlite")),
allow(unused_variables) allow(unused_variables)
)] )]
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend { let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config
.database_backend
{
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
"sqlite" => Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?), "sqlite" => {
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
#[cfg(feature = "rocksdb")] #[cfg(feature = "rocksdb")]
"rocksdb" => Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?), "rocksdb" => {
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
_ => { _ => {
return Err(Error::BadConfig("Database backend not found.")); return Err(Error::BadConfig("Database backend not found."));
} }
@ -327,28 +349,38 @@ impl KeyValueDatabase {
userid_avatarurl: builder.open_tree("userid_avatarurl")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?, userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?, userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userdeviceid_metadata: builder
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, .open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder
.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, onetimekeyid_onetimekeys: builder
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, .open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder
.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?, keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, userid_selfsigningkeyid: builder
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, .open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder
.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?, todeviceid_events: builder.open_tree("todeviceid_events")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaainfo: builder
.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, readreceiptid_readreceipt: builder
.open_tree("readreceiptid_readreceipt")?,
// "Private" read receipt // "Private" read receipt
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, roomuserid_privateread: builder
.open_tree("roomuserid_privateread")?,
roomuserid_lastprivatereadupdate: builder roomuserid_lastprivatereadupdate: builder
.open_tree("roomuserid_lastprivatereadupdate")?, .open_tree("roomuserid_lastprivatereadupdate")?,
presenceid_presence: builder.open_tree("presenceid_presence")?, presenceid_presence: builder.open_tree("presenceid_presence")?,
userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, userid_lastpresenceupdate: builder
.open_tree("userid_lastpresenceupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?, pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?, eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
@ -367,9 +399,12 @@ impl KeyValueDatabase {
roomuserid_joined: builder.open_tree("roomuserid_joined")?, roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, roomuseroncejoinedids: builder
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, .open_tree("roomuseroncejoinedids")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, userroomid_invitestate: builder
.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder
.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
@ -377,41 +412,57 @@ impl KeyValueDatabase {
lazyloadedids: builder.open_tree("lazyloadedids")?, lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, userroomid_notificationcount: builder
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, .open_tree("userroomid_notificationcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, userroomid_highlightcount: builder
.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder
.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, statekey_shortstatekey: builder
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, .open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder
.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, shorteventid_authchain: builder
.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, shortstatehash_statediff: builder
.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, shorteventid_shortstatehash: builder
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, .open_tree("shorteventid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, roomid_shortstatehash: builder
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, .open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder
.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder
.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?, softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?, tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?, referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomuserdataid_accountdata: builder
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, .open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder
.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?, mediaid_file: builder.open_tree("mediaid_file")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?, backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, userdevicetxnid_response: builder
.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?, servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?, servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, servercurrentevent_data: builder
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, .open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder
.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?, senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?, global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
@ -489,11 +540,13 @@ impl KeyValueDatabase {
if !services().users.exists(&grapevine_user)? { if !services().users.exists(&grapevine_user)? {
error!( error!(
"The {} server user does not exist, and the database is not new.", "The {} server user does not exist, and the database is \
not new.",
grapevine_user grapevine_user
); );
return Err(Error::bad_database( return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first." "Cannot reuse an existing database after changing the \
server name, please delete the old one first.",
)); ));
} }
} }
@ -505,14 +558,15 @@ impl KeyValueDatabase {
// MIGRATIONS // MIGRATIONS
if services().globals.database_version()? < 1 { if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() { for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xff); let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = parts.next().expect("split always returns one element"); let room_id =
parts.next().expect("split always returns one element");
let Some(servername) = parts.next() else { let Some(servername) = parts.next() else {
error!("Migration: Invalid roomserverid in db."); error!("Migration: Invalid roomserverid in db.");
continue; continue;
}; };
let mut serverroomid = servername.to_vec(); let mut serverroomid = servername.to_vec();
serverroomid.push(0xff); serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id); serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?; db.serverroomids.insert(&serverroomid, &[])?;
@ -524,13 +578,16 @@ impl KeyValueDatabase {
} }
if services().globals.database_version()? < 2 { if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just "" // We accidentally inserted hashed versions of "" into the db
// instead of just ""
for (userid, password) in db.userid_password.iter() { for (userid, password) in db.userid_password.iter() {
let password = utils::string_from_bytes(&password); let password = utils::string_from_bytes(&password);
let empty_hashed_password = password.map_or(false, |password| { let empty_hashed_password =
argon2::verify_encoded(&password, b"").unwrap_or(false) password.map_or(false, |password| {
}); argon2::verify_encoded(&password, b"")
.unwrap_or(false)
});
if empty_hashed_password { if empty_hashed_password {
db.userid_password.insert(&userid, b"")?; db.userid_password.insert(&userid, b"")?;
@ -567,10 +624,16 @@ impl KeyValueDatabase {
if services().users.is_deactivated(&our_user)? { if services().users.is_deactivated(&our_user)? {
continue; continue;
} }
for room in services().rooms.state_cache.rooms_joined(&our_user) { for room in
for user in services().rooms.state_cache.room_members(&room?) { services().rooms.state_cache.rooms_joined(&our_user)
{
for user in
services().rooms.state_cache.room_members(&room?)
{
let user = user?; let user = user?;
if user.server_name() != services().globals.server_name() { if user.server_name()
!= services().globals.server_name()
{
info!(?user, "Migration: creating user"); info!(?user, "Migration: creating user");
services().users.create(&user, None)?; services().users.create(&user, None)?;
} }
@ -585,16 +648,18 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 5 { if services().globals.database_version()? < 5 {
// Upgrade user data store // Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter()
let mut parts = roomuserdataid.split(|&b| b == 0xff); {
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap(); let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap(); let event_type =
roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec(); let mut key = room_id.to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id); key.extend_from_slice(user_id);
key.push(0xff); key.push(0xFF);
key.extend_from_slice(event_type); key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid db.roomusertype_roomuserdataid
@ -611,7 +676,10 @@ impl KeyValueDatabase {
for (roomid, _) in db.roomid_shortstatehash.iter() { for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap(); let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?; services()
.rooms
.state_cache
.update_joined_count(room_id)?;
} }
services().globals.bump_database_version(6)?; services().globals.bump_database_version(6)?;
@ -621,7 +689,8 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 7 { if services().globals.database_version()? < 7 {
// Upgrade state store // Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); let mut last_roomstates: HashMap<OwnedRoomId, u64> =
HashMap::new();
let mut current_sstatehash: Option<u64> = None; let mut current_sstatehash: Option<u64> = None;
let mut current_room = None; let mut current_room = None;
let mut current_state = HashSet::new(); let mut current_state = HashSet::new();
@ -633,7 +702,8 @@ impl KeyValueDatabase {
current_state: HashSet<_>, current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| { last_roomstates: &mut HashMap<_, _>| {
counter += 1; counter += 1;
let last_roomsstatehash = last_roomstates.get(current_room); let last_roomsstatehash =
last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else( let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
@ -641,12 +711,16 @@ impl KeyValueDatabase {
services() services()
.rooms .rooms
.state_compressor .state_compressor
.load_shortstatehash_info(last_roomsstatehash) .load_shortstatehash_info(
last_roomsstatehash,
)
}, },
)?; )?;
let (statediffnew, statediffremoved) = let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() { if let Some(parent_stateinfo) =
states_parents.last()
{
let statediffnew = current_state let statediffnew = current_state
.difference(&parent_stateinfo.1) .difference(&parent_stateinfo.1)
.copied() .copied()
@ -663,21 +737,28 @@ impl KeyValueDatabase {
(current_state, HashSet::new()) (current_state, HashSet::new())
}; };
services().rooms.state_compressor.save_state_from_diff( services()
current_sstatehash, .rooms
Arc::new(statediffnew), .state_compressor
Arc::new(statediffremoved), .save_state_from_diff(
// every state change is 2 event changes on average current_sstatehash,
2, Arc::new(statediffnew),
states_parents, Arc::new(statediffremoved),
)?; // every state change is 2 event changes on
// average
2,
states_parents,
)?;
Ok::<_, Error>(()) Ok::<_, Error>(())
}; };
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { for (k, seventid) in
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]) db.db.open_tree("stateid_shorteventid")?.iter()
.expect("number of bytes is correct"); {
let sstatehash =
utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec(); let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash { if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash { if let Some(current_sstatehash) = current_sstatehash {
@ -687,15 +768,23 @@ impl KeyValueDatabase {
current_state, current_state,
&mut last_roomstates, &mut last_roomstates,
)?; )?;
last_roomstates last_roomstates.insert(
.insert(current_room.clone().unwrap(), current_sstatehash); current_room.clone().unwrap(),
current_sstatehash,
);
} }
current_state = HashSet::new(); current_state = HashSet::new();
current_sstatehash = Some(sstatehash); current_sstatehash = Some(sstatehash);
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); let event_id = db
let string = utils::string_from_bytes(&event_id).unwrap(); .shorteventid_eventid
let event_id = <&EventId>::try_from(string.as_str()).unwrap(); .get(&seventid)
.unwrap()
.unwrap();
let string =
utils::string_from_bytes(&event_id).unwrap();
let event_id =
<&EventId>::try_from(string.as_str()).unwrap();
let pdu = services() let pdu = services()
.rooms .rooms
.timeline .timeline
@ -710,7 +799,8 @@ impl KeyValueDatabase {
let mut val = sstatekey; let mut val = sstatekey;
val.extend_from_slice(&seventid); val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct")); current_state
.insert(val.try_into().expect("size is correct"));
} }
if let Some(current_sstatehash) = current_sstatehash { if let Some(current_sstatehash) = current_sstatehash {
@ -730,7 +820,8 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 8 { if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms // Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() { for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes(); let shortroomid =
services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?; db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8"); info!("Migration: 8");
} }
@ -739,7 +830,7 @@ impl KeyValueDatabase {
if !key.starts_with(b"!") { if !key.starts_with(b"!") {
return None; return None;
} }
let mut parts = key.splitn(2, |&b| b == 0xff); let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap(); let count = parts.next().unwrap();
@ -757,25 +848,26 @@ impl KeyValueDatabase {
db.pduid_pdu.insert_batch(&mut batch)?; db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { let mut batch2 =
if !value.starts_with(b"!") { db.eventid_pduid.iter().filter_map(|(k, value)| {
return None; if !value.starts_with(b"!") {
} return None;
let mut parts = value.splitn(2, |&b| b == 0xff); }
let room_id = parts.next().unwrap(); let mut parts = value.splitn(2, |&b| b == 0xFF);
let count = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db let short_room_id = db
.roomid_shortroomid .roomid_shortroomid
.get(room_id) .get(room_id)
.unwrap() .unwrap()
.expect("shortroomid should exist"); .expect("shortroomid should exist");
let mut new_value = short_room_id; let mut new_value = short_room_id;
new_value.extend_from_slice(count); new_value.extend_from_slice(count);
Some((k, new_value)) Some((k, new_value))
}); });
db.eventid_pduid.insert_batch(&mut batch2)?; db.eventid_pduid.insert_batch(&mut batch2)?;
@ -793,7 +885,7 @@ impl KeyValueDatabase {
if !key.starts_with(b"!") { if !key.starts_with(b"!") {
return None; return None;
} }
let mut parts = key.splitn(4, |&b| b == 0xff); let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let word = parts.next().unwrap(); let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap(); let _pdu_id_room = parts.next().unwrap();
@ -806,7 +898,7 @@ impl KeyValueDatabase {
.expect("shortroomid should exist"); .expect("shortroomid should exist");
let mut new_key = short_room_id; let mut new_key = short_room_id;
new_key.extend_from_slice(word); new_key.extend_from_slice(word);
new_key.push(0xff); new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count); new_key.extend_from_slice(pdu_id_count);
Some((new_key, Vec::new())) Some((new_key, Vec::new()))
}) })
@ -836,12 +928,15 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 10 { if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys // Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { for (statekey, shortstatekey) in
db.statekey_shortstatekey.iter()
{
db.shortstatekey_statekey db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?; .insert(&shortstatekey, &statekey)?;
} }
// Force E2EE device list updates so we can send them over federation // Force E2EE device list updates so we can send them over
// federation
for user_id in services().users.iter().filter_map(Result::ok) { for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?; services().users.mark_device_key_update(&user_id)?;
} }
@ -852,9 +947,7 @@ impl KeyValueDatabase {
} }
if services().globals.database_version()? < 11 { if services().globals.database_version()? < 11 {
db.db db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?;
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
services().globals.bump_database_version(11)?; services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished"); warn!("Migration: 10 -> 11 finished");
@ -878,24 +971,34 @@ impl KeyValueDatabase {
.get( .get(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules
.to_string()
.into(),
) )
.unwrap() .unwrap()
.expect("Username is invalid"); .expect("Username is invalid");
let mut account_data = let mut account_data =
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); serde_json::from_str::<PushRulesEvent>(
raw_rules_list.get(),
)
.unwrap();
let rules_list = &mut account_data.content.global; let rules_list = &mut account_data.content.global;
//content rule //content rule
{ {
let content_rule_transformation = let content_rule_transformation = [
[".m.rules.contains_user_name", ".m.rule.contains_user_name"]; ".m.rules.contains_user_name",
".m.rule.contains_user_name",
];
let rule = rules_list.content.get(content_rule_transformation[0]); let rule = rules_list
.content
.get(content_rule_transformation[0]);
if rule.is_some() { if rule.is_some() {
let mut rule = rule.unwrap().clone(); let mut rule = rule.unwrap().clone();
rule.rule_id = content_rule_transformation[1].to_owned(); rule.rule_id =
content_rule_transformation[1].to_owned();
rules_list rules_list
.content .content
.shift_remove(content_rule_transformation[0]); .shift_remove(content_rule_transformation[0]);
@ -907,7 +1010,10 @@ impl KeyValueDatabase {
{ {
let underride_rule_transformation = [ let underride_rule_transformation = [
[".m.rules.call", ".m.rule.call"], [".m.rules.call", ".m.rule.call"],
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], [
".m.rules.room_one_to_one",
".m.rule.room_one_to_one",
],
[ [
".m.rules.encrypted_room_one_to_one", ".m.rules.encrypted_room_one_to_one",
".m.rule.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one",
@ -917,11 +1023,14 @@ impl KeyValueDatabase {
]; ];
for transformation in underride_rule_transformation { for transformation in underride_rule_transformation {
let rule = rules_list.underride.get(transformation[0]); let rule =
rules_list.underride.get(transformation[0]);
if let Some(rule) = rule { if let Some(rule) = rule {
let mut rule = rule.clone(); let mut rule = rule.clone();
rule.rule_id = transformation[1].to_owned(); rule.rule_id = transformation[1].to_owned();
rules_list.underride.shift_remove(transformation[0]); rules_list
.underride
.shift_remove(transformation[0]);
rules_list.underride.insert(rule); rules_list.underride.insert(rule);
} }
} }
@ -930,8 +1039,11 @@ impl KeyValueDatabase {
services().account_data.update( services().account_data.update(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules
&serde_json::to_value(account_data).expect("to json value always works"), .to_string()
.into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
} }
@ -940,7 +1052,8 @@ impl KeyValueDatabase {
warn!("Migration: 11 -> 12 finished"); warn!("Migration: 11 -> 12 finished");
} }
// This migration can be reused as-is anytime the server-default rules are updated. // This migration can be reused as-is anytime the server-default
// rules are updated.
if services().globals.database_version()? < 13 { if services().globals.database_version()? < 13 {
for username in services().users.list_local_users()? { for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name( let user = match UserId::parse_with_server_name(
@ -959,15 +1072,21 @@ impl KeyValueDatabase {
.get( .get(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules
.to_string()
.into(),
) )
.unwrap() .unwrap()
.expect("Username is invalid"); .expect("Username is invalid");
let mut account_data = let mut account_data =
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); serde_json::from_str::<PushRulesEvent>(
raw_rules_list.get(),
)
.unwrap();
let user_default_rules = ruma::push::Ruleset::server_default(&user); let user_default_rules =
ruma::push::Ruleset::server_default(&user);
account_data account_data
.content .content
.global .global
@ -976,8 +1095,11 @@ impl KeyValueDatabase {
services().account_data.update( services().account_data.update(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules
&serde_json::to_value(account_data).expect("to json value always works"), .to_string()
.into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?; )?;
} }
@ -1018,13 +1140,24 @@ impl KeyValueDatabase {
match set_emergency_access() { match set_emergency_access() {
Ok(pwd_set) => { Ok(pwd_set) => {
if pwd_set { if pwd_set {
warn!("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!"); warn!(
services().admin.send_message(RoomMessageEventContent::text_plain("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!")); "The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin account \
recovery!"
);
services().admin.send_message(
RoomMessageEventContent::text_plain(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin \
account recovery!",
),
);
} }
} }
Err(e) => { Err(e) => {
error!( error!(
"Could not set the configured emergency password for the grapevine user: {}", "Could not set the configured emergency password for the \
grapevine user: {}",
e e
); );
} }
@ -1050,15 +1183,15 @@ impl KeyValueDatabase {
#[tracing::instrument] #[tracing::instrument]
pub(crate) async fn start_cleanup_task() { pub(crate) async fn start_cleanup_task() {
use tokio::time::interval; use std::time::{Duration, Instant};
#[cfg(unix)] #[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind}; use tokio::signal::unix::{signal, SignalKind};
use tokio::time::interval;
use std::time::{Duration, Instant}; let timer_interval = Duration::from_secs(u64::from(
services().globals.config.cleanup_second_interval,
let timer_interval = ));
Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
tokio::spawn(async move { tokio::spawn(async move {
let mut i = interval(timer_interval); let mut i = interval(timer_interval);
@ -1092,11 +1225,14 @@ impl KeyValueDatabase {
} }
} }
/// Sets the emergency password and push rules for the @grapevine account in case emergency password is set /// Sets the emergency password and push rules for the @grapevine account in
/// case emergency password is set
fn set_emergency_access() -> Result<bool> { fn set_emergency_access() -> Result<bool> {
let grapevine_user = let grapevine_user = UserId::parse_with_server_name(
UserId::parse_with_server_name("grapevine", services().globals.server_name()) "grapevine",
.expect("@grapevine:server_name is a valid UserId"); services().globals.server_name(),
)
.expect("@grapevine:server_name is a valid UserId");
services().users.set_password( services().users.set_password(
&grapevine_user, &grapevine_user,
@ -1113,7 +1249,9 @@ fn set_emergency_access() -> Result<bool> {
&grapevine_user, &grapevine_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent { &serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent { global: ruleset }, content: PushRulesEventContent {
global: ruleset,
},
}) })
.expect("to json value always works"), .expect("to json value always works"),
)?; )?;

View file

@ -1,8 +1,8 @@
use std::{future::Future, pin::Pin, sync::Arc};
use super::Config; use super::Config;
use crate::Result; use crate::Result;
use std::{future::Future, pin::Pin, sync::Arc};
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
pub(crate) mod sqlite; pub(crate) mod sqlite;
@ -22,7 +22,8 @@ pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
Ok(()) Ok(())
} }
fn memory_usage(&self) -> Result<String> { fn memory_usage(&self) -> Result<String> {
Ok("Current database engine does not support memory usage reporting.".to_owned()) Ok("Current database engine does not support memory usage reporting."
.to_owned())
} }
fn clear_caches(&self) {} fn clear_caches(&self) {}
} }
@ -31,7 +32,10 @@ pub(crate) trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>; fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()>;
fn remove(&self, key: &[u8]) -> Result<()>; fn remove(&self, key: &[u8]) -> Result<()>;
@ -44,14 +48,20 @@ pub(crate) trait KvTree: Send + Sync {
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>; fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()>;
fn scan_prefix<'a>( fn scan_prefix<'a>(
&'a self, &'a self,
prefix: Vec<u8>, prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear(&self) -> Result<()> { fn clear(&self) -> Result<()> {
for (key, _) in self.iter() { for (key, _) in self.iter() {

View file

@ -1,17 +1,21 @@
use rocksdb::{
perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache,
ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType, DBRecoveryMode, DBWithThreadMode,
Direction, IteratorMode, MultiThreaded, Options, ReadOptions, WriteOptions,
};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{utils, Result};
use std::{ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use rocksdb::{
perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache,
ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType,
DBRecoveryMode, DBWithThreadMode, Direction, IteratorMode, MultiThreaded,
Options, ReadOptions, WriteOptions,
};
use super::{
super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree,
};
use crate::{utils, Result};
pub(crate) struct Engine { pub(crate) struct Engine {
rocks: DBWithThreadMode<MultiThreaded>, rocks: DBWithThreadMode<MultiThreaded>,
max_open_files: i32, max_open_files: i32,
@ -38,7 +42,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &Cache) -> Options {
let mut db_opts = Options::default(); let mut db_opts = Options::default();
db_opts.set_block_based_table_factory(&block_based_options); db_opts.set_block_based_table_factory(&block_based_options);
db_opts.create_if_missing(true); db_opts.create_if_missing(true);
db_opts.increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX)); db_opts
.increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX));
db_opts.set_max_open_files(max_open_files); db_opts.set_max_open_files(max_open_files);
db_opts.set_compression_type(DBCompressionType::Lz4); db_opts.set_compression_type(DBCompressionType::Lz4);
db_opts.set_bottommost_compression_type(DBCompressionType::Zstd); db_opts.set_bottommost_compression_type(DBCompressionType::Zstd);
@ -69,13 +74,17 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_sign_loss, clippy::cast_sign_loss,
clippy::cast_possible_truncation clippy::cast_possible_truncation
)] )]
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let cache_capacity_bytes =
(config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let rocksdb_cache = Cache::new_lru_cache(cache_capacity_bytes); let rocksdb_cache = Cache::new_lru_cache(cache_capacity_bytes);
let db_opts = db_options(config.rocksdb_max_open_files, &rocksdb_cache); let db_opts = db_options(config.rocksdb_max_open_files, &rocksdb_cache);
let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(&db_opts, &config.database_path) let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(
.unwrap_or_default(); &db_opts,
&config.database_path,
)
.unwrap_or_default();
let db = DBWithThreadMode::<MultiThreaded>::open_cf_descriptors( let db = DBWithThreadMode::<MultiThreaded>::open_cf_descriptors(
&db_opts, &db_opts,
@ -119,14 +128,14 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(clippy::as_conversions, clippy::cast_precision_loss)] #[allow(clippy::as_conversions, clippy::cast_precision_loss)]
fn memory_usage(&self) -> Result<String> { fn memory_usage(&self) -> Result<String> {
let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; let stats =
get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
Ok(format!( Ok(format!(
"Approximate memory usage of all the mem-tables: {:.3} MB\n\ "Approximate memory usage of all the mem-tables: {:.3} \
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\ MB\nApproximate memory usage of un-flushed mem-tables: {:.3} \
Approximate memory usage of all the table readers: {:.3} MB\n\ MB\nApproximate memory usage of all the table readers: {:.3} \
Approximate memory usage by cache: {:.3} MB\n\ MB\nApproximate memory usage by cache: {:.3} MB\nApproximate \
Approximate memory usage by cache pinned: {:.3} MB\n\ memory usage by cache pinned: {:.3} MB\n",
",
stats.mem_table_total as f64 / 1024.0 / 1024.0, stats.mem_table_total as f64 / 1024.0 / 1024.0,
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
@ -154,9 +163,7 @@ impl KvTree for RocksDbEngineTree<'_> {
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default(); let writeoptions = WriteOptions::default();
let lock = self.write_lock.read().unwrap(); let lock = self.write_lock.read().unwrap();
self.db self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
.rocks
.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
drop(lock); drop(lock);
self.watchers.wake(key); self.watchers.wake(key);
@ -164,12 +171,13 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(()) Ok(())
} }
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> { fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()> {
let writeoptions = WriteOptions::default(); let writeoptions = WriteOptions::default();
for (key, value) in iter { for (key, value) in iter {
self.db self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
.rocks
.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
} }
Ok(()) Ok(())
@ -177,10 +185,7 @@ impl KvTree for RocksDbEngineTree<'_> {
fn remove(&self, key: &[u8]) -> Result<()> { fn remove(&self, key: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default(); let writeoptions = WriteOptions::default();
Ok(self Ok(self.db.rocks.delete_cf_opt(&self.cf(), key, &writeoptions)?)
.db
.rocks
.delete_cf_opt(&self.cf(), key, &writeoptions)?)
} }
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> { fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
@ -230,26 +235,26 @@ impl KvTree for RocksDbEngineTree<'_> {
let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?; let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?;
let new = utils::increment(old.as_deref()); let new = utils::increment(old.as_deref());
self.db self.db.rocks.put_cf_opt(&self.cf(), key, &new, &writeoptions)?;
.rocks
.put_cf_opt(&self.cf(), key, &new, &writeoptions)?;
drop(lock); drop(lock);
Ok(new) Ok(new)
} }
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> { fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()> {
let readoptions = ReadOptions::default(); let readoptions = ReadOptions::default();
let writeoptions = WriteOptions::default(); let writeoptions = WriteOptions::default();
let lock = self.write_lock.write().unwrap(); let lock = self.write_lock.write().unwrap();
for key in iter { for key in iter {
let old = self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?; let old =
self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?;
let new = utils::increment(old.as_deref()); let new = utils::increment(old.as_deref());
self.db self.db.rocks.put_cf_opt(&self.cf(), key, new, &writeoptions)?;
.rocks
.put_cf_opt(&self.cf(), key, new, &writeoptions)?;
} }
drop(lock); drop(lock);
@ -277,7 +282,10 @@ impl KvTree for RocksDbEngineTree<'_> {
) )
} }
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix) self.watchers.watch(prefix)
} }
} }

View file

@ -1,7 +1,3 @@
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use std::{ use std::{
cell::RefCell, cell::RefCell,
future::Future, future::Future,
@ -9,9 +5,15 @@ use std::{
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
}; };
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use thread_local::ThreadLocal; use thread_local::ThreadLocal;
use tracing::debug; use tracing::debug;
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
thread_local! { thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
@ -68,7 +70,11 @@ impl Engine {
conn.pragma_update(Some(Main), "page_size", 2048)?; conn.pragma_update(Some(Main), "page_size", 2048)?;
conn.pragma_update(Some(Main), "journal_mode", "WAL")?; conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; conn.pragma_update(
Some(Main),
"cache_size",
-i64::from(cache_size_kb),
)?;
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
Ok(conn) Ok(conn)
@ -79,18 +85,23 @@ impl Engine {
} }
fn read_lock(&self) -> &Connection { fn read_lock(&self) -> &Connection {
self.read_conn_tls self.read_conn_tls.get_or(|| {
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()
})
} }
fn read_lock_iterator(&self) -> &Connection { fn read_lock_iterator(&self) -> &Connection {
self.read_iterator_conn_tls self.read_iterator_conn_tls.get_or(|| {
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()
})
} }
pub(crate) fn flush_wal(self: &Arc<Self>) -> Result<()> { pub(crate) fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock() self.write_lock().pragma_update(
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; Some(Main),
"wal_checkpoint",
"RESTART",
)?;
Ok(()) Ok(())
} }
} }
@ -108,7 +119,8 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
// calculates cache-size per permanent connection // calculates cache-size per permanent connection
// 1. convert MB to KiB // 1. convert MB to KiB
// 2. divide by permanent connections + permanent iter connections + write connection // 2. divide by permanent connections + permanent iter connections +
// write connection
// 3. round down to nearest integer // 3. round down to nearest integer
#[allow( #[allow(
clippy::as_conversions, clippy::as_conversions,
@ -117,9 +129,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_sign_loss clippy::cast_sign_loss
)] )]
let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0) let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0)
/ ((num_cpus::get() as f64 * 2.0) + 1.0)) as u32; / ((num_cpus::get() as f64 * 2.0) + 1.0))
as u32;
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); let writer =
Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
let arc = Arc::new(Engine { let arc = Arc::new(Engine {
writer, writer,
@ -133,7 +147,13 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
} }
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> { fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; self.write_lock().execute(
&format!(
"CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY \
KEY, \"value\" BLOB NOT NULL )"
),
[],
)?;
Ok(Arc::new(SqliteTable { Ok(Arc::new(SqliteTable {
engine: Arc::clone(self), engine: Arc::clone(self),
@ -161,14 +181,26 @@ pub(crate) struct SqliteTable {
type TupleOfBytes = (Vec<u8>, Vec<u8>); type TupleOfBytes = (Vec<u8>, Vec<u8>);
impl SqliteTable { impl SqliteTable {
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get_with_guard(
&self,
guard: &Connection,
key: &[u8],
) -> Result<Option<Vec<u8>>> {
Ok(guard Ok(guard
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .prepare(
format!("SELECT value FROM {} WHERE key = ?", self.name)
.as_str(),
)?
.query_row([key], |row| row.get(0)) .query_row([key], |row| row.get(0))
.optional()?) .optional()?)
} }
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { fn insert_with_guard(
&self,
guard: &Connection,
key: &[u8],
value: &[u8],
) -> Result<()> {
guard.execute( guard.execute(
format!( format!(
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
@ -222,7 +254,10 @@ impl KvTree for SqliteTable {
Ok(()) Ok(())
} }
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> { fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?; guard.execute("BEGIN", [])?;
@ -236,7 +271,10 @@ impl KvTree for SqliteTable {
Ok(()) Ok(())
} }
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> { fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?; guard.execute("BEGIN", [])?;
@ -282,7 +320,8 @@ impl KvTree for SqliteTable {
let statement = Box::leak(Box::new( let statement = Box::leak(Box::new(
guard guard
.prepare(&format!( .prepare(&format!(
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", "SELECT key, value FROM {} WHERE key <= ? ORDER BY \
key DESC",
&self.name &self.name
)) ))
.unwrap(), .unwrap(),
@ -292,7 +331,9 @@ impl KvTree for SqliteTable {
let iterator = Box::new( let iterator = Box::new(
statement statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .query_map([from], |row| {
Ok((row.get_unwrap(0), row.get_unwrap(1)))
})
.unwrap() .unwrap()
.map(Result::unwrap), .map(Result::unwrap),
); );
@ -304,7 +345,8 @@ impl KvTree for SqliteTable {
let statement = Box::leak(Box::new( let statement = Box::leak(Box::new(
guard guard
.prepare(&format!( .prepare(&format!(
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", "SELECT key, value FROM {} WHERE key >= ? ORDER BY \
key ASC",
&self.name &self.name
)) ))
.unwrap(), .unwrap(),
@ -314,7 +356,9 @@ impl KvTree for SqliteTable {
let iterator = Box::new( let iterator = Box::new(
statement statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .query_map([from], |row| {
Ok((row.get_unwrap(0), row.get_unwrap(1)))
})
.unwrap() .unwrap()
.map(Result::unwrap), .map(Result::unwrap),
); );
@ -338,14 +382,20 @@ impl KvTree for SqliteTable {
Ok(new) Ok(new)
} }
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
Box::new( Box::new(
self.iter_from(&prefix, false) self.iter_from(&prefix, false)
.take_while(move |(key, _)| key.starts_with(&prefix)), .take_while(move |(key, _)| key.starts_with(&prefix)),
) )
} }
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix) self.watchers.watch(prefix)
} }

View file

@ -4,12 +4,14 @@ use std::{
pin::Pin, pin::Pin,
sync::RwLock, sync::RwLock,
}; };
use tokio::sync::watch; use tokio::sync::watch;
#[derive(Default)] #[derive(Default)]
pub(super) struct Watchers { pub(super) struct Watchers {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
watchers: RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>, watchers:
RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>,
} }
impl Watchers { impl Watchers {
@ -17,7 +19,8 @@ impl Watchers {
&'a self, &'a self,
prefix: &[u8], prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec())
{
hash_map::Entry::Occupied(o) => o.get().1.clone(), hash_map::Entry::Occupied(o) => o.get().1.clone(),
hash_map::Entry::Vacant(v) => { hash_map::Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(()); let (tx, rx) = tokio::sync::watch::channel(());
@ -31,6 +34,7 @@ impl Watchers {
rx.changed().await.unwrap(); rx.changed().await.unwrap();
}) })
} }
pub(super) fn wake(&self, key: &[u8]) { pub(super) fn wake(&self, key: &[u8]) {
let watchers = self.watchers.read().unwrap(); let watchers = self.watchers.read().unwrap();
let mut triggered = Vec::new(); let mut triggered = Vec::new();

View file

@ -7,10 +7,13 @@ use ruma::{
RoomId, UserId, RoomId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::account_data::Data for KeyValueDatabase { impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update( fn update(
&self, &self,
@ -24,13 +27,14 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default() .unwrap_or_default()
.as_bytes() .as_bytes()
.to_vec(); .to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes()); prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
let mut roomuserdataid = prefix.clone(); let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid
roomuserdataid.push(0xff); .extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix; let mut key = prefix;
@ -45,13 +49,13 @@ impl service::account_data::Data for KeyValueDatabase {
self.roomuserdataid_accountdata.insert( self.roomuserdataid_accountdata.insert(
&roomuserdataid, &roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"), &serde_json::to_vec(&data)
.expect("to_vec always works on json values"),
)?; )?;
let prev = self.roomusertype_roomuserdataid.get(&key)?; let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
.insert(&key, &roomuserdataid)?;
// Remove old entry // Remove old entry
if let Some(prev) = prev { if let Some(prev) = prev {
@ -74,17 +78,15 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default() .unwrap_or_default()
.as_bytes() .as_bytes()
.to_vec(); .to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes()); key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid
.get(&key)? .get(&key)?
.and_then(|roomuserdataid| { .and_then(|roomuserdataid| {
self.roomuserdataid_accountdata self.roomuserdataid_accountdata.get(&roomuserdataid).transpose()
.get(&roomuserdataid)
.transpose()
}) })
.transpose()? .transpose()?
.map(|data| { .map(|data| {
@ -101,7 +103,8 @@ impl service::account_data::Data for KeyValueDatabase {
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
since: u64, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>
{
let mut userdata = HashMap::new(); let mut userdata = HashMap::new();
let mut prefix = room_id let mut prefix = room_id
@ -109,9 +112,9 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default() .unwrap_or_default()
.as_bytes() .as_bytes()
.to_vec(); .to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes()); prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time // Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone(); let mut first_possible = prefix.clone();
@ -124,14 +127,27 @@ impl service::account_data::Data for KeyValueDatabase {
.map(|(k, v)| { .map(|(k, v)| {
Ok::<_, Error>(( Ok::<_, Error>((
RoomAccountDataEventType::from( RoomAccountDataEventType::from(
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( utils::string_from_bytes(
|| Error::bad_database("RoomUserData ID in db is invalid."), k.rsplit(|&b| b == 0xFF).next().ok_or_else(
)?) || {
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, Error::bad_database(
"RoomUserData ID in db is invalid.",
)
},
)?,
)
.map_err(|_| {
Error::bad_database(
"RoomUserData ID in db is invalid.",
)
})?,
), ),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| { serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
Error::bad_database("Database contains invalid account data.") .map_err(|_| {
})?, Error::bad_database(
"Database contains invalid account data.",
)
})?,
)) ))
}) })
{ {

View file

@ -20,8 +20,7 @@ impl service::appservice::Data for KeyValueDatabase {
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> { fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations self.id_appserviceregistrations.remove(service_name.as_bytes())?;
.remove(service_name.as_bytes())?;
Ok(()) Ok(())
} }
@ -30,20 +29,25 @@ impl service::appservice::Data for KeyValueDatabase {
.get(id.as_bytes())? .get(id.as_bytes())?
.map(|bytes| { .map(|bytes| {
serde_yaml::from_slice(&bytes).map_err(|_| { serde_yaml::from_slice(&bytes).map_err(|_| {
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") Error::bad_database(
"Invalid registration bytes in \
id_appserviceregistrations.",
)
}) })
}) })
.transpose() .transpose()
} }
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { fn iter_ids<'a>(
Ok(Box::new(self.id_appserviceregistrations.iter().map( &'a self,
|(id, _)| { ) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
utils::string_from_bytes(&id).map_err(|_| { Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
Error::bad_database("Invalid id bytes in id_appserviceregistrations.") utils::string_from_bytes(&id).map_err(|_| {
}) Error::bad_database(
}, "Invalid id bytes in id_appserviceregistrations.",
))) )
})
})))
} }
fn all(&self) -> Result<Vec<(String, Registration)>> { fn all(&self) -> Result<Vec<(String, Registration)>> {

View file

@ -6,10 +6,13 @@ use lru_cache::LruCache;
use ruma::{ use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey}, api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair, signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName,
UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
pub(crate) const COUNTER: &[u8] = b"c"; pub(crate) const COUNTER: &[u8] = b"c";
@ -27,14 +30,18 @@ impl service::globals::Data for KeyValueDatabase {
}) })
} }
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { async fn watch(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec(); let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone(); let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xff); userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone(); let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xff); userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new(); let mut futures = FuturesUnordered::new();
@ -46,10 +53,10 @@ impl service::globals::Data for KeyValueDatabase {
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push( futures.push(
self.userroomid_notificationcount self.userroomid_notificationcount.watch_prefix(&userid_prefix),
.watch_prefix(&userid_prefix),
); );
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); futures
.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in // Events for rooms we are in
for room_id in services() for room_id in services()
@ -70,17 +77,24 @@ impl service::globals::Data for KeyValueDatabase {
let roomid_bytes = room_id.as_bytes().to_vec(); let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone(); let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xff); roomid_prefix.push(0xFF);
// PDUs // PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs // EDUs
futures.push(Box::pin(async move { futures.push(Box::pin(async move {
let _result = services().rooms.edus.typing.wait_for_update(&room_id).await; let _result = services()
.rooms
.edus
.typing
.wait_for_update(&room_id)
.await;
})); }));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); futures.push(
self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix),
);
// Key changes // Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
@ -90,12 +104,11 @@ impl service::globals::Data for KeyValueDatabase {
roomuser_prefix.extend_from_slice(&userid_prefix); roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push( futures.push(
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix),
.watch_prefix(&roomuser_prefix),
); );
} }
let mut globaluserdata_prefix = vec![0xff]; let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix); globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push( futures.push(
@ -107,7 +120,8 @@ impl service::globals::Data for KeyValueDatabase {
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys // One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures
.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch())); futures.push(Box::pin(services().globals.rotate.watch()));
@ -126,10 +140,14 @@ impl service::globals::Data for KeyValueDatabase {
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len(); let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len(); let statekeyshort_cache =
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); self.statekeyshort_cache.lock().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); let our_real_users_cache =
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache =
self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache =
self.lasttimelinecount_cache.lock().unwrap().len();
let mut response = format!( let mut response = format!(
"\ "\
@ -194,27 +212,29 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|s| Ok(s.clone()), |s| Ok(s.clone()),
)?; )?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes( utils::string_from_bytes(
// 1. version // 1. version
parts parts.next().expect("splitn always returns at least one element"),
.next()
.expect("splitn always returns at least one element"),
) )
.map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| { .and_then(|version| {
// 2. key // 2. key
parts parts
.next() .next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) .ok_or_else(|| {
Error::bad_database("Invalid keypair format in database.")
})
.map(|key| (version, key)) .map(|key| (version, key))
}) })
.and_then(|(version, key)| { .and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version) Ed25519KeyPair::from_der(key, version).map_err(|_| {
.map_err(|_| Error::bad_database("Private or public keys are invalid.")) Error::bad_database("Private or public keys are invalid.")
})
}) })
} }
fn remove_keypair(&self) -> Result<()> { fn remove_keypair(&self) -> Result<()> {
self.global.remove(b"keypair") self.global.remove(b"keypair")
} }
@ -231,7 +251,10 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
.and_then(|keys| serde_json::from_slice(&keys).ok()) .and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| { .unwrap_or_else(|| {
// Just insert "now", it doesn't matter // Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) ServerSigningKeys::new(
origin.to_owned(),
MilliSecondsSinceUnixEpoch::now(),
)
}); });
let ServerSigningKeys { let ServerSigningKeys {
@ -245,7 +268,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
self.server_signingkeys.insert( self.server_signingkeys.insert(
origin.as_bytes(), origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), &serde_json::to_vec(&keys)
.expect("serversigningkeys can be serialized"),
)?; )?;
let mut tree = keys.verify_keys; let mut tree = keys.verify_keys;
@ -258,7 +282,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
Ok(tree) Ok(tree)
} }
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for( fn signing_keys_for(
&self, &self,
origin: &ServerName, origin: &ServerName,
@ -283,8 +308,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
fn database_version(&self) -> Result<u64> { fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| { self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version) utils::u64_from_bytes(&version).map_err(|_| {
.map_err(|_| Error::bad_database("Database version id is invalid.")) Error::bad_database("Database version id is invalid.")
})
}) })
} }

View file

@ -9,7 +9,9 @@ use ruma::{
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::key_backups::Data for KeyValueDatabase { impl service::key_backups::Data for KeyValueDatabase {
fn create_backup( fn create_backup(
@ -20,12 +22,13 @@ impl service::key_backups::Data for KeyValueDatabase {
let version = services().globals.next_count()?.to_string(); let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert( self.backupid_algorithm.insert(
&key, &key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata)
.expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())?;
@ -34,13 +37,13 @@ impl service::key_backups::Data for KeyValueDatabase {
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?; self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?; self.backupid_etag.remove(&key)?;
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
@ -56,7 +59,7 @@ impl service::key_backups::Data for KeyValueDatabase {
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> { ) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() { if self.backupid_algorithm.get(&key)?.is_none() {
@ -73,9 +76,12 @@ impl service::key_backups::Data for KeyValueDatabase {
Ok(version.to_owned()) Ok(version.to_owned())
} }
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { fn get_latest_backup_version(
&self,
user_id: &UserId,
) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -85,11 +91,13 @@ impl service::key_backups::Data for KeyValueDatabase {
.next() .next()
.map(|(key, _)| { .map(|(key, _)| {
utils::string_from_bytes( utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) .map_err(|_| {
Error::bad_database("backupid_algorithm key is invalid.")
})
}) })
.transpose() .transpose()
} }
@ -99,7 +107,7 @@ impl service::key_backups::Data for KeyValueDatabase {
user_id: &UserId, user_id: &UserId,
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { ) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -109,33 +117,42 @@ impl service::key_backups::Data for KeyValueDatabase {
.next() .next()
.map(|(key, value)| { .map(|(key, value)| {
let version = utils::string_from_bytes( let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; .map_err(|_| {
Error::bad_database("backupid_algorithm key is invalid.")
})?;
Ok(( Ok((
version, version,
serde_json::from_slice(&value).map_err(|_| { serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("Algorithm in backupid_algorithm is invalid.") Error::bad_database(
"Algorithm in backupid_algorithm is invalid.",
)
})?, })?,
)) ))
}) })
.transpose() .transpose()
} }
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { fn get_backup(
&self,
user_id: &UserId,
version: &str,
) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
.get(&key)? serde_json::from_slice(&bytes).map_err(|_| {
.map_or(Ok(None), |bytes| { Error::bad_database(
serde_json::from_slice(&bytes) "Algorithm in backupid_algorithm is invalid.",
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) )
}) })
})
} }
fn add_key( fn add_key(
@ -147,7 +164,7 @@ impl service::key_backups::Data for KeyValueDatabase {
key_data: &Raw<KeyBackupData>, key_data: &Raw<KeyBackupData>,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() { if self.backupid_algorithm.get(&key)?.is_none() {
@ -160,9 +177,9 @@ impl service::key_backups::Data for KeyValueDatabase {
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(session_id.as_bytes()); key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup self.backupkeyid_backup
@ -173,7 +190,7 @@ impl service::key_backups::Data for KeyValueDatabase {
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
@ -181,7 +198,7 @@ impl service::key_backups::Data for KeyValueDatabase {
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes( Ok(utils::u64_from_bytes(
@ -200,40 +217,56 @@ impl service::key_backups::Data for KeyValueDatabase {
version: &str, version: &str,
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { ) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self for result in
.backupkeyid_backup self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
.scan_prefix(prefix) let mut parts = key.rsplit(|&b| b == 0xFF);
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff);
let session_id = let session_id = utils::string_from_bytes(
utils::string_from_bytes(parts.next().ok_or_else(|| { parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.") Error::bad_database(
})?) "backupkeyid_backup key is invalid.",
.map_err(|_| { )
Error::bad_database("backupkeyid_backup session_id is invalid.") })?,
})?;
let room_id = RoomId::parse(
utils::string_from_bytes(parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.")
})?)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("backupkeyid_backup room_id is invalid room id.") Error::bad_database(
"backupkeyid_backup session_id is invalid.",
)
})?; })?;
let key_data = serde_json::from_slice(&value).map_err(|_| { let room_id = RoomId::parse(
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") utils::string_from_bytes(parts.next().ok_or_else(
|| {
Error::bad_database(
"backupkeyid_backup key is invalid.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"backupkeyid_backup room_id is invalid.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"backupkeyid_backup room_id is invalid room id.",
)
})?; })?;
let key_data =
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
})?;
Ok::<_, Error>((room_id, session_id, key_data)) Ok::<_, Error>((room_id, session_id, key_data))
}) })
{ {
@ -257,30 +290,38 @@ impl service::key_backups::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes()); prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
Ok(self Ok(self
.backupkeyid_backup .backupkeyid_backup
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(key, value)| { .map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = let session_id = utils::string_from_bytes(
utils::string_from_bytes(parts.next().ok_or_else(|| { parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.") Error::bad_database(
})?) "backupkeyid_backup key is invalid.",
.map_err(|_| { )
Error::bad_database("backupkeyid_backup session_id is invalid.") })?,
})?; )
.map_err(|_| {
let key_data = serde_json::from_slice(&value).map_err(|_| { Error::bad_database(
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") "backupkeyid_backup session_id is invalid.",
)
})?; })?;
let key_data =
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
})?;
Ok::<_, Error>((session_id, key_data)) Ok::<_, Error>((session_id, key_data))
}) })
.filter_map(Result::ok) .filter_map(Result::ok)
@ -295,18 +336,20 @@ impl service::key_backups::Data for KeyValueDatabase {
session_id: &str, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> { ) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(session_id.as_bytes()); key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup self.backupkeyid_backup
.get(&key)? .get(&key)?
.map(|value| { .map(|value| {
serde_json::from_slice(&value).map_err(|_| { serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
}) })
}) })
.transpose() .transpose()
@ -314,9 +357,9 @@ impl service::key_backups::Data for KeyValueDatabase {
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
@ -325,13 +368,18 @@ impl service::key_backups::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { fn delete_room_keys(
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
@ -348,11 +396,11 @@ impl service::key_backups::Data for KeyValueDatabase {
session_id: &str, session_id: &str,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(session_id.as_bytes()); key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {

View file

@ -12,22 +12,19 @@ impl service::media::Data for KeyValueDatabase {
content_type: Option<&str>, content_type: Option<&str>,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec(); let mut key = mxc.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice( key.extend_from_slice(
content_disposition content_disposition
.as_ref() .as_ref()
.map(|f| f.as_bytes()) .map(|f| f.as_bytes())
.unwrap_or_default(), .unwrap_or_default(),
); );
key.push(0xff); key.push(0xFF);
key.extend_from_slice( key.extend_from_slice(
content_type content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default(),
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
); );
self.mediaid_file.insert(&key, &[])?; self.mediaid_file.insert(&key, &[])?;
@ -42,24 +39,25 @@ impl service::media::Data for KeyValueDatabase {
height: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> { ) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes()); prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes()); prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xff); prefix.push(0xFF);
let (key, _) = self let (key, _) =
.mediaid_file self.mediaid_file.scan_prefix(prefix).next().ok_or(
.scan_prefix(prefix) Error::BadRequest(ErrorKind::NotFound, "Media not found"),
.next() )?;
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts let content_type = parts
.next() .next()
.map(|bytes| { .map(|bytes| {
utils::string_from_bytes(bytes).map_err(|_| { utils::string_from_bytes(bytes).map_err(|_| {
Error::bad_database("Content type in mediaid_file is invalid unicode.") Error::bad_database(
"Content type in mediaid_file is invalid unicode.",
)
}) })
}) })
.transpose()?; .transpose()?;
@ -71,11 +69,14 @@ impl service::media::Data for KeyValueDatabase {
let content_disposition = if content_disposition_bytes.is_empty() { let content_disposition = if content_disposition_bytes.is_empty() {
None None
} else { } else {
Some( Some(utils::string_from_bytes(content_disposition_bytes).map_err(
utils::string_from_bytes(content_disposition_bytes).map_err(|_| { |_| {
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.") Error::bad_database(
})?, "Content Disposition in mediaid_file is invalid \
) unicode.",
)
},
)?)
}; };
Ok((content_disposition, content_type, key)) Ok((content_disposition, content_type, key))
} }

View file

@ -6,30 +6,39 @@ use ruma::{
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::pusher::Data for KeyValueDatabase { impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { fn set_pusher(
&self,
sender: &UserId,
pusher: set_pusher::v3::PusherAction,
) -> Result<()> {
match &pusher { match &pusher {
set_pusher::v3::PusherAction::Post(data) => { set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher.insert( self.senderkey_pusher.insert(
&key, &key,
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), &serde_json::to_vec(&pusher)
.expect("Pusher is valid JSON value"),
)?; )?;
Ok(()) Ok(())
} }
set_pusher::v3::PusherAction::Delete(ids) => { set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes()); key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into) self.senderkey_pusher.remove(&key).map_err(Into::into)
} }
} }
} }
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec(); let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xff); senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes()); senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher self.senderkey_pusher
@ -43,7 +52,7 @@ impl service::pusher::Data for KeyValueDatabase {
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
self.senderkey_pusher self.senderkey_pusher
.scan_prefix(prefix) .scan_prefix(prefix)
@ -59,16 +68,20 @@ impl service::pusher::Data for KeyValueDatabase {
sender: &UserId, sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>> + 'a> { ) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xff); let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next(); let _senderkey = parts.next();
let push_key = parts let push_key = parts.next().ok_or_else(|| {
.next() Error::bad_database("Invalid senderkey_pusher in db")
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; })?;
let push_key_string = utils::string_from_bytes(push_key) let push_key_string =
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; utils::string_from_bytes(push_key).map_err(|_| {
Error::bad_database(
"Invalid pusher bytes in senderkey_pusher",
)
})?;
Ok(push_key_string) Ok(push_key_string)
})) }))

View file

@ -1,6 +1,11 @@
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use ruma::{
api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId,
RoomId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::alias::Data for KeyValueDatabase { impl service::rooms::alias::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -8,17 +13,20 @@ impl service::rooms::alias::Data for KeyValueDatabase {
self.alias_roomid self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?; .insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xff); aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); aliasid
.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { if let Some(room_id) =
self.alias_roomid.get(alias.alias().as_bytes())?
{
let mut prefix = room_id.clone(); let mut prefix = room_id.clone();
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) { for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?; self.aliasid_alias.remove(&key)?;
@ -34,14 +42,23 @@ impl service::rooms::alias::Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { fn resolve_local_alias(
&self,
alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>> {
self.alias_roomid self.alias_roomid
.get(alias.alias().as_bytes())? .get(alias.alias().as_bytes())?
.map(|bytes| { .map(|bytes| {
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { RoomId::parse(utils::string_from_bytes(&bytes).map_err(
Error::bad_database("Room ID in alias_roomid is invalid unicode.") |_| {
})?) Error::bad_database(
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) "Room ID in alias_roomid is invalid unicode.",
)
},
)?)
.map_err(|_| {
Error::bad_database("Room ID in alias_roomid is invalid.")
})
}) })
.transpose() .transpose()
} }
@ -52,13 +69,17 @@ impl service::rooms::alias::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? .map_err(|_| {
Error::bad_database("Invalid alias bytes in aliasid_alias.")
})?
.try_into() .try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) .map_err(|_| {
Error::bad_database("Invalid alias in aliasid_alias.")
})
})) }))
} }
} }

View file

@ -3,9 +3,13 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{database::KeyValueDatabase, service, utils, Result}; use crate::{database::KeyValueDatabase, service, utils, Result};
impl service::rooms::auth_chain::Data for KeyValueDatabase { impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { fn get_cached_eventid_authchain(
&self,
key: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache // Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key)
{
return Ok(Some(Arc::clone(result))); return Ok(Some(Arc::clone(result)));
} }
@ -18,7 +22,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
.map(|chain| { .map(|chain| {
chain chain
.chunks_exact(size_of::<u64>()) .chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) .map(|chunk| {
utils::u64_from_bytes(chunk)
.expect("byte length is correct")
})
.collect() .collect()
}); });
@ -38,7 +45,11 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
Ok(None) Ok(None)
} }
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { fn cache_auth_chain(
&self,
key: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()> {
// Only persist single events in db // Only persist single events in db
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( self.shorteventid_authchain.insert(
@ -51,10 +62,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
} }
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(()) Ok(())
} }

View file

@ -19,14 +19,18 @@ impl service::rooms::directory::Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn public_rooms<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| { Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database(
Error::bad_database("Room ID in publicroomids is invalid unicode.") "Room ID in publicroomids is invalid unicode.",
})?, )
) })?)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) .map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid.")
})
})) }))
} }
} }

View file

@ -1,10 +1,13 @@
use std::mem; use std::mem;
use ruma::{ use ruma::{
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject,
OwnedUserId, RoomId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update( fn readreceipt_update(
@ -14,7 +17,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
event: ReceiptEvent, event: ReceiptEvent,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -25,7 +28,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.iter_from(&last_possible_key, true) .iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix)) .take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| { .find(|(key, _)| {
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element") .expect("rsplit always returns an element")
== user_id.as_bytes() == user_id.as_bytes()
@ -36,13 +39,15 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
} }
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); room_latest_id
room_latest_id.push(0xff); .extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes()); room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert( self.readreceiptid_readreceipt.insert(
&room_latest_id, &room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"), &serde_json::to_vec(&event)
.expect("EduEvent::to_string always works"),
)?; )?;
Ok(()) Ok(())
@ -64,7 +69,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
> + 'a, > + 'a,
> { > {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let prefix2 = prefix.clone(); let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone(); let mut first_possible_edu = prefix.clone();
@ -79,21 +84,35 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
let count = utils::u64_from_bytes( let count = utils::u64_from_bytes(
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()], &k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
) )
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; .map_err(|_| {
Error::bad_database(
"Invalid readreceiptid count in db.",
)
})?;
let user_id = UserId::parse( let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..]) utils::string_from_bytes(
.map_err(|_| { &k[prefix.len() + mem::size_of::<u64>() + 1..],
Error::bad_database("Invalid readreceiptid userid bytes in db.") )
})?, .map_err(|_| {
Error::bad_database(
"Invalid readreceiptid userid bytes in db.",
)
})?,
) )
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; .map_err(|_| {
Error::bad_database(
"Invalid readreceiptid userid in db.",
)
})?;
let mut json = let mut json =
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| { serde_json::from_slice::<CanonicalJsonObject>(&v)
Error::bad_database( .map_err(|_| {
"Read receipt in roomlatestid_roomlatest is invalid json.", Error::bad_database(
) "Read receipt in roomlatestid_roomlatest \
})?; is invalid json.",
)
})?;
json.remove("room_id"); json.remove("room_id");
Ok(( Ok((
@ -109,36 +128,46 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { fn private_read_set(
&self,
room_id: &RoomId,
user_id: &UserId,
count: u64,
) -> Result<()> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes()) .insert(&key, &services().globals.next_count()?.to_be_bytes())
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
.get(&key)? Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
.map_or(Ok(None), |v| { Error::bad_database("Invalid private read marker bytes")
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { })?))
Error::bad_database("Invalid private read marker bytes") })
})?))
})
} }
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
Ok(self Ok(self
@ -146,7 +175,9 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") Error::bad_database(
"Count in roomuserid_lastprivatereadupdate is invalid.",
)
}) })
}) })
.transpose()? .transpose()?

View file

@ -11,11 +11,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
ll_user: &UserId, ll_user: &UserId,
) -> Result<bool> { ) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes()); key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some()) Ok(self.lazyloadedids.get(&key)?.is_some())
} }
@ -28,11 +28,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> { ) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes()); prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
for ll_id in confirmed_user_ids { for ll_id in confirmed_user_ids {
let mut key = prefix.clone(); let mut key = prefix.clone();
@ -50,11 +50,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Result<()> { ) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes()); prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) { for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?; self.lazyloadedids.remove(&key)?;

View file

@ -1,6 +1,8 @@
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::metadata::Data for KeyValueDatabase { impl service::rooms::metadata::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -19,14 +21,18 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
.is_some()) .is_some())
} }
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn iter_ids<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database(
Error::bad_database("Room ID in publicroomids is invalid unicode.") "Room ID in publicroomids is invalid unicode.",
})?, )
) })?)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) .map_err(|_| {
Error::bad_database("Room ID in roomid_shortroomid is invalid.")
})
})) }))
} }

View file

@ -3,24 +3,35 @@ use ruma::{CanonicalJsonObject, EventId};
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
impl service::rooms::outlier::Data for KeyValueDatabase { impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_outlier_pdu_json(
self.eventid_outlierpdu &self,
.get(event_id.as_bytes())? event_id: &EventId,
.map_or(Ok(None), |pdu| { ) -> Result<Option<CanonicalJsonObject>> {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(
}) Ok(None),
|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
},
)
} }
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(
.get(event_id.as_bytes())? Ok(None),
.map_or(Ok(None), |pdu| { |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) serde_json::from_slice(&pdu)
}) .map_err(|_| Error::bad_database("Invalid PDU in db."))
},
)
} }
#[tracing::instrument(skip(self, pdu))] #[tracing::instrument(skip(self, pdu))]
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { fn add_pdu_outlier(
&self,
event_id: &EventId,
pdu: &CanonicalJsonObject,
) -> Result<()> {
self.eventid_outlierpdu.insert( self.eventid_outlierpdu.insert(
event_id.as_bytes(), event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),

View file

@ -22,7 +22,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
shortroomid: u64, shortroomid: u64,
target: u64, target: u64,
until: PduCount, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let prefix = target.to_be_bytes().to_vec(); let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone(); let mut current = prefix.clone();
@ -40,8 +41,12 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
.iter_from(&current, true) .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| { .map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..]) let from = utils::u64_from_bytes(
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; &tofrom[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database("Invalid count in tofrom_relation.")
})?;
let mut pduid = shortroomid.to_be_bytes().to_vec(); let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes()); pduid.extend_from_slice(&from.to_be_bytes());
@ -50,7 +55,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
.rooms .rooms
.timeline .timeline
.get_pdu_from_id(&pduid)? .get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; .ok_or_else(|| {
Error::bad_database(
"Pdu in tofrom_relation is invalid.",
)
})?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
@ -59,7 +68,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
)) ))
} }
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { fn mark_as_referenced(
&self,
room_id: &RoomId,
event_ids: &[Arc<EventId>],
) -> Result<()> {
for prev in event_ids { for prev in event_ids {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes()); key.extend_from_slice(prev.as_bytes());
@ -69,7 +82,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes()); key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some()) Ok(self.referencedevents.get(&key)?.is_some())
@ -80,8 +97,6 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
} }
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
.get(event_id.as_bytes())
.map(|o| o.is_some())
} }
} }

View file

@ -4,7 +4,12 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result};
impl service::rooms::search::Data for KeyValueDatabase { impl service::rooms::search::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { fn index_pdu(
&self,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()> {
let mut batch = message_body let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
@ -13,7 +18,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
.map(|word| { .map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
key.push(0xff); key.push(0xFF);
// TODO: currently we save the room id a second time here // TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id); key.extend_from_slice(pdu_id);
(key, Vec::new()) (key, Vec::new())
@ -28,7 +33,8 @@ impl service::rooms::search::Data for KeyValueDatabase {
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
search_string: &str, search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>> { ) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>
{
let prefix = services() let prefix = services()
.rooms .rooms
.short .short
@ -46,7 +52,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
let iterators = words.clone().into_iter().map(move |word| { let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone(); let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes()); prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xff); prefix2.push(0xFF);
let prefix3 = prefix2.clone(); let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone(); let mut last_possible_id = prefix2.clone();
@ -60,7 +66,9 @@ impl service::rooms::search::Data for KeyValueDatabase {
}); });
// We compare b with a because we reversed the iterator earlier // We compare b with a because we reversed the iterator earlier
let Some(common_elements) = utils::common_elements(iterators, |a, b| b.cmp(a)) else { let Some(common_elements) =
utils::common_elements(iterators, |a, b| b.cmp(a))
else {
return Ok(None); return Ok(None);
}; };

View file

@ -2,26 +2,32 @@ use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::short::Data for KeyValueDatabase { impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { if let Some(short) =
self.eventidshort_cache.lock().unwrap().get_mut(event_id)
{
return Ok(*short); return Ok(*short);
} }
let short = let short = if let Some(shorteventid) =
if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { self.eventid_shorteventid.get(event_id.as_bytes())?
utils::u64_from_bytes(&shorteventid) {
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))? utils::u64_from_bytes(&shorteventid).map_err(|_| {
} else { Error::bad_database("Invalid shorteventid in db.")
let shorteventid = services().globals.next_count()?; })?
self.eventid_shorteventid } else {
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; let shorteventid = services().globals.next_count()?;
self.shorteventid_eventid self.eventid_shorteventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
shorteventid self.shorteventid_eventid
}; .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
self.eventidshort_cache self.eventidshort_cache
.lock() .lock()
@ -46,15 +52,16 @@ impl service::rooms::short::Data for KeyValueDatabase {
} }
let mut db_key = event_type.to_string().as_bytes().to_vec(); let mut db_key = event_type.to_string().as_bytes().to_vec();
db_key.push(0xff); db_key.push(0xFF);
db_key.extend_from_slice(state_key.as_bytes()); db_key.extend_from_slice(state_key.as_bytes());
let short = self let short = self
.statekey_shortstatekey .statekey_shortstatekey
.get(&db_key)? .get(&db_key)?
.map(|shortstatekey| { .map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey) utils::u64_from_bytes(&shortstatekey).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) Error::bad_database("Invalid shortstatekey in db.")
})
}) })
.transpose()?; .transpose()?;
@ -83,12 +90,15 @@ impl service::rooms::short::Data for KeyValueDatabase {
} }
let mut db_key = event_type.to_string().as_bytes().to_vec(); let mut db_key = event_type.to_string().as_bytes().to_vec();
db_key.push(0xff); db_key.push(0xFF);
db_key.extend_from_slice(state_key.as_bytes()); db_key.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&db_key)? { let short = if let Some(shortstatekey) =
utils::u64_from_bytes(&shortstatekey) self.statekey_shortstatekey.get(&db_key)?
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| {
Error::bad_database("Invalid shortstatekey in db.")
})?
} else { } else {
let shortstatekey = services().globals.next_count()?; let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey self.statekey_shortstatekey
@ -106,12 +116,12 @@ impl service::rooms::short::Data for KeyValueDatabase {
Ok(short) Ok(short)
} }
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { fn get_eventid_from_short(
if let Some(id) = self &self,
.shorteventid_cache shorteventid: u64,
.lock() ) -> Result<Arc<EventId>> {
.unwrap() if let Some(id) =
.get_mut(&shorteventid) self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid)
{ {
return Ok(Arc::clone(id)); return Ok(Arc::clone(id));
} }
@ -119,12 +129,20 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self let bytes = self
.shorteventid_eventid .shorteventid_eventid
.get(&shorteventid.to_be_bytes())? .get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; .ok_or_else(|| {
Error::bad_database("Shorteventid does not exist")
})?;
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { let event_id = EventId::parse_arc(
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") utils::string_from_bytes(&bytes).map_err(|_| {
})?) Error::bad_database(
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; "EventID in shorteventid_eventid is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database("EventId in shorteventid_eventid is invalid.")
})?;
self.shorteventid_cache self.shorteventid_cache
.lock() .lock()
@ -134,12 +152,12 @@ impl service::rooms::short::Data for KeyValueDatabase {
Ok(event_id) Ok(event_id)
} }
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { fn get_statekey_from_short(
if let Some(id) = self &self,
.shortstatekey_cache shortstatekey: u64,
.lock() ) -> Result<(StateEventType, String)> {
.unwrap() if let Some(id) =
.get_mut(&shortstatekey) self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey)
{ {
return Ok(id.clone()); return Ok(id.clone());
} }
@ -147,23 +165,32 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self let bytes = self
.shortstatekey_statekey .shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())? .get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; .ok_or_else(|| {
Error::bad_database("Shortstatekey does not exist")
})?;
let mut parts = bytes.splitn(2, |&b| b == 0xff); let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry"); let eventtype_bytes =
let statekey_bytes = parts parts.next().expect("split always returns one entry");
.next() let statekey_bytes = parts.next().ok_or_else(|| {
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; Error::bad_database("Invalid statekey in shortstatekey_statekey.")
let event_type =
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
})?; })?;
let event_type = StateEventType::from(
utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database(
"Event type in shortstatekey_statekey is invalid unicode.",
)
})?,
);
let state_key =
utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database(
"Statekey in shortstatekey_statekey is invalid unicode.",
)
})?;
let result = (event_type, state_key); let result = (event_type, state_key);
self.shortstatekey_cache self.shortstatekey_cache
@ -175,12 +202,18 @@ impl service::rooms::short::Data for KeyValueDatabase {
} }
/// Returns `(shortstatehash, already_existed)` /// Returns `(shortstatehash, already_existed)`
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)> {
Ok( Ok(
if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { if let Some(shortstatehash) =
self.statehash_shortstatehash.get(state_hash)?
{
( (
utils::u64_from_bytes(&shortstatehash) utils::u64_from_bytes(&shortstatehash).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, Error::bad_database("Invalid shortstatehash in db.")
})?,
true, true,
) )
} else { } else {
@ -196,17 +229,21 @@ impl service::rooms::short::Data for KeyValueDatabase {
self.roomid_shortroomid self.roomid_shortroomid
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid shortroomid in db.")) Error::bad_database("Invalid shortroomid in db.")
})
}) })
.transpose() .transpose()
} }
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok( Ok(
if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { if let Some(short) =
utils::u64_from_bytes(&short) self.roomid_shortroomid.get(room_id.as_bytes())?
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))? {
utils::u64_from_bytes(&short).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})?
} else { } else {
let short = services().globals.next_count()?; let short = services().globals.next_count()?;
self.roomid_shortroomid self.roomid_shortroomid

View file

@ -1,20 +1,22 @@
use ruma::{EventId, OwnedEventId, RoomId}; use std::{collections::HashSet, sync::Arc};
use std::collections::HashSet;
use std::sync::Arc; use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard; use tokio::sync::MutexGuard;
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase { impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(
.get(room_id.as_bytes())? Ok(None),
.map_or(Ok(None), |bytes| { |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") Error::bad_database(
"Invalid shortstatehash in roomid_shortstatehash",
)
})?)) })?))
}) },
)
} }
fn set_room_state( fn set_room_state(
@ -29,23 +31,40 @@ impl service::rooms::state::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { fn set_event_state(
self.shorteventid_shortstatehash &self,
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; shorteventid: u64,
shortstatehash: u64,
) -> Result<()> {
self.shorteventid_shortstatehash.insert(
&shorteventid.to_be_bytes(),
&shortstatehash.to_be_bytes(),
)?;
Ok(()) Ok(())
} }
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { fn get_forward_extremities(
&self,
room_id: &RoomId,
) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
self.roomid_pduleaves self.roomid_pduleaves
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(_, bytes)| { .map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") |_| {
})?) Error::bad_database(
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) "EventID in roomid_pduleaves is invalid unicode.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"EventId in roomid_pduleaves is invalid.",
)
})
}) })
.collect() .collect()
} }
@ -58,7 +77,7 @@ impl service::rooms::state::Data for KeyValueDatabase {
_mutex_lock: &MutexGuard<'_, ()>, _mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?; self.roomid_pduleaves.remove(&key)?;

View file

@ -1,12 +1,19 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use async_trait::async_trait; use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
#[async_trait] #[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase { impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { async fn state_full_ids(
&self,
shortstatehash: u64,
) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services() let full_state = services()
.rooms .rooms
.state_compressor .state_compressor
@ -56,7 +63,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
pdu.kind.to_string().into(), pdu.kind.to_string().into(),
pdu.state_key pdu.state_key
.as_ref() .as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))? .ok_or_else(|| {
Error::bad_database(
"State event has no state key.",
)
})?
.clone(), .clone(),
), ),
pdu, pdu,
@ -72,17 +83,16 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
Ok(result) Ok(result)
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id( fn state_get_id(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services() let Some(shortstatekey) =
.rooms services().rooms.short.get_shortstatekey(event_type, state_key)?
.short
.get_shortstatekey(event_type, state_key)?
else { else {
return Ok(None); return Ok(None);
}; };
@ -106,7 +116,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
})) }))
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get( fn state_get(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
@ -121,20 +132,22 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(
.get(event_id.as_bytes())? Ok(None),
.map_or(Ok(None), |shorteventid| { |shorteventid| {
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.get(&shorteventid)? .get(&shorteventid)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database( Error::bad_database(
"Invalid shortstatehash bytes in shorteventid_shortstatehash", "Invalid shortstatehash bytes in \
shorteventid_shortstatehash",
) )
}) })
}) })
.transpose() .transpose()
}) },
)
} }
/// Returns the full room state. /// Returns the full room state.
@ -151,7 +164,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
} }
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id( fn room_state_get_id(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
@ -167,7 +181,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
} }
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get( fn room_state_get(
&self, &self,
room_id: &RoomId, room_id: &RoomId,

View file

@ -13,20 +13,24 @@ use crate::{
}; };
impl service::rooms::state_cache::Data for KeyValueDatabase { impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_once_joined(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[]) self.roomuseroncejoinedids.insert(&userroom_id, &[])
} }
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?; self.userroomid_joined.insert(&userroom_id, &[])?;
@ -46,11 +50,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) -> Result<()> { ) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert( self.userroomid_invitestate.insert(
@ -72,11 +76,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
// TODO // TODO
@ -112,7 +116,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
joinedcount += 1; joinedcount += 1;
} }
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { for _invited in
self.room_members_invited(room_id).filter_map(Result::ok)
{
invitedcount += 1; invitedcount += 1;
} }
@ -127,15 +133,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.unwrap() .unwrap()
.insert(room_id.to_owned(), Arc::new(real_users)); .insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { for old_joined_server in
self.room_servers(room_id).filter_map(Result::ok)
{
if !joined_servers.remove(&old_joined_server) { if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore // Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec(); let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff); roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes()); roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec(); let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xff); serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes()); serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?; self.roomserverids.remove(&roomserver_id)?;
@ -146,49 +154,45 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
// Now only new servers are in joined_servers anymore // Now only new servers are in joined_servers anymore
for server in joined_servers { for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec(); let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff); roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes()); roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec(); let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xff); serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes()); serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?; self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?; self.serverroomids.insert(&serverroom_id, &[])?;
} }
self.appservice_in_room_cache self.appservice_in_room_cache.write().unwrap().remove(room_id);
.write()
.unwrap()
.remove(room_id);
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, room_id))] #[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> { fn get_our_real_users(
let maybe = self &self,
.our_real_users_cache room_id: &RoomId,
.read() ) -> Result<Arc<HashSet<OwnedUserId>>> {
.unwrap() let maybe =
.get(room_id) self.our_real_users_cache.read().unwrap().get(room_id).cloned();
.cloned();
if let Some(users) = maybe { if let Some(users) = maybe {
Ok(users) Ok(users)
} else { } else {
self.update_joined_count(room_id)?; self.update_joined_count(room_id)?;
Ok(Arc::clone( Ok(Arc::clone(
self.our_real_users_cache self.our_real_users_cache.read().unwrap().get(room_id).unwrap(),
.read()
.unwrap()
.get(room_id)
.unwrap(),
)) ))
} }
} }
#[tracing::instrument(skip(self, room_id, appservice))] #[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { fn appservice_in_room(
&self,
room_id: &RoomId,
appservice: &RegistrationInfo,
) -> Result<bool> {
let maybe = self let maybe = self
.appservice_in_room_cache .appservice_in_room_cache
.read() .read()
@ -206,11 +210,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
) )
.ok(); .ok();
let in_room = bridge_user_id let in_room = bridge_user_id.map_or(false, |id| {
.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) self.is_joined(&id, room_id).unwrap_or(false)
|| self.room_members(room_id).any(|userid| { }) || self.room_members(room_id).any(|userid| {
userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())) userid.map_or(false, |userid| {
}); appservice.users.is_match(userid.as_str())
})
});
self.appservice_in_room_cache self.appservice_in_room_cache
.write() .write()
@ -227,11 +233,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
@ -247,51 +253,66 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse( ServerName::parse(
utils::string_from_bytes( utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("Server name in roomserverids is invalid unicode.") Error::bad_database(
"Server name in roomserverids is invalid unicode.",
)
})?, })?,
) )
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) .map_err(|_| {
Error::bad_database("Server name in roomserverids is invalid.")
})
})) }))
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { fn server_in_room(
&self,
server: &ServerName,
room_id: &RoomId,
) -> Result<bool> {
let mut key = server.as_bytes().to_vec(); let mut key = server.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some()) self.serverroomids.get(&key).map(|o| o.is_some())
} }
/// Returns an iterator of all rooms a server participates in (as far as we know). /// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn server_rooms<'a>( fn server_rooms<'a>(
&'a self, &'a self,
server: &ServerName, server: &ServerName,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec(); let mut prefix = server.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes( utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, .map_err(|_| {
Error::bad_database(
"RoomId in serverroomids is invalid unicode.",
)
})?,
) )
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) .map_err(|_| {
Error::bad_database("RoomId in serverroomids is invalid.")
})
})) }))
} }
@ -302,20 +323,24 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse( UserId::parse(
utils::string_from_bytes( utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid unicode.") Error::bad_database(
"User ID in roomuserid_joined is invalid unicode.",
)
})?, })?,
) )
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) .map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid.")
})
})) }))
} }
@ -324,8 +349,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomid_joinedcount self.roomid_joinedcount
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|b| { .map(|b| {
utils::u64_from_bytes(&b) utils::u64_from_bytes(&b).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid joinedcount in db.")) Error::bad_database("Invalid joinedcount in db.")
})
}) })
.transpose() .transpose()
} }
@ -335,8 +361,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomid_invitedcount self.roomid_invitedcount
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|b| { .map(|b| {
utils::u64_from_bytes(&b) utils::u64_from_bytes(&b).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid joinedcount in db.")) Error::bad_database("Invalid joinedcount in db.")
})
}) })
.transpose() .transpose()
} }
@ -348,27 +375,30 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new( Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(
self.roomuseroncejoinedids |(key, _)| {
.scan_prefix(prefix) UserId::parse(
.map(|(key, _)| { utils::string_from_bytes(
UserId::parse( key.rsplit(|&b| b == 0xFF)
utils::string_from_bytes( .next()
key.rsplit(|&b| b == 0xff) .expect("rsplit always returns an element"),
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database(
"User ID in room_useroncejoined is invalid unicode.",
)
})?,
) )
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) .map_err(|_| {
}), Error::bad_database(
) "User ID in room_useroncejoined is invalid \
unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"User ID in room_useroncejoined is invalid.",
)
})
},
))
} }
/// Returns an iterator over all invited members of a room. /// Returns an iterator over all invited members of a room.
@ -378,53 +408,64 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new( Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(
self.roomuserid_invitecount |(key, _)| {
.scan_prefix(prefix) UserId::parse(
.map(|(key, _)| { utils::string_from_bytes(
UserId::parse( key.rsplit(|&b| b == 0xFF)
utils::string_from_bytes( .next()
key.rsplit(|&b| b == 0xff) .expect("rsplit always returns an element"),
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
})?,
) )
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) .map_err(|_| {
}), Error::bad_database(
) "User ID in roomuserid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"User ID in roomuserid_invited is invalid.",
)
})
},
))
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn get_invite_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| {
.get(&key)? Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
.map_or(Ok(None), |bytes| { Error::bad_database("Invalid invitecount in db.")
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { })?))
Error::bad_database("Invalid invitecount in db.") })
})?))
})
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn get_left_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount self.roomuserid_leftcount
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid leftcount in db.")) Error::bad_database("Invalid leftcount in db.")
})
}) })
.transpose() .transpose()
} }
@ -441,15 +482,22 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.map(|(key, _)| { .map(|(key, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes( utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF)
.next() .next()
.expect("rsplit always returns an element"), .expect("rsplit always returns an element"),
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("Room ID in userroomid_joined is invalid unicode.") Error::bad_database(
"Room ID in userroomid_joined is invalid \
unicode.",
)
})?, })?,
) )
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) .map_err(|_| {
Error::bad_database(
"Room ID in userroomid_joined is invalid.",
)
})
}), }),
) )
} }
@ -460,35 +508,43 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn rooms_invited<'a>( fn rooms_invited<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a> { ) -> Box<
dyn Iterator<
Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>,
> + 'a,
> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new( Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(
self.userroomid_invitestate |(key, state)| {
.scan_prefix(prefix) let room_id = RoomId::parse(
.map(|(key, state)| { utils::string_from_bytes(
let room_id = RoomId::parse( key.rsplit(|&b| b == 0xFF)
utils::string_from_bytes( .next()
key.rsplit(|&b| b == 0xff) .expect("rsplit always returns an element"),
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid.") Error::bad_database(
})?; "Room ID in userroomid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"Room ID in userroomid_invited is invalid.",
)
})?;
let state = serde_json::from_slice(&state).map_err(|_| { let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database("Invalid state in userroomid_invitestate.") Error::bad_database(
})?; "Invalid state in userroomid_invitestate.",
)
})?;
Ok((room_id, state)) Ok((room_id, state))
}), },
) ))
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -498,14 +554,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate self.userroomid_invitestate
.get(&key)? .get(&key)?
.map(|state| { .map(|state| {
let state = serde_json::from_slice(&state) let state = serde_json::from_slice(&state).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; Error::bad_database(
"Invalid state in userroomid_invitestate.",
)
})?;
Ok(state) Ok(state)
}) })
@ -519,14 +578,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate self.userroomid_leftstate
.get(&key)? .get(&key)?
.map(|state| { .map(|state| {
let state = serde_json::from_slice(&state) let state = serde_json::from_slice(&state).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; Error::bad_database(
"Invalid state in userroomid_leftstate.",
)
})?;
Ok(state) Ok(state)
}) })
@ -539,41 +601,48 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn rooms_left<'a>( fn rooms_left<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a> { ) -> Box<
dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>>
+ 'a,
> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
Box::new( Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(
self.userroomid_leftstate |(key, state)| {
.scan_prefix(prefix) let room_id = RoomId::parse(
.map(|(key, state)| { utils::string_from_bytes(
let room_id = RoomId::parse( key.rsplit(|&b| b == 0xFF)
utils::string_from_bytes( .next()
key.rsplit(|&b| b == 0xff) .expect("rsplit always returns an element"),
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
) )
.map_err(|_| { .map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid.") Error::bad_database(
})?; "Room ID in userroomid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"Room ID in userroomid_invited is invalid.",
)
})?;
let state = serde_json::from_slice(&state).map_err(|_| { let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database("Invalid state in userroomid_leftstate.") Error::bad_database(
})?; "Invalid state in userroomid_leftstate.",
)
})?;
Ok((room_id, state)) Ok((room_id, state))
}), },
) ))
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
@ -582,7 +651,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
@ -591,7 +660,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
@ -600,7 +669,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())

View file

@ -12,8 +12,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
.shortstatehash_statediff .shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())? .get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?; .ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); .expect("bytes have right length");
let parent = (parent != 0).then_some(parent); let parent = (parent != 0).then_some(parent);
let mut add_mode = true; let mut add_mode = true;
@ -30,7 +30,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
if add_mode { if add_mode {
added.insert(v.try_into().expect("we checked the size above")); added.insert(v.try_into().expect("we checked the size above"));
} else { } else {
removed.insert(v.try_into().expect("we checked the size above")); removed
.insert(v.try_into().expect("we checked the size above"));
} }
i += 2 * size_of::<u64>(); i += 2 * size_of::<u64>();
} }
@ -42,7 +43,11 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
}) })
} }
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { fn save_statediff(
&self,
shortstatehash: u64,
diff: StateDiff,
) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() { for new in diff.added.iter() {
value.extend_from_slice(&new[..]); value.extend_from_slice(&new[..]);

View file

@ -1,8 +1,14 @@
use std::mem; use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use ruma::{
api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId,
UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
impl service::rooms::threads::Data for KeyValueDatabase { impl service::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>( fn threads_until<'a>(
@ -28,14 +34,22 @@ impl service::rooms::threads::Data for KeyValueDatabase {
.iter_from(&current, true) .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| { .map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..]) let count = utils::u64_from_bytes(
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; &pduid[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database(
"Invalid pduid in threadid_userids.",
)
})?;
let mut pdu = services() let mut pdu = services()
.rooms .rooms
.timeline .timeline
.get_pdu_from_id(&pduid)? .get_pdu_from_id(&pduid)?
.ok_or_else(|| { .ok_or_else(|| {
Error::bad_database("Invalid pduid reference in threadid_userids") Error::bad_database(
"Invalid pduid reference in threadid_userids",
)
})?; })?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
@ -45,28 +59,43 @@ impl service::rooms::threads::Data for KeyValueDatabase {
)) ))
} }
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { fn update_participants(
&self,
root_id: &[u8],
participants: &[OwnedUserId],
) -> Result<()> {
let users = participants let users = participants
.iter() .iter()
.map(|user| user.as_bytes()) .map(|user| user.as_bytes())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(&[0xff][..]); .join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?; self.threadid_userids.insert(root_id, &users)?;
Ok(()) Ok(())
} }
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> { fn get_participants(
&self,
root_id: &[u8],
) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? { if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some( Ok(Some(
users users
.split(|b| *b == 0xff) .split(|b| *b == 0xFF)
.map(|bytes| { .map(|bytes| {
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| { UserId::parse(utils::string_from_bytes(bytes).map_err(
Error::bad_database("Invalid UserId bytes in threadid_userids.") |_| {
})?) Error::bad_database(
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) "Invalid UserId bytes in threadid_userids.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"Invalid UserId in threadid_userids.",
)
})
}) })
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(), .collect(),

View file

@ -1,16 +1,23 @@
use std::{collections::hash_map, mem::size_of, sync::Arc}; use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{ use ruma::{
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId,
RoomId, UserId,
}; };
use service::rooms::timeline::PduCount;
use tracing::error; use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
use service::rooms::timeline::PduCount; Result,
};
impl service::rooms::timeline::Data for KeyValueDatabase { impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { fn last_timeline_count(
&self,
sender_user: &UserId,
room_id: &RoomId,
) -> Result<PduCount> {
match self match self
.lasttimelinecount_cache .lasttimelinecount_cache
.lock() .lock()
@ -45,14 +52,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else( self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| { || {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu| { .map(|pdu| {
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid PDU in db.")) Error::bad_database("Invalid PDU in db.")
})
}) })
.transpose() .transpose()
}, },
@ -61,17 +72,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
.get(&pduid)? Error::bad_database("Invalid pduid in eventid_pduid.")
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) })
}) })
.transpose()? .transpose()?
.map(|pdu| { .map(|pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
}) })
.transpose() .transpose()
} }
@ -82,17 +97,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
/// Returns the pdu. /// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_non_outlier_pdu(
&self,
event_id: &EventId,
) -> Result<Option<PduEvent>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
.get(&pduid)? Error::bad_database("Invalid pduid in eventid_pduid.")
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) })
}) })
.transpose()? .transpose()?
.map(|pdu| { .map(|pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
}) })
.transpose() .transpose()
} }
@ -112,8 +131,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu| { .map(|pdu| {
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid PDU in db.")) Error::bad_database("Invalid PDU in db.")
})
}) })
.transpose() .transpose()
}, },
@ -144,7 +164,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json_from_id(
&self,
pdu_id: &[u8],
) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu)
@ -162,7 +185,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?; )?;
self.lasttimelinecount_cache self.lasttimelinecount_cache
@ -184,7 +208,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?; )?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
@ -203,7 +228,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
if self.pduid_pdu.get(pdu_id)?.is_some() { if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(pdu_json)
.expect("CanonicalJsonObject is always a valid"),
)?; )?;
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -212,22 +238,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
)); ));
} }
self.pdu_cache self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
.lock()
.unwrap()
.remove(&(*pdu.event_id).to_owned());
Ok(()) Ok(())
} }
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that
/// event with id `until` in reverse-chronological order. /// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>( fn pdus_until<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
until: PduCount, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let (prefix, current) = count_to_id(room_id, until, 1, true)?; let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
@ -238,7 +263,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; .map_err(|_| {
Error::bad_database("PDU in db is invalid.")
})?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
@ -254,7 +281,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: PduCount, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let (prefix, current) = count_to_id(room_id, from, 1, false)?; let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
@ -265,7 +293,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; .map_err(|_| {
Error::bad_database("PDU in db is invalid.")
})?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
@ -286,13 +316,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let mut highlights_batch = Vec::new(); let mut highlights_batch = Vec::new();
for user in notifies { for user in notifies {
let mut userroom_id = user.as_bytes().to_vec(); let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id); notifies_batch.push(userroom_id);
} }
for user in highlights { for user in highlights {
let mut userroom_id = user.as_bytes().to_vec(); let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id); highlights_batch.push(userroom_id);
} }
@ -307,10 +337,12 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) let last_u64 =
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 = utils::u64_from_bytes( let second_last_u64 = utils::u64_from_bytes(
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()], &pdu_id[pdu_id.len() - 2 * size_of::<u64>()
..pdu_id.len() - size_of::<u64>()],
); );
if matches!(second_last_u64, Ok(0)) { if matches!(second_last_u64, Ok(0)) {
@ -330,7 +362,9 @@ fn count_to_id(
.rooms .rooms
.short .short
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? .ok_or_else(|| {
Error::bad_database("Looked for bad shortroomid in timeline")
})?
.to_be_bytes() .to_be_bytes()
.to_vec(); .to_vec();
let mut pdu_id = prefix.clone(); let mut pdu_id = prefix.clone();

View file

@ -1,14 +1,20 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::user::Data for KeyValueDatabase { impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn reset_notification_counts(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount self.userroomid_notificationcount
@ -24,35 +30,51 @@ impl service::rooms::user::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn notification_count(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount self.userroomid_notificationcount.get(&userroom_id)?.map_or(
.get(&userroom_id)? Ok(0),
.map_or(Ok(0), |bytes| { |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid notification count in db.")) Error::bad_database("Invalid notification count in db.")
}) })
},
)
} }
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn highlight_count(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount self.userroomid_highlightcount.get(&userroom_id)?.map_or(
.get(&userroom_id)? Ok(0),
.map_or(Ok(0), |bytes| { |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid highlight count in db.")) Error::bad_database("Invalid highlight count in db.")
}) })
},
)
} }
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn last_notification_read(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
Ok(self Ok(self
@ -60,7 +82,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") Error::bad_database(
"Count in roomuserid_lastprivatereadupdate is invalid.",
)
}) })
}) })
.transpose()? .transpose()?
@ -86,7 +110,11 @@ impl service::rooms::user::Data for KeyValueDatabase {
.insert(&key, &shortstatehash.to_be_bytes()) .insert(&key, &shortstatehash.to_be_bytes())
} }
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<Option<u64>> {
let shortroomid = services() let shortroomid = services()
.rooms .rooms
.short .short
@ -100,7 +128,10 @@ impl service::rooms::user::Data for KeyValueDatabase {
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") Error::bad_database(
"Invalid shortstatehash in \
roomsynctoken_shortstatehash",
)
}) })
}) })
.transpose() .transpose()
@ -112,7 +143,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| { let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
self.userroomid_joined self.userroomid_joined
.scan_prefix(prefix) .scan_prefix(prefix)
@ -121,8 +152,12 @@ impl service::rooms::user::Data for KeyValueDatabase {
let roomid_index = key let roomid_index = key
.iter() .iter()
.enumerate() .enumerate()
.find(|(_, &b)| b == 0xff) .find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? .ok_or_else(|| {
Error::bad_database(
"Invalid userroomid_joined in db.",
)
})?
.0 .0
+ 1; + 1;
@ -133,15 +168,24 @@ impl service::rooms::user::Data for KeyValueDatabase {
.filter_map(Result::ok) .filter_map(Result::ok)
}); });
// We use the default compare function because keys are sorted correctly (not reversed) // We use the default compare function because keys are sorted correctly
// (not reversed)
Ok(Box::new( Ok(Box::new(
utils::common_elements(iterators, Ord::cmp) utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty") .expect("users is not empty")
.map(|bytes| { .map(|bytes| {
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { RoomId::parse(utils::string_from_bytes(&bytes).map_err(
Error::bad_database("Invalid RoomId bytes in userroomid_joined") |_| {
})?) Error::bad_database(
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) "Invalid RoomId bytes in userroomid_joined",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"Invalid RoomId in userroomid_joined.",
)
})
}), }),
)) ))
} }

View file

@ -12,31 +12,34 @@ use crate::{
impl service::sending::Data for KeyValueDatabase { impl service::sending::Data for KeyValueDatabase {
fn active_requests<'a>( fn active_requests<'a>(
&'a self, &'a self,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> { ) -> Box<
Box::new( dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>>
self.servercurrentevent_data + 'a,
.iter() > {
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), Box::new(self.servercurrentevent_data.iter().map(|(key, v)| {
) parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))
}))
} }
fn active_requests_for<'a>( fn active_requests_for<'a>(
&'a self, &'a self,
outgoing_kind: &OutgoingKind, outgoing_kind: &OutgoingKind,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> { ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a>
{
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
Box::new( Box::new(self.servercurrentevent_data.scan_prefix(prefix).map(
self.servercurrentevent_data |(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e)),
.scan_prefix(prefix) ))
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
} }
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
self.servercurrentevent_data.remove(&key) self.servercurrentevent_data.remove(&key)
} }
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { fn delete_all_active_requests_for(
&self,
outgoing_kind: &OutgoingKind,
) -> Result<()> {
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?; self.servercurrentevent_data.remove(&key)?;
@ -45,9 +48,13 @@ impl service::sending::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { fn delete_all_requests_for(
&self,
outgoing_kind: &OutgoingKind,
) -> Result<()> {
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone())
{
self.servercurrentevent_data.remove(&key).unwrap(); self.servercurrentevent_data.remove(&key).unwrap();
} }
@ -69,7 +76,9 @@ impl service::sending::Data for KeyValueDatabase {
if let SendingEventType::Pdu(value) = &event { if let SendingEventType::Pdu(value) = &event {
key.extend_from_slice(value); key.extend_from_slice(value);
} else { } else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); key.extend_from_slice(
&services().globals.next_count()?.to_be_bytes(),
);
} }
let value = if let SendingEventType::Edu(value) = &event { let value = if let SendingEventType::Edu(value) = &event {
&**value &**value
@ -79,24 +88,25 @@ impl service::sending::Data for KeyValueDatabase {
batch.push((key.clone(), value.to_owned())); batch.push((key.clone(), value.to_owned()));
keys.push(key); keys.push(key);
} }
self.servernameevent_data self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
.insert_batch(&mut batch.into_iter())?;
Ok(keys) Ok(keys)
} }
fn queued_requests<'a>( fn queued_requests<'a>(
&'a self, &'a self,
outgoing_kind: &OutgoingKind, outgoing_kind: &OutgoingKind,
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> { ) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a>
{
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
return Box::new( return Box::new(self.servernameevent_data.scan_prefix(prefix).map(
self.servernameevent_data |(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k)),
.scan_prefix(prefix) ));
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
} }
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> { fn mark_as_active(
&self,
events: &[(SendingEventType, Vec<u8>)],
) -> Result<()> {
for (e, key) in events { for (e, key) in events {
let value = if let SendingEventType::Edu(value) = &e { let value = if let SendingEventType::Edu(value) = &e {
&**value &**value
@ -110,18 +120,24 @@ impl service::sending::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { fn set_latest_educount(
&self,
server_name: &ServerName,
last_count: u64,
) -> Result<()> {
self.servername_educount self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes()) .insert(server_name.as_bytes(), &last_count.to_be_bytes())
} }
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount self.servername_educount.get(server_name.as_bytes())?.map_or(
.get(server_name.as_bytes())? Ok(0),
.map_or(Ok(0), |bytes| { |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) Error::bad_database("Invalid u64 in servername_educount.")
}) })
},
)
} }
} }
@ -132,15 +148,17 @@ fn parse_servercurrentevent(
) -> Result<(OutgoingKind, SendingEventType)> { ) -> Result<(OutgoingKind, SendingEventType)> {
// Appservices start with a plus // Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") { Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xff); let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts let event = parts.next().ok_or_else(|| {
.next() Error::bad_database("Invalid bytes in servercurrentpdus.")
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; })?;
let server = utils::string_from_bytes(server).map_err(|_| { let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction") Error::bad_database(
"Invalid server bytes in server_currenttransaction",
)
})?; })?;
( (
@ -152,23 +170,27 @@ fn parse_servercurrentevent(
}, },
) )
} else if key.starts_with(b"$") { } else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xff); let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element"); let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user) let user_string = utils::string_from_bytes(user).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; Error::bad_database("Invalid user string in servercurrentevent")
let user_id = UserId::parse(user_string) })?;
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; let user_id = UserId::parse(user_string).map_err(|_| {
Error::bad_database("Invalid user id in servercurrentevent")
})?;
let pushkey = parts let pushkey = parts.next().ok_or_else(|| {
.next() Error::bad_database("Invalid bytes in servercurrentpdus.")
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; })?;
let pushkey_string = utils::string_from_bytes(pushkey) let pushkey_string =
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; utils::string_from_bytes(pushkey).map_err(|_| {
Error::bad_database("Invalid pushkey in servercurrentevent")
})?;
let event = parts let event = parts.next().ok_or_else(|| {
.next() Error::bad_database("Invalid bytes in servercurrentpdus.")
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; })?;
( (
OutgoingKind::Push(user_id, pushkey_string), OutgoingKind::Push(user_id, pushkey_string),
@ -180,20 +202,24 @@ fn parse_servercurrentevent(
}, },
) )
} else { } else {
let mut parts = key.splitn(2, |&b| b == 0xff); let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts let event = parts.next().ok_or_else(|| {
.next() Error::bad_database("Invalid bytes in servercurrentpdus.")
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; })?;
let server = utils::string_from_bytes(server).map_err(|_| { let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction") Error::bad_database(
"Invalid server bytes in server_currenttransaction",
)
})?; })?;
( (
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction") Error::bad_database(
"Invalid server string in server_currenttransaction",
)
})?), })?),
if value.is_empty() { if value.is_empty() {
SendingEventType::Pdu(event.to_vec()) SendingEventType::Pdu(event.to_vec())

View file

@ -11,9 +11,11 @@ impl service::transaction_ids::Data for KeyValueDatabase {
data: &[u8], data: &[u8],
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); key.extend_from_slice(
key.push(0xff); device_id.map(DeviceId::as_bytes).unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes()); key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?; self.userdevicetxnid_response.insert(&key, data)?;
@ -28,9 +30,11 @@ impl service::transaction_ids::Data for KeyValueDatabase {
txn_id: &TransactionId, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); key.extend_from_slice(
key.push(0xff); device_id.map(DeviceId::as_bytes).unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes()); key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction // If there's no entry, this is a new transaction

View file

@ -13,13 +13,10 @@ impl service::uiaa::Data for KeyValueDatabase {
session: &str, session: &str,
request: &CanonicalJsonValue, request: &CanonicalJsonValue,
) -> Result<()> { ) -> Result<()> {
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest.write().unwrap().insert(
.write() (user_id.to_owned(), device_id.to_owned(), session.to_owned()),
.unwrap() request.to_owned(),
.insert( );
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(()) Ok(())
} }
@ -33,7 +30,11 @@ impl service::uiaa::Data for KeyValueDatabase {
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest
.read() .read()
.unwrap() .unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) .get(&(
user_id.to_owned(),
device_id.to_owned(),
session.to_owned(),
))
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
} }
@ -45,19 +46,19 @@ impl service::uiaa::Data for KeyValueDatabase {
uiaainfo: Option<&UiaaInfo>, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> { ) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec(); let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff); userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes()); userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff); userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes()); userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo { if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert( self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid, &userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), &serde_json::to_vec(&uiaainfo)
.expect("UiaaInfo::to_vec always works"),
)?; )?;
} else { } else {
self.userdevicesessionid_uiaainfo self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
.remove(&userdevicesessionid)?;
} }
Ok(()) Ok(())
@ -70,9 +71,9 @@ impl service::uiaa::Data for KeyValueDatabase {
session: &str, session: &str,
) -> Result<UiaaInfo> { ) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec(); let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff); userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes()); userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff); userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes()); userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice( serde_json::from_slice(
@ -84,6 +85,8 @@ impl service::uiaa::Data for KeyValueDatabase {
"UIAA session does not exist.", "UIAA session does not exist.",
))?, ))?,
) )
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) .map_err(|_| {
Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")
})
} }
} }

View file

@ -5,8 +5,8 @@ use ruma::{
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType}, events::{AnyToDeviceEvent, StateEventType},
serde::Raw, serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch,
OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
}; };
use tracing::warn; use tracing::warn;
@ -40,67 +40,100 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { fn find_from_token(
self.token_userdeviceid &self,
.get(token.as_bytes())? token: &str,
.map_or(Ok(None), |bytes| { ) -> Result<Option<(OwnedUserId, String)>> {
let mut parts = bytes.split(|&b| b == 0xff); self.token_userdeviceid.get(token.as_bytes())?.map_or(
Ok(None),
|bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts.next().ok_or_else(|| { let user_bytes = parts.next().ok_or_else(|| {
Error::bad_database("User ID in token_userdeviceid is invalid.") Error::bad_database(
"User ID in token_userdeviceid is invalid.",
)
})?; })?;
let device_bytes = parts.next().ok_or_else(|| { let device_bytes = parts.next().ok_or_else(|| {
Error::bad_database("Device ID in token_userdeviceid is invalid.") Error::bad_database(
"Device ID in token_userdeviceid is invalid.",
)
})?; })?;
Ok(Some(( Ok(Some((
UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { UserId::parse(
Error::bad_database("User ID in token_userdeviceid is invalid unicode.") utils::string_from_bytes(user_bytes).map_err(|_| {
})?) Error::bad_database(
"User ID in token_userdeviceid is invalid \
unicode.",
)
})?,
)
.map_err(|_| { .map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid.") Error::bad_database(
"User ID in token_userdeviceid is invalid.",
)
})?, })?,
utils::string_from_bytes(device_bytes).map_err(|_| { utils::string_from_bytes(device_bytes).map_err(|_| {
Error::bad_database("Device ID in token_userdeviceid is invalid.") Error::bad_database(
"Device ID in token_userdeviceid is invalid.",
)
})?, })?,
))) )))
}) },
)
} }
/// Returns an iterator over all users on this homeserver. /// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { fn iter<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| { Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.") Error::bad_database(
"User ID in userid_password is invalid unicode.",
)
})?) })?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) .map_err(|_| {
Error::bad_database("User ID in userid_password is invalid.")
})
})) }))
} }
/// Returns a list of local users as list of usernames. /// Returns a list of local users as list of usernames.
/// ///
/// A user account is considered `local` if the length of it's password is greater then zero. /// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> { fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self let users: Vec<String> = self
.userid_password .userid_password
.iter() .iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) .filter_map(|(username, pw)| {
get_username_with_valid_password(&username, &pw)
})
.collect(); .collect();
Ok(users) Ok(users)
} }
/// Returns the password hash for the given user. /// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password self.userid_password.get(user_id.as_bytes())?.map_or(
.get(user_id.as_bytes())? Ok(None),
.map_or(Ok(None), |bytes| { |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.") Error::bad_database(
"Password hash in db is not valid string.",
)
})?)) })?))
}) },
)
} }
/// Hash and set the user's password to the Argon2 hash /// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { fn set_password(
&self,
user_id: &UserId,
password: Option<&str>,
) -> Result<()> {
if let Some(password) = password { if let Some(password) = password {
if let Ok(hash) = utils::calculate_password_hash(password) { if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password self.userid_password
@ -120,17 +153,23 @@ impl service::users::Data for KeyValueDatabase {
/// Returns the `displayname` of a user on this homeserver. /// Returns the `displayname` of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname.get(user_id.as_bytes())?.map_or(
.get(user_id.as_bytes())? Ok(None),
.map_or(Ok(None), |bytes| { |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Displayname in db is invalid.") Error::bad_database("Displayname in db is invalid.")
})?)) })?))
}) },
)
} }
/// Sets a new `displayname` or removes it if `displayname` is `None`. You still need to nofify all rooms of this change. /// Sets a new `displayname` or removes it if `displayname` is `None`. You
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { /// still need to nofify all rooms of this change.
fn set_displayname(
&self,
user_id: &UserId,
displayname: Option<String>,
) -> Result<()> {
if let Some(displayname) = displayname { if let Some(displayname) = displayname {
self.userid_displayname self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?; .insert(user_id.as_bytes(), displayname.as_bytes())?;
@ -147,17 +186,25 @@ impl service::users::Data for KeyValueDatabase {
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) .map_err(|_| {
Error::bad_database("Avatar URL in db is invalid.")
})
.map(Into::into) .map(Into::into)
}) })
.transpose() .transpose()
} }
/// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { fn set_avatar_url(
&self,
user_id: &UserId,
avatar_url: Option<OwnedMxcUri>,
) -> Result<()> {
if let Some(avatar_url) = avatar_url { if let Some(avatar_url) = avatar_url {
self.userid_avatarurl self.userid_avatarurl.insert(
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; user_id.as_bytes(),
avatar_url.to_string().as_bytes(),
)?;
} else { } else {
self.userid_avatarurl.remove(user_id.as_bytes())?; self.userid_avatarurl.remove(user_id.as_bytes())?;
} }
@ -170,8 +217,9 @@ impl service::users::Data for KeyValueDatabase {
self.userid_blurhash self.userid_blurhash
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
let s = utils::string_from_bytes(&bytes) let s = utils::string_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; Error::bad_database("Avatar URL in db is invalid.")
})?;
Ok(s) Ok(s)
}) })
@ -179,7 +227,11 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Sets a new `avatar_url` or removes it if `avatar_url` is `None`. /// Sets a new `avatar_url` or removes it if `avatar_url` is `None`.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { fn set_blurhash(
&self,
user_id: &UserId,
blurhash: Option<String>,
) -> Result<()> {
if let Some(blurhash) = blurhash { if let Some(blurhash) = blurhash {
self.userid_blurhash self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?; .insert(user_id.as_bytes(), blurhash.as_bytes())?;
@ -204,11 +256,10 @@ impl service::users::Data for KeyValueDatabase {
); );
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion self.userid_devicelistversion.increment(user_id.as_bytes())?;
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,
@ -228,9 +279,13 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Removes a device from a user. /// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { fn remove_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens // Remove tokens
@ -241,7 +296,7 @@ impl service::users::Data for KeyValueDatabase {
// Remove todevice events // Remove todevice events
let mut prefix = userdeviceid.clone(); let mut prefix = userdeviceid.clone();
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) { for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?; self.todeviceid_events.remove(&key)?;
@ -249,8 +304,7 @@ impl service::users::Data for KeyValueDatabase {
// TODO: Remove onetimekeys // TODO: Remove onetimekeys
self.userid_devicelistversion self.userid_devicelistversion.increment(user_id.as_bytes())?;
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?; self.userdeviceid_metadata.remove(&userdeviceid)?;
@ -263,29 +317,34 @@ impl service::users::Data for KeyValueDatabase {
user_id: &UserId, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
// All devices have metadata // All devices have metadata
Box::new( Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(
self.userdeviceid_metadata |(bytes, _)| {
.scan_prefix(prefix) Ok(utils::string_from_bytes(
.map(|(bytes, _)| { bytes.rsplit(|&b| b == 0xFF).next().ok_or_else(|| {
Ok(utils::string_from_bytes( Error::bad_database("UserDevice ID in db is invalid.")
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { })?,
Error::bad_database("UserDevice ID in db is invalid.") )
})?, .map_err(|_| {
Error::bad_database(
"Device ID in userdeviceid_metadata is invalid.",
) )
.map_err(|_| { })?
Error::bad_database("Device ID in userdeviceid_metadata is invalid.") .into())
})? },
.into()) ))
}),
)
} }
/// Replaces the access token of one device. /// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { fn set_token(
&self,
user_id: &UserId,
device_id: &DeviceId,
token: &str,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
assert!( assert!(
@ -300,10 +359,8 @@ impl service::users::Data for KeyValueDatabase {
} }
// Assign token to user device combination // Assign token to user device combination
self.userdeviceid_token self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?;
.insert(&userdeviceid, token.as_bytes())?; self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(()) Ok(())
} }
@ -316,17 +373,19 @@ impl service::users::Data for KeyValueDatabase {
one_time_key_value: &Raw<OneTimeKey>, one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
assert!( assert!(
self.userdeviceid_metadata.get(&key)?.is_some(), self.userdeviceid_metadata.get(&key)?.is_some(),
"devices should have metadata and this method should only be called with existing devices" "devices should have metadata and this method should only be \
called with existing devices"
); );
key.push(0xff); key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything, // TODO: Use DeviceKeyId::to_string when it's available (and update
// because there are no wrapping quotation marks anymore) // everything, because there are no wrapping quotation marks
// anymore)
key.extend_from_slice( key.extend_from_slice(
serde_json::to_string(one_time_key_key) serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works") .expect("DeviceKeyId::to_string always works")
@ -335,7 +394,8 @@ impl service::users::Data for KeyValueDatabase {
self.onetimekeyid_onetimekeys.insert( self.onetimekeyid_onetimekeys.insert(
&key, &key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), &serde_json::to_vec(&one_time_key_value)
.expect("OneTimeKey::to_vec always works"),
)?; )?;
self.userid_lastonetimekeyupdate.insert( self.userid_lastonetimekeyupdate.insert(
@ -347,13 +407,16 @@ impl service::users::Data for KeyValueDatabase {
} }
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate.get(user_id.as_bytes())?.map_or(
.get(user_id.as_bytes())? Ok(0),
.map_or(Ok(0), |bytes| { |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") Error::bad_database(
"Count in roomid_lastroomactiveupdate is invalid.",
)
}) })
}) },
)
} }
fn take_one_time_key( fn take_one_time_key(
@ -363,9 +426,9 @@ impl service::users::Data for KeyValueDatabase {
key_algorithm: &DeviceKeyAlgorithm, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
// Annoying quotation mark // Annoying quotation mark
prefix.push(b'"'); prefix.push(b'"');
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
@ -384,13 +447,18 @@ impl service::users::Data for KeyValueDatabase {
Ok(( Ok((
serde_json::from_slice( serde_json::from_slice(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF).next().ok_or_else(|| {
.next() Error::bad_database(
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, "OneTimeKeyId in db is invalid.",
)
})?,
) )
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, .map_err(|_| {
serde_json::from_slice(&value) Error::bad_database("OneTimeKeyId in db is invalid.")
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, })?,
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("OneTimeKeys in db are invalid.")
})?,
)) ))
}) })
.transpose() .transpose()
@ -402,25 +470,31 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> { ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new(); let mut counts = BTreeMap::new();
for algorithm in for algorithm in self
self.onetimekeyid_onetimekeys .onetimekeyid_onetimekeys
.scan_prefix(userdeviceid) .scan_prefix(userdeviceid)
.map(|(bytes, _)| { .map(|(bytes, _)| {
Ok::<_, Error>( Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>( serde_json::from_slice::<OwnedDeviceKeyId>(
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { bytes.rsplit(|&b| b == 0xFF).next().ok_or_else(
Error::bad_database("OneTimeKey ID in db is invalid.") || {
})?, Error::bad_database(
) "OneTimeKey ID in db is invalid.",
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? )
.algorithm(), },
)?,
) )
}) .map_err(|_| {
Error::bad_database("DeviceKeyId in db is invalid.")
})?
.algorithm(),
)
})
{ {
*counts.entry(algorithm?).or_default() += UInt::from(1_u32); *counts.entry(algorithm?).or_default() += UInt::from(1_u32);
} }
@ -435,12 +509,13 @@ impl service::users::Data for KeyValueDatabase {
device_keys: &Raw<DeviceKeys>, device_keys: &Raw<DeviceKeys>,
) -> Result<()> { ) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert( self.keyid_key.insert(
&userdeviceid, &userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), &serde_json::to_vec(&device_keys)
.expect("DeviceKeys::to_vec always works"),
)?; )?;
self.mark_device_key_update(user_id)?; self.mark_device_key_update(user_id)?;
@ -458,30 +533,33 @@ impl service::users::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
// TODO: Check signatures // TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?; .insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?;
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key // Self-signing key
if let Some(self_signing_key) = self_signing_key { if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key let mut self_signing_key_ids = self_signing_key
.deserialize() .deserialize()
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid self signing key",
)
})? })?
.keys .keys
.into_values(); .into_values();
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( let self_signing_key_id =
ErrorKind::InvalidParam, self_signing_key_ids.next().ok_or(Error::BadRequest(
"Self signing key contained no key.", ErrorKind::InvalidParam,
))?; "Self signing key contained no key.",
))?;
if self_signing_key_ids.next().is_some() { if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -491,7 +569,8 @@ impl service::users::Data for KeyValueDatabase {
} }
let mut self_signing_key_key = prefix.clone(); let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); self_signing_key_key
.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key.insert( self.keyid_key.insert(
&self_signing_key_key, &self_signing_key_key,
@ -507,15 +586,19 @@ impl service::users::Data for KeyValueDatabase {
let mut user_signing_key_ids = user_signing_key let mut user_signing_key_ids = user_signing_key
.deserialize() .deserialize()
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid user signing key",
)
})? })?
.keys .keys
.into_values(); .into_values();
let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( let user_signing_key_id =
ErrorKind::InvalidParam, user_signing_key_ids.next().ok_or(Error::BadRequest(
"User signing key contained no key.", ErrorKind::InvalidParam,
))?; "User signing key contained no key.",
))?;
if user_signing_key_ids.next().is_some() { if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -525,7 +608,8 @@ impl service::users::Data for KeyValueDatabase {
} }
let mut user_signing_key_key = prefix; let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); user_signing_key_key
.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key.insert( self.keyid_key.insert(
&user_signing_key_key, &user_signing_key_key,
@ -551,32 +635,44 @@ impl service::users::Data for KeyValueDatabase {
sender_id: &UserId, sender_id: &UserId,
) -> Result<()> { ) -> Result<()> {
let mut key = target_id.as_bytes().to_vec(); let mut key = target_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(key_id.as_bytes()); key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value = let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( &self.keyid_key.get(&key)?.ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Tried to sign nonexistent key.", "Tried to sign nonexistent key.",
))?) ))?,
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; )
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key let signatures = cross_signing_key
.get_mut("signatures") .get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? .ok_or_else(|| {
Error::bad_database("key in keyid_key has no signatures field.")
})?
.as_object_mut() .as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? .ok_or_else(|| {
Error::bad_database(
"key in keyid_key has invalid signatures field.",
)
})?
.entry(sender_id.to_string()) .entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into()); .or_insert_with(|| serde_json::Map::new().into());
signatures signatures
.as_object_mut() .as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? .ok_or_else(|| {
Error::bad_database(
"signatures in keyid_key for a user is invalid.",
)
})?
.insert(signature.0, signature.1.into()); .insert(signature.0, signature.1.into());
self.keyid_key.insert( self.keyid_key.insert(
&key, &key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), &serde_json::to_vec(&cross_signing_key)
.expect("CrossSigningKey::to_vec always works"),
)?; )?;
self.mark_device_key_update(target_id)?; self.mark_device_key_update(target_id)?;
@ -591,7 +687,7 @@ impl service::users::Data for KeyValueDatabase {
to: Option<u64>, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec(); let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut start = prefix.clone(); let mut start = prefix.clone();
start.extend_from_slice(&(from + 1).to_be_bytes()); start.extend_from_slice(&(from + 1).to_be_bytes());
@ -603,26 +699,39 @@ impl service::users::Data for KeyValueDatabase {
.iter_from(&start, false) .iter_from(&start, false)
.take_while(move |(k, _)| { .take_while(move |(k, _)| {
k.starts_with(&prefix) k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { && if let Some(current) =
k.splitn(2, |&b| b == 0xFF).nth(1)
{
if let Ok(c) = utils::u64_from_bytes(current) { if let Ok(c) = utils::u64_from_bytes(current) {
c <= to c <= to
} else { } else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes"); warn!(
"BadDatabase: Could not parse \
keychangeid_userid bytes"
);
false false
} }
} else { } else {
warn!("BadDatabase: Could not parse keychangeid_userid"); warn!(
"BadDatabase: Could not parse \
keychangeid_userid"
);
false false
} }
}) })
.map(|(_, bytes)| { .map(|(_, bytes)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { UserId::parse(utils::string_from_bytes(&bytes).map_err(
Error::bad_database( |_| {
"User ID in devicekeychangeid_userid is invalid unicode.", Error::bad_database(
) "User ID in devicekeychangeid_userid is \
})?) invalid unicode.",
)
},
)?)
.map_err(|_| { .map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid.") Error::bad_database(
"User ID in devicekeychangeid_userid is invalid.",
)
}) })
}), }),
) )
@ -647,14 +756,14 @@ impl service::users::Data for KeyValueDatabase {
} }
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(&count); key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?; self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
} }
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(&count); key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?; self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
@ -667,7 +776,7 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<Raw<DeviceKeys>>> { ) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
@ -683,11 +792,11 @@ impl service::users::Data for KeyValueDatabase {
master_key: &Raw<CrossSigningKey>, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> { ) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let master_key = master_key let master_key = master_key.deserialize().map_err(|_| {
.deserialize() Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key")
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; })?;
let mut master_key_ids = master_key.keys.values(); let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -712,8 +821,12 @@ impl service::users::Data for KeyValueDatabase {
allowed_signatures: &dyn Fn(&UserId) -> bool, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes) let mut cross_signing_key = serde_json::from_slice::<
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; serde_json::Value,
>(&bytes)
.map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.")
})?;
clean_signatures( clean_signatures(
&mut cross_signing_key, &mut cross_signing_key,
sender_user, sender_user,
@ -754,16 +867,20 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { fn get_user_signing_key(
self.userid_usersigningkeyid &self,
.get(user_id.as_bytes())? user_id: &UserId,
.map_or(Ok(None), |key| { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(
Ok(None),
|key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.") Error::bad_database("CrossSigningKey in db is invalid.")
})?)) })?))
}) })
}) },
)
} }
fn add_to_device_event( fn add_to_device_event(
@ -775,9 +892,9 @@ impl service::users::Data for KeyValueDatabase {
content: serde_json::Value, content: serde_json::Value,
) -> Result<()> { ) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec(); let mut key = target_user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes()); key.extend_from_slice(target_device_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new(); let mut json = serde_json::Map::new();
@ -785,7 +902,8 @@ impl service::users::Data for KeyValueDatabase {
json.insert("sender".to_owned(), sender.to_string().into()); json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content); json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); let value =
serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?; self.todeviceid_events.insert(&key, &value)?;
@ -800,15 +918,14 @@ impl service::users::Data for KeyValueDatabase {
let mut events = Vec::new(); let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) { for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push( events.push(serde_json::from_slice(&value).map_err(|_| {
serde_json::from_slice(&value) Error::bad_database("Event in todeviceid_events is invalid.")
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, })?);
);
} }
Ok(events) Ok(events)
@ -821,9 +938,9 @@ impl service::users::Data for KeyValueDatabase {
until: u64, until: u64,
) -> Result<()> { ) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xFF);
let mut last = prefix.clone(); let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes()); last.extend_from_slice(&until.to_be_bytes());
@ -836,8 +953,14 @@ impl service::users::Data for KeyValueDatabase {
.map(|(key, _)| { .map(|(key, _)| {
Ok::<_, Error>(( Ok::<_, Error>((
key.clone(), key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()]) utils::u64_from_bytes(
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, &key[key.len() - size_of::<u64>()..key.len()],
)
.map_err(|_| {
Error::bad_database(
"ToDeviceId has invalid count bytes.",
)
})?,
)) ))
}) })
.filter_map(Result::ok) .filter_map(Result::ok)
@ -856,7 +979,7 @@ impl service::users::Data for KeyValueDatabase {
device: &Device, device: &Device,
) -> Result<()> { ) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
assert!( assert!(
@ -864,12 +987,12 @@ impl service::users::Data for KeyValueDatabase {
"this method should only be called with existing devices" "this method should only be called with existing devices"
); );
self.userid_devicelistversion self.userid_devicelistversion.increment(user_id.as_bytes())?;
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"), &serde_json::to_vec(device)
.expect("Device::to_string always works"),
)?; )?;
Ok(()) Ok(())
@ -882,26 +1005,32 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata self.userdeviceid_metadata.get(&userdeviceid)?.map_or(
.get(&userdeviceid)? Ok(None),
.map_or(Ok(None), |bytes| { |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.") Error::bad_database(
"Metadata in userdeviceid_metadata is invalid.",
)
})?)) })?))
}) },
)
} }
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(
.get(user_id.as_bytes())? Ok(None),
.map_or(Ok(None), |bytes| { |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) .map_err(|_| {
Error::bad_database("Invalid devicelistversion in db.")
})
.map(Some) .map(Some)
}) },
)
} }
fn all_devices_metadata<'a>( fn all_devices_metadata<'a>(
@ -909,25 +1038,29 @@ impl service::users::Data for KeyValueDatabase {
user_id: &UserId, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> { ) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
Box::new( Box::new(self.userdeviceid_metadata.scan_prefix(key).map(
self.userdeviceid_metadata |(_, bytes)| {
.scan_prefix(key) serde_json::from_slice::<Device>(&bytes).map_err(|_| {
.map(|(_, bytes)| { Error::bad_database(
serde_json::from_slice::<Device>(&bytes).map_err(|_| { "Device in userdeviceid_metadata is invalid.",
Error::bad_database("Device in userdeviceid_metadata is invalid.") )
}) })
}), },
) ))
} }
/// Creates a new sync filter. Returns the filter id. /// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { fn create_filter(
&self,
user_id: &UserId,
filter: &FilterDefinition,
) -> Result<String> {
let filter_id = utils::random_string(4); let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes()); key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter.insert( self.userfilterid_filter.insert(
@ -938,9 +1071,13 @@ impl service::users::Data for KeyValueDatabase {
Ok(filter_id) Ok(filter_id)
} }
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> { fn get_filter(
&self,
user_id: &UserId,
filter_id: &str,
) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes()); key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?; let raw = self.userfilterid_filter.get(&key)?;
@ -956,9 +1093,12 @@ impl service::users::Data for KeyValueDatabase {
/// Will only return with Some(username) if the password was not empty and the /// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed. /// username could be successfully parsed.
/// If [`utils::string_from_bytes`] returns an error that username will be skipped /// If [`utils::string_from_bytes`] returns an error that username will be
/// and the error will be logged. /// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> { fn get_username_with_valid_password(
username: &[u8],
password: &[u8],
) -> Option<String> {
// A valid password is not empty // A valid password is not empty
if password.is_empty() { if password.is_empty() {
None None
@ -967,7 +1107,8 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<
Ok(u) => Some(u), Ok(u) => Some(u),
Err(e) => { Err(e) => {
warn!( warn!(
"Failed to parse username while calling get_local_users(): {}", "Failed to parse username while calling \
get_local_users(): {}",
e.to_string() e.to_string()
); );
None None

View file

@ -12,7 +12,9 @@ use axum::{
routing::{any, get, on, MethodFilter}, routing::{any, get, on, MethodFilter},
Router, Router,
}; };
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
};
use figment::{ use figment::{
providers::{Env, Format, Toml}, providers::{Env, Format, Toml},
Figment, Figment,
@ -50,16 +52,16 @@ use api::{client_server, server_server};
pub(crate) use config::Config; pub(crate) use config::Config;
pub(crate) use database::KeyValueDatabase; pub(crate) use database::KeyValueDatabase;
pub(crate) use service::{pdu::PduEvent, Services}; pub(crate) use service::{pdu::PduEvent, Services};
pub(crate) use utils::error::{Error, Result};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
use tikv_jemallocator::Jemalloc; use tikv_jemallocator::Jemalloc;
pub(crate) use utils::error::{Error, Result};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
#[global_allocator] #[global_allocator]
static GLOBAL: Jemalloc = Jemalloc; static GLOBAL: Jemalloc = Jemalloc;
pub(crate) static SERVICES: RwLock<Option<&'static Services>> = RwLock::new(None); pub(crate) static SERVICES: RwLock<Option<&'static Services>> =
RwLock::new(None);
/// Convenient access to the global [`Services`] instance /// Convenient access to the global [`Services`] instance
pub(crate) fn services() -> &'static Services { pub(crate) fn services() -> &'static Services {
@ -71,9 +73,9 @@ pub(crate) fn services() -> &'static Services {
/// Returns the current version of the crate with extra info if supplied /// Returns the current version of the crate with extra info if supplied
/// ///
/// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string to /// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string
/// include it in parenthesis after the SemVer version. A common value are git /// to include it in parenthesis after the SemVer version. A common value are
/// commit hashes. /// git commit hashes.
fn version() -> String { fn version() -> String {
let cargo_pkg_version = env!("CARGO_PKG_VERSION"); let cargo_pkg_version = env!("CARGO_PKG_VERSION");
@ -91,7 +93,8 @@ async fn main() {
let raw_config = Figment::new() let raw_config = Figment::new()
.merge( .merge(
Toml::file(Env::var("GRAPEVINE_CONFIG").expect( Toml::file(Env::var("GRAPEVINE_CONFIG").expect(
"The GRAPEVINE_CONFIG env var needs to be set. Example: /etc/grapevine.toml", "The GRAPEVINE_CONFIG env var needs to be set. Example: \
/etc/grapevine.toml",
)) ))
.nested(), .nested(),
) )
@ -100,7 +103,10 @@ async fn main() {
let config = match raw_config.extract::<Config>() { let config = match raw_config.extract::<Config>() {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
eprintln!("It looks like your config is invalid. The following error occurred: {e}"); eprintln!(
"It looks like your config is invalid. The following error \
occurred: {e}"
);
std::process::exit(1); std::process::exit(1);
} }
}; };
@ -108,7 +114,9 @@ async fn main() {
config.warn_deprecated(); config.warn_deprecated();
if config.allow_jaeger { if config.allow_jaeger {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); opentelemetry::global::set_text_map_propagator(
opentelemetry_jaeger::Propagator::new(),
);
let tracer = opentelemetry_jaeger::new_agent_pipeline() let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_auto_split_batch(true) .with_auto_split_batch(true)
.with_service_name("grapevine") .with_service_name("grapevine")
@ -120,7 +128,8 @@ async fn main() {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
eprintln!( eprintln!(
"It looks like your log config is invalid. The following error occurred: {e}" "It looks like your log config is invalid. The following \
error occurred: {e}"
); );
EnvFilter::try_new("warn").unwrap() EnvFilter::try_new("warn").unwrap()
} }
@ -146,7 +155,10 @@ async fn main() {
let filter_layer = match EnvFilter::try_new(&config.log) { let filter_layer = match EnvFilter::try_new(&config.log) {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); eprintln!(
"It looks like your config is invalid. The following \
error occured while parsing it: {e}"
);
EnvFilter::try_new("warn").unwrap() EnvFilter::try_new("warn").unwrap()
} }
}; };
@ -163,7 +175,8 @@ async fn main() {
// * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6
// * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741
#[cfg(unix)] #[cfg(unix)]
maximize_fd_limit().expect("should be able to increase the soft limit to the hard limit"); maximize_fd_limit()
.expect("should be able to increase the soft limit to the hard limit");
info!("Loading database"); info!("Loading database");
if let Err(error) = KeyValueDatabase::load_or_create(config).await { if let Err(error) = KeyValueDatabase::load_or_create(config).await {
@ -190,17 +203,19 @@ async fn run_server() -> io::Result<()> {
let middlewares = ServiceBuilder::new() let middlewares = ServiceBuilder::new()
.sensitive_headers([header::AUTHORIZATION]) .sensitive_headers([header::AUTHORIZATION])
.layer(axum::middleware::from_fn(spawn_task)) .layer(axum::middleware::from_fn(spawn_task))
.layer( .layer(TraceLayer::new_for_http().make_span_with(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| { |request: &http::Request<_>| {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() { let path = if let Some(path) =
request.extensions().get::<MatchedPath>()
{
path.as_str() path.as_str()
} else { } else {
request.uri().path() request.uri().path()
}; };
tracing::info_span!("http_request", %path) tracing::info_span!("http_request", %path)
}), },
) ))
.layer(axum::middleware::from_fn(unrecognized_method)) .layer(axum::middleware::from_fn(unrecognized_method))
.layer( .layer(
CorsLayer::new() CorsLayer::new()
@ -235,7 +250,8 @@ async fn run_server() -> io::Result<()> {
match &config.tls { match &config.tls {
Some(tls) => { Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; let conf =
RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app); let server = bind_rustls(addr, conf).handle(handle).serve(app);
#[cfg(feature = "systemd")] #[cfg(feature = "systemd")]
@ -411,9 +427,10 @@ fn routes(config: &Config) -> Router {
.ruma_route(c2s::get_relating_events_route) .ruma_route(c2s::get_relating_events_route)
.ruma_route(c2s::get_hierarchy_route); .ruma_route(c2s::get_hierarchy_route);
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // Ruma doesn't have support for multiple paths for a single endpoint yet,
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route. // and these routes share one Ruma request / response type pair with
// These two endpoints also allow trailing slashes. // {get,send}_state_event_for_key_route. These two endpoints also allow
// trailing slashes.
let router = router let router = router
.route( .route(
"/_matrix/client/r0/rooms/:room_id/state/:event_type", "/_matrix/client/r0/rooms/:room_id/state/:event_type",
@ -483,9 +500,7 @@ fn routes(config: &Config) -> Router {
async fn shutdown_signal(handle: ServerHandle) { async fn shutdown_signal(handle: ServerHandle) {
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c() signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
.await
.expect("failed to install Ctrl+C handler");
}; };
#[cfg(unix)] #[cfg(unix)]
@ -554,9 +569,9 @@ impl RouterExt for Router {
} }
pub(crate) trait RumaHandler<T> { pub(crate) trait RumaHandler<T> {
// Can't transform to a handler without boxing or relying on the nightly-only // Can't transform to a handler without boxing or relying on the
// impl-trait-in-traits feature. Moving a small amount of extra logic into the trait // nightly-only impl-trait-in-traits feature. Moving a small amount of
// allows bypassing both. // extra logic into the trait allows bypassing both.
fn add_to_router(self, router: Router) -> Router; fn add_to_router(self, router: Router) -> Router;
} }

View file

@ -4,10 +4,9 @@ use std::{
}; };
use lru_cache::LruCache; use lru_cache::LruCache;
use tokio::sync::{broadcast, Mutex}; use tokio::sync::{broadcast, Mutex, RwLock};
use crate::{Config, Result}; use crate::{Config, Result};
use tokio::sync::RwLock;
pub(crate) mod account_data; pub(crate) mod account_data;
pub(crate) mod admin; pub(crate) mod admin;
@ -58,10 +57,14 @@ impl Services {
) -> Result<Self> { ) -> Result<Self> {
Ok(Self { Ok(Self {
appservice: appservice::Service::build(db)?, appservice: appservice::Service::build(db)?,
pusher: pusher::Service { db }, pusher: pusher::Service {
db,
},
rooms: rooms::Service { rooms: rooms::Service {
alias: db, alias: db,
auth_chain: rooms::auth_chain::Service { db }, auth_chain: rooms::auth_chain::Service {
db,
},
directory: db, directory: db,
edus: rooms::edus::Service { edus: rooms::edus::Service {
read_receipt: db, read_receipt: db,
@ -78,10 +81,14 @@ impl Services {
}, },
metadata: db, metadata: db,
outlier: db, outlier: db,
pdu_metadata: rooms::pdu_metadata::Service { db }, pdu_metadata: rooms::pdu_metadata::Service {
db,
},
search: db, search: db,
short: db, short: db,
state: rooms::state::Service { db }, state: rooms::state::Service {
db,
},
state_accessor: rooms::state_accessor::Service { state_accessor: rooms::state_accessor::Service {
db, db,
#[allow( #[allow(
@ -101,7 +108,9 @@ impl Services {
(100.0 * config.cache_capacity_modifier) as usize, (100.0 * config.cache_capacity_modifier) as usize,
)), )),
}, },
state_cache: rooms::state_cache::Service { db }, state_cache: rooms::state_cache::Service {
db,
},
state_compressor: rooms::state_compressor::Service { state_compressor: rooms::state_compressor::Service {
db, db,
#[allow( #[allow(
@ -117,14 +126,18 @@ impl Services {
db, db,
lasttimelinecount_cache: Mutex::new(HashMap::new()), lasttimelinecount_cache: Mutex::new(HashMap::new()),
}, },
threads: rooms::threads::Service { db }, threads: rooms::threads::Service {
db,
},
spaces: rooms::spaces::Service { spaces: rooms::spaces::Service {
roomid_spacechunk_cache: Mutex::new(LruCache::new(200)), roomid_spacechunk_cache: Mutex::new(LruCache::new(200)),
}, },
user: db, user: db,
}, },
transaction_ids: db, transaction_ids: db,
uiaa: uiaa::Service { db }, uiaa: uiaa::Service {
db,
},
users: users::Service { users: users::Service {
db, db,
connections: StdMutex::new(BTreeMap::new()), connections: StdMutex::new(BTreeMap::new()),
@ -132,14 +145,18 @@ impl Services {
account_data: db, account_data: db,
admin: admin::Service::build(), admin: admin::Service::build(),
key_backups: db, key_backups: db,
media: media::Service { db }, media: media::Service {
db,
},
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
}) })
} }
async fn memory_usage(&self) -> String { async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); let lazy_load_waiting =
self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self let server_visibility_cache = self
.rooms .rooms
.state_accessor .state_accessor
@ -154,21 +171,12 @@ impl Services {
.lock() .lock()
.unwrap() .unwrap()
.len(); .len();
let stateinfo_cache = self let stateinfo_cache =
.rooms self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
.state_compressor let lasttimelinecount_cache =
.stateinfo_cache self.rooms.timeline.lasttimelinecount_cache.lock().await.len();
.lock() let roomid_spacechunk_cache =
.unwrap() self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
.len();
let lasttimelinecount_cache = self
.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
format!( format!(
"\ "\
@ -177,18 +185,13 @@ server_visibility_cache: {server_visibility_cache}
user_visibility_cache: {user_visibility_cache} user_visibility_cache: {user_visibility_cache}
stateinfo_cache: {stateinfo_cache} stateinfo_cache: {stateinfo_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} lasttimelinecount_cache: {lasttimelinecount_cache}
roomid_spacechunk_cache: {roomid_spacechunk_cache}\ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
"
) )
} }
async fn clear_caches(&self, amount: u32) { async fn clear_caches(&self, amount: u32) {
if amount > 0 { if amount > 0 {
self.rooms self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear();
.lazy_loading
.lazy_load_waiting
.lock()
.await
.clear();
} }
if amount > 1 { if amount > 1 {
self.rooms self.rooms
@ -207,28 +210,13 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\
.clear(); .clear();
} }
if amount > 3 { if amount > 3 {
self.rooms self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.clear();
} }
if amount > 4 { if amount > 4 {
self.rooms self.rooms.timeline.lasttimelinecount_cache.lock().await.clear();
.timeline
.lasttimelinecount_cache
.lock()
.await
.clear();
} }
if amount > 5 { if amount > 5 {
self.rooms self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear();
.spaces
.roomid_spacechunk_cache
.lock()
.await
.clear();
} }
} }
} }

View file

@ -1,14 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::Result;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the
/// previous entry.
fn update( fn update(
&self, &self,
room_id: Option<&RoomId>, room_id: Option<&RoomId>,

View file

@ -16,7 +16,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent, canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent, create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent}, guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent}, member::{MembershipState, RoomMemberEventContent},
message::RoomMessageEventContent, message::RoomMessageEventContent,
@ -26,19 +28,19 @@ use ruma::{
}, },
TimelineEventType, TimelineEventType,
}, },
EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId,
ServerName, UserId,
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::warn; use tracing::warn;
use super::pdu::PduBuilder;
use crate::{ use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
services, utils, Error, PduEvent, Result, services, utils, Error, PduEvent, Result,
}; };
use super::pdu::PduBuilder;
#[cfg_attr(test, derive(Debug))] #[cfg_attr(test, derive(Debug))]
#[derive(Parser)] #[derive(Parser)]
#[command(name = "@grapevine:server.name:", version = env!("CARGO_PKG_VERSION"))] #[command(name = "@grapevine:server.name:", version = env!("CARGO_PKG_VERSION"))]
@ -46,11 +48,12 @@ enum AdminCommand {
#[command(verbatim_doc_comment)] #[command(verbatim_doc_comment)]
/// Register an appservice using its registration YAML /// Register an appservice using its registration YAML
/// ///
/// This command needs a YAML generated by an appservice (such as a bridge), /// This command needs a YAML generated by an appservice (such as a
/// which must be provided in a Markdown code-block below the command. /// bridge), which must be provided in a Markdown code-block below the
/// command.
/// ///
/// Registering a new bridge using the ID of an existing bridge will replace /// Registering a new bridge using the ID of an existing bridge will
/// the old one. /// replace the old one.
/// ///
/// [commandbody]() /// [commandbody]()
/// # ``` /// # ```
@ -95,8 +98,9 @@ enum AdminCommand {
/// ///
/// Users will not be removed from joined rooms by default. /// Users will not be removed from joined rooms by default.
/// Can be overridden with --leave-rooms flag. /// Can be overridden with --leave-rooms flag.
/// Removing a mass amount of users from a room may cause a significant amount of leave events. /// Removing a mass amount of users from a room may cause a significant
/// The time to leave rooms may depend significantly on joined rooms and servers. /// amount of leave events. The time to leave rooms may depend
/// significantly on joined rooms and servers.
/// ///
/// [commandbody]() /// [commandbody]()
/// # ``` /// # ```
@ -138,11 +142,17 @@ enum AdminCommand {
/// Print database memory usage statistics /// Print database memory usage statistics
MemoryUsage, MemoryUsage,
/// Clears all of Grapevine's database caches with index smaller than the amount /// Clears all of Grapevine's database caches with index smaller than the
ClearDatabaseCaches { amount: u32 }, /// amount
ClearDatabaseCaches {
amount: u32,
},
/// Clears all of Grapevine's service caches with index smaller than the amount /// Clears all of Grapevine's service caches with index smaller than the
ClearServiceCaches { amount: u32 }, /// amount
ClearServiceCaches {
amount: u32,
},
/// Show configuration values /// Show configuration values
ShowConfig, ShowConfig,
@ -162,9 +172,13 @@ enum AdminCommand {
}, },
/// Disables incoming federation handling for a room. /// Disables incoming federation handling for a room.
DisableRoom { room_id: Box<RoomId> }, DisableRoom {
room_id: Box<RoomId>,
},
/// Enables incoming federation handling for a room again. /// Enables incoming federation handling for a room again.
EnableRoom { room_id: Box<RoomId> }, EnableRoom {
room_id: Box<RoomId>,
},
/// Verify json signatures /// Verify json signatures
/// [commandbody]() /// [commandbody]()
@ -267,31 +281,38 @@ impl Service {
} }
pub(crate) fn process_message(&self, room_message: String) { pub(crate) fn process_message(&self, room_message: String) {
self.sender self.sender.send(AdminRoomEvent::ProcessMessage(room_message)).unwrap();
.send(AdminRoomEvent::ProcessMessage(room_message))
.unwrap();
} }
pub(crate) fn send_message(&self, message_content: RoomMessageEventContent) { pub(crate) fn send_message(
self.sender &self,
.send(AdminRoomEvent::SendMessage(message_content)) message_content: RoomMessageEventContent,
.unwrap(); ) {
self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap();
} }
// Parse and process a message from the admin room // Parse and process a message from the admin room
async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { async fn process_admin_message(
&self,
room_message: String,
) -> RoomMessageEventContent {
let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); let mut lines = room_message.lines().filter(|l| !l.trim().is_empty());
let command_line = lines.next().expect("each string has at least one line"); let command_line =
lines.next().expect("each string has at least one line");
let body: Vec<_> = lines.collect(); let body: Vec<_> = lines.collect();
let admin_command = match Self::parse_admin_command(command_line) { let admin_command = match Self::parse_admin_command(command_line) {
Ok(command) => command, Ok(command) => command,
Err(error) => { Err(error) => {
let server_name = services().globals.server_name(); let server_name = services().globals.server_name();
let message = error.replace("server.name", server_name.as_str()); let message =
error.replace("server.name", server_name.as_str());
let html_message = Self::usage_to_html(&message, server_name); let html_message = Self::usage_to_html(&message, server_name);
return RoomMessageEventContent::text_html(message, html_message); return RoomMessageEventContent::text_html(
message,
html_message,
);
} }
}; };
@ -299,22 +320,28 @@ impl Service {
Ok(reply_message) => reply_message, Ok(reply_message) => reply_message,
Err(error) => { Err(error) => {
let markdown_message = format!( let markdown_message = format!(
"Encountered an error while handling the command:\n\ "Encountered an error while handling the \
```\n{error}\n```", command:\n```\n{error}\n```",
); );
let html_message = format!( let html_message = format!(
"Encountered an error while handling the command:\n\ "Encountered an error while handling the \
<pre>\n{error}\n</pre>", command:\n<pre>\n{error}\n</pre>",
); );
RoomMessageEventContent::text_html(markdown_message, html_message) RoomMessageEventContent::text_html(
markdown_message,
html_message,
)
} }
} }
} }
// Parse chat messages from the admin room into an AdminCommand object // Parse chat messages from the admin room into an AdminCommand object
fn parse_admin_command(command_line: &str) -> std::result::Result<AdminCommand, String> { fn parse_admin_command(
// Note: argv[0] is `@grapevine:servername:`, which is treated as the main command command_line: &str,
) -> std::result::Result<AdminCommand, String> {
// Note: argv[0] is `@grapevine:servername:`, which is treated as the
// main command
let mut argv: Vec<_> = command_line.split_whitespace().collect(); let mut argv: Vec<_> = command_line.split_whitespace().collect();
// Replace `help command` with `command --help` // Replace `help command` with `command --help`
@ -342,18 +369,26 @@ impl Service {
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
let reply_message_content = match command { let reply_message_content = match command {
AdminCommand::RegisterAppservice => { AdminCommand::RegisterAppservice => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{ {
let appservice_config = body[1..body.len() - 1].join("\n"); let appservice_config = body[1..body.len() - 1].join("\n");
let parsed_config = serde_yaml::from_str::<Registration>(&appservice_config); let parsed_config = serde_yaml::from_str::<Registration>(
&appservice_config,
);
match parsed_config { match parsed_config {
Ok(yaml) => match services().appservice.register_appservice(yaml).await { Ok(yaml) => match services()
Ok(id) => RoomMessageEventContent::text_plain(format!( .appservice
"Appservice registered with ID: {id}." .register_appservice(yaml)
)), .await
Err(e) => RoomMessageEventContent::text_plain(format!( {
"Failed to register appservice: {e}" Ok(id) => RoomMessageEventContent::text_plain(
)), format!("Appservice registered with ID: {id}."),
),
Err(e) => RoomMessageEventContent::text_plain(
format!("Failed to register appservice: {e}"),
),
}, },
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
"Could not parse appservice config: {e}" "Could not parse appservice config: {e}"
@ -361,7 +396,8 @@ impl Service {
} }
} else { } else {
RoomMessageEventContent::text_plain( RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.", "Expected code block in command body. Add --help for \
details.",
) )
} }
} }
@ -372,7 +408,9 @@ impl Service {
.unregister_appservice(&appservice_identifier) .unregister_appservice(&appservice_identifier)
.await .await
{ {
Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), Ok(()) => RoomMessageEventContent::text_plain(
"Appservice unregistered.",
),
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
"Failed to unregister appservice: {e}" "Failed to unregister appservice: {e}"
)), )),
@ -407,17 +445,25 @@ impl Service {
); );
RoomMessageEventContent::text_plain(output) RoomMessageEventContent::text_plain(output)
} }
AdminCommand::ListLocalUsers => match services().users.list_local_users() { AdminCommand::ListLocalUsers => match services()
.users
.list_local_users()
{
Ok(users) => { Ok(users) => {
let mut msg: String = format!("Found {} local user account(s):\n", users.len()); let mut msg: String = format!(
"Found {} local user account(s):\n",
users.len()
);
msg += &users.join("\n"); msg += &users.join("\n");
RoomMessageEventContent::text_plain(&msg) RoomMessageEventContent::text_plain(&msg)
} }
Err(e) => RoomMessageEventContent::text_plain(e.to_string()), Err(e) => RoomMessageEventContent::text_plain(e.to_string()),
}, },
AdminCommand::IncomingFederation => { AdminCommand::IncomingFederation => {
let map = services().globals.roomid_federationhandletime.read().await; let map =
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); services().globals.roomid_federationhandletime.read().await;
let mut msg: String =
format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() { for (r, (e, i)) in map.iter() {
let elapsed = i.elapsed(); let elapsed = i.elapsed();
@ -431,17 +477,26 @@ impl Service {
} }
RoomMessageEventContent::text_plain(&msg) RoomMessageEventContent::text_plain(&msg)
} }
AdminCommand::GetAuthChain { event_id } => { AdminCommand::GetAuthChain {
event_id,
} => {
let event_id = Arc::<EventId>::from(event_id); let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { if let Some(event) =
services().rooms.timeline.get_pdu_json(&event_id)?
{
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database"))?; .ok_or_else(|| {
Error::bad_database("Invalid event in database")
})?;
let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { let room_id =
Error::bad_database("Invalid room id field in event in database") <&RoomId>::try_from(room_id_str).map_err(|_| {
})?; Error::bad_database(
"Invalid room id field in event in database",
)
})?;
let start = Instant::now(); let start = Instant::now();
let count = services() let count = services()
.rooms .rooms
@ -458,29 +513,47 @@ impl Service {
} }
} }
AdminCommand::ParsePdu => { AdminCommand::ParsePdu => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{ {
let string = body[1..body.len() - 1].join("\n"); let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) { match serde_json::from_str(&string) {
Ok(value) => { Ok(value) => {
match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { match ruma::signatures::reference_hash(
&value,
&RoomVersionId::V6,
) {
Ok(hash) => { Ok(hash) => {
let event_id = EventId::parse(format!("${hash}")); let event_id =
EventId::parse(format!("${hash}"));
match serde_json::from_value::<PduEvent>( match serde_json::from_value::<PduEvent>(
serde_json::to_value(value).expect("value is json"), serde_json::to_value(value)
.expect("value is json"),
) { ) {
Ok(pdu) => RoomMessageEventContent::text_plain(format!( Ok(pdu) => {
"EventId: {event_id:?}\n{pdu:#?}" RoomMessageEventContent::text_plain(
)), format!(
Err(e) => RoomMessageEventContent::text_plain(format!( "EventId: {event_id:?}\\
"EventId: {event_id:?}\nCould not parse event: {e}" n{pdu:#?}"
)), ),
)
}
Err(e) => {
RoomMessageEventContent::text_plain(
format!(
"EventId: {event_id:?}\\
nCould not parse event: \
{e}"
),
)
}
} }
} }
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(
"Could not parse PDU JSON: {e:?}" format!("Could not parse PDU JSON: {e:?}"),
)), ),
} }
} }
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
@ -488,10 +561,14 @@ impl Service {
)), )),
} }
} else { } else {
RoomMessageEventContent::text_plain("Expected code block in command body.") RoomMessageEventContent::text_plain(
"Expected code block in command body.",
)
} }
} }
AdminCommand::GetPdu { event_id } => { AdminCommand::GetPdu {
event_id,
} => {
let mut outlier = false; let mut outlier = false;
let mut pdu_json = services() let mut pdu_json = services()
.rooms .rooms
@ -499,7 +576,8 @@ impl Service {
.get_non_outlier_pdu_json(&event_id)?; .get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() { if pdu_json.is_none() {
outlier = true; outlier = true;
pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; pdu_json =
services().rooms.timeline.get_pdu_json(&event_id)?;
} }
match pdu_json { match pdu_json {
Some(json) => { Some(json) => {
@ -516,7 +594,8 @@ impl Service {
json_text json_text
), ),
format!( format!(
"<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n", "<p>{}</p>\n<pre><code \
class=\"language-json\">{}\n</code></pre>\n",
if outlier { if outlier {
"PDU is outlier" "PDU is outlier"
} else { } else {
@ -526,7 +605,9 @@ impl Service {
), ),
) )
} }
None => RoomMessageEventContent::text_plain("PDU not found."), None => {
RoomMessageEventContent::text_plain("PDU not found.")
}
} }
} }
AdminCommand::MemoryUsage => { AdminCommand::MemoryUsage => {
@ -537,30 +618,42 @@ impl Service {
"Services:\n{response1}\n\nDatabase:\n{response2}" "Services:\n{response1}\n\nDatabase:\n{response2}"
)) ))
} }
AdminCommand::ClearDatabaseCaches { amount } => { AdminCommand::ClearDatabaseCaches {
amount,
} => {
services().globals.db.clear_caches(amount); services().globals.db.clear_caches(amount);
RoomMessageEventContent::text_plain("Done.") RoomMessageEventContent::text_plain("Done.")
} }
AdminCommand::ClearServiceCaches { amount } => { AdminCommand::ClearServiceCaches {
amount,
} => {
services().clear_caches(amount).await; services().clear_caches(amount).await;
RoomMessageEventContent::text_plain("Done.") RoomMessageEventContent::text_plain("Done.")
} }
AdminCommand::ShowConfig => { AdminCommand::ShowConfig => {
// Construct and send the response // Construct and send the response
RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) RoomMessageEventContent::text_plain(format!(
"{}",
services().globals.config
))
} }
AdminCommand::ResetPassword { username } => { AdminCommand::ResetPassword {
username,
} => {
let user_id = match UserId::parse_with_server_name( let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(), username.as_str().to_lowercase(),
services().globals.server_name(), services().globals.server_name(),
) { ) {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!( return Ok(RoomMessageEventContent::text_plain(
"The supplied username is not a valid username: {e}" format!(
))) "The supplied username is not a valid \
username: {e}"
),
))
} }
}; };
@ -589,23 +682,29 @@ impl Service {
)); ));
} }
let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); let new_password =
utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
match services() match services()
.users .users
.set_password(&user_id, Some(new_password.as_str())) .set_password(&user_id, Some(new_password.as_str()))
{ {
Ok(()) => RoomMessageEventContent::text_plain(format!( Ok(()) => RoomMessageEventContent::text_plain(format!(
"Successfully reset the password for user {user_id}: {new_password}" "Successfully reset the password for user {user_id}: \
{new_password}"
)), )),
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
"Couldn't reset the password for user {user_id}: {e}" "Couldn't reset the password for user {user_id}: {e}"
)), )),
} }
} }
AdminCommand::CreateUser { username, password } => { AdminCommand::CreateUser {
let password = username,
password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); password,
} => {
let password = password.unwrap_or_else(|| {
utils::random_string(AUTO_GEN_PASSWORD_LENGTH)
});
// Validate user id // Validate user id
let user_id = match UserId::parse_with_server_name( let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(), username.as_str().to_lowercase(),
@ -613,9 +712,12 @@ impl Service {
) { ) {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!( return Ok(RoomMessageEventContent::text_plain(
"The supplied username is not a valid username: {e}" format!(
))) "The supplied username is not a valid \
username: {e}"
),
))
} }
}; };
if user_id.is_historical() { if user_id.is_historical() {
@ -647,24 +749,32 @@ impl Service {
.into(), .into(),
&serde_json::to_value(PushRulesEvent { &serde_json::to_value(PushRulesEvent {
content: PushRulesEventContent { content: PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id), global: ruma::push::Ruleset::server_default(
&user_id,
),
}, },
}) })
.expect("to json value always works"), .expect("to json value always works"),
)?; )?;
// we dont add a device since we're not the user, just the creator // we dont add a device since we're not the user, just the
// creator
// Inhibit login does not work for guests // Inhibit login does not work for guests
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"Created user with user_id: {user_id} and password: {password}" "Created user with user_id: {user_id} and password: \
{password}"
)) ))
} }
AdminCommand::DisableRoom { room_id } => { AdminCommand::DisableRoom {
room_id,
} => {
services().rooms.metadata.disable_room(&room_id, true)?; services().rooms.metadata.disable_room(&room_id, true)?;
RoomMessageEventContent::text_plain("Room disabled.") RoomMessageEventContent::text_plain("Room disabled.")
} }
AdminCommand::EnableRoom { room_id } => { AdminCommand::EnableRoom {
room_id,
} => {
services().rooms.metadata.disable_room(&room_id, false)?; services().rooms.metadata.disable_room(&room_id, false)?;
RoomMessageEventContent::text_plain("Room enabled.") RoomMessageEventContent::text_plain("Room enabled.")
} }
@ -677,13 +787,16 @@ impl Service {
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"User {user_id} doesn't exist on this server" "User {user_id} doesn't exist on this server"
)) ))
} else if user_id.server_name() != services().globals.server_name() { } else if user_id.server_name()
!= services().globals.server_name()
{
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"User {user_id} is not from this server" "User {user_id} is not from this server"
)) ))
} else { } else {
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"Making {user_id} leave all rooms before deactivation..." "Making {user_id} leave all rooms before \
deactivation..."
)); ));
services().users.deactivate_account(&user_id)?; services().users.deactivate_account(&user_id)?;
@ -697,10 +810,18 @@ impl Service {
)) ))
} }
} }
AdminCommand::DeactivateAll { leave_rooms, force } => { AdminCommand::DeactivateAll {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" leave_rooms,
force,
} => {
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{ {
let users = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); let users = body
.clone()
.drain(1..body.len() - 1)
.collect::<Vec<_>>();
let mut user_ids = Vec::new(); let mut user_ids = Vec::new();
let mut remote_ids = Vec::new(); let mut remote_ids = Vec::new();
@ -710,7 +831,9 @@ impl Service {
for &user in &users { for &user in &users {
match <&UserId>::try_from(user) { match <&UserId>::try_from(user) {
Ok(user_id) => { Ok(user_id) => {
if user_id.server_name() != services().globals.server_name() { if user_id.server_name()
!= services().globals.server_name()
{
remote_ids.push(user_id); remote_ids.push(user_id);
} else if !services().users.exists(user_id)? { } else if !services().users.exists(user_id)? {
non_existant_ids.push(user_id); non_existant_ids.push(user_id);
@ -727,39 +850,59 @@ impl Service {
let mut markdown_message = String::new(); let mut markdown_message = String::new();
let mut html_message = String::new(); let mut html_message = String::new();
if !invalid_users.is_empty() { if !invalid_users.is_empty() {
markdown_message.push_str("The following user ids are not valid:\n```\n"); markdown_message.push_str(
html_message.push_str("The following user ids are not valid:\n<pre>\n"); "The following user ids are not valid:\n```\n",
);
html_message.push_str(
"The following user ids are not valid:\n<pre>\n",
);
for invalid_user in invalid_users { for invalid_user in invalid_users {
writeln!(markdown_message, "{invalid_user}") writeln!(markdown_message, "{invalid_user}")
.expect("write to in-memory buffer should succeed"); .expect(
writeln!(html_message, "{invalid_user}") "write to in-memory buffer should succeed",
.expect("write to in-memory buffer should succeed"); );
writeln!(html_message, "{invalid_user}").expect(
"write to in-memory buffer should succeed",
);
} }
markdown_message.push_str("```\n\n"); markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n"); html_message.push_str("</pre>\n\n");
} }
if !remote_ids.is_empty() { if !remote_ids.is_empty() {
markdown_message markdown_message.push_str(
.push_str("The following users are not from this server:\n```\n"); "The following users are not from this \
html_message server:\n```\n",
.push_str("The following users are not from this server:\n<pre>\n"); );
html_message.push_str(
"The following users are not from this \
server:\n<pre>\n",
);
for remote_id in remote_ids { for remote_id in remote_ids {
writeln!(markdown_message, "{remote_id}") writeln!(markdown_message, "{remote_id}").expect(
.expect("write to in-memory buffer should succeed"); "write to in-memory buffer should succeed",
writeln!(html_message, "{remote_id}") );
.expect("write to in-memory buffer should succeed"); writeln!(html_message, "{remote_id}").expect(
"write to in-memory buffer should succeed",
);
} }
markdown_message.push_str("```\n\n"); markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n"); html_message.push_str("</pre>\n\n");
} }
if !non_existant_ids.is_empty() { if !non_existant_ids.is_empty() {
markdown_message.push_str("The following users do not exist:\n```\n"); markdown_message.push_str(
html_message.push_str("The following users do not exist:\n<pre>\n"); "The following users do not exist:\n```\n",
);
html_message.push_str(
"The following users do not exist:\n<pre>\n",
);
for non_existant_id in non_existant_ids { for non_existant_id in non_existant_ids {
writeln!(markdown_message, "{non_existant_id}") writeln!(markdown_message, "{non_existant_id}")
.expect("write to in-memory buffer should succeed"); .expect(
writeln!(html_message, "{non_existant_id}") "write to in-memory buffer should succeed",
.expect("write to in-memory buffer should succeed"); );
writeln!(html_message, "{non_existant_id}").expect(
"write to in-memory buffer should succeed",
);
} }
markdown_message.push_str("```\n\n"); markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n"); html_message.push_str("</pre>\n\n");
@ -775,21 +918,24 @@ impl Service {
let mut admins = Vec::new(); let mut admins = Vec::new();
if !force { if !force {
user_ids.retain(|&user_id| match services().users.is_admin(user_id) { user_ids.retain(|&user_id| {
Ok(is_admin) => { match services().users.is_admin(user_id) {
if is_admin { Ok(is_admin) => {
admins.push(user_id.localpart()); if is_admin {
false admins.push(user_id.localpart());
} else { false
true } else {
true
}
} }
Err(_) => false,
} }
Err(_) => false,
}); });
} }
for &user_id in &user_ids { for &user_id in &user_ids {
if services().users.deactivate_account(user_id).is_ok() { if services().users.deactivate_account(user_id).is_ok()
{
deactivation_count += 1; deactivation_count += 1;
} }
} }
@ -807,16 +953,25 @@ impl Service {
"Deactivated {deactivation_count} accounts." "Deactivated {deactivation_count} accounts."
)) ))
} else { } else {
RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) RoomMessageEventContent::text_plain(format!(
"Deactivated {} accounts.\nSkipped admin \
accounts: {:?}. Use --force to deactivate admin \
accounts",
deactivation_count,
admins.join(", ")
))
} }
} else { } else {
RoomMessageEventContent::text_plain( RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.", "Expected code block in command body. Add --help for \
details.",
) )
} }
} }
AdminCommand::SignJson => { AdminCommand::SignJson => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{ {
let string = body[1..body.len() - 1].join("\n"); let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) { match serde_json::from_str(&string) {
@ -827,20 +982,26 @@ impl Service {
&mut value, &mut value,
) )
.expect("our request json is what ruma expects"); .expect("our request json is what ruma expects");
let json_text = serde_json::to_string_pretty(&value) let json_text =
.expect("canonical json is valid json"); serde_json::to_string_pretty(&value)
.expect("canonical json is valid json");
RoomMessageEventContent::text_plain(json_text) RoomMessageEventContent::text_plain(json_text)
} }
Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), Err(e) => RoomMessageEventContent::text_plain(format!(
"Invalid json: {e}"
)),
} }
} else { } else {
RoomMessageEventContent::text_plain( RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.", "Expected code block in command body. Add --help for \
details.",
) )
} }
} }
AdminCommand::VerifyJson => { AdminCommand::VerifyJson => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{ {
let string = body[1..body.len() - 1].join("\n"); let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) { match serde_json::from_str(&string) {
@ -850,22 +1011,35 @@ impl Service {
services() services()
.rooms .rooms
.event_handler .event_handler
.fetch_required_signing_keys(&value, &pub_key_map) .fetch_required_signing_keys(
&value,
&pub_key_map,
)
.await?; .await?;
let pub_key_map = pub_key_map.read().await; let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) { match ruma::signatures::verify_json(
Ok(()) => RoomMessageEventContent::text_plain("Signature correct"), &pub_key_map,
Err(e) => RoomMessageEventContent::text_plain(format!( &value,
"Signature verification failed: {e}" ) {
)), Ok(()) => RoomMessageEventContent::text_plain(
"Signature correct",
),
Err(e) => RoomMessageEventContent::text_plain(
format!(
"Signature verification failed: {e}"
),
),
} }
} }
Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), Err(e) => RoomMessageEventContent::text_plain(format!(
"Invalid json: {e}"
)),
} }
} else { } else {
RoomMessageEventContent::text_plain( RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.", "Expected code block in command body. Add --help for \
details.",
) )
} }
} }
@ -876,7 +1050,8 @@ impl Service {
// Utility to turn clap's `--help` text to HTML. // Utility to turn clap's `--help` text to HTML.
fn usage_to_html(text: &str, server_name: &ServerName) -> String { fn usage_to_html(text: &str, server_name: &ServerName) -> String {
// Replace `@grapevine:servername:-subcmdname` with `@grapevine:servername: subcmdname` // Replace `@grapevine:servername:-subcmdname` with
// `@grapevine:servername: subcmdname`
let localpart = if services().globals.config.conduit_compat { let localpart = if services().globals.config.conduit_compat {
"conduit" "conduit"
} else { } else {
@ -892,11 +1067,13 @@ impl Service {
let text = text.replace("SUBCOMMAND", "COMMAND"); let text = text.replace("SUBCOMMAND", "COMMAND");
let text = text.replace("subcommand", "command"); let text = text.replace("subcommand", "command");
// Escape option names (e.g. `<element-id>`) since they look like HTML tags // Escape option names (e.g. `<element-id>`) since they look like HTML
// tags
let text = text.replace('<', "&lt;").replace('>', "&gt;"); let text = text.replace('<', "&lt;").replace('>', "&gt;");
// Italicize the first line (command name and version text) // Italicize the first line (command name and version text)
let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); let re =
Regex::new("^(.*?)\n").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<em>$1</em>\n"); let text = re.replace_all(&text, "<em>$1</em>\n");
// Unmerge wrapped lines // Unmerge wrapped lines
@ -911,8 +1088,8 @@ impl Service {
.expect("Regex compilation should not fail"); .expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<code>$1</code>: $4"); let text = re.replace_all(&text, "<code>$1</code>: $4");
// Look for a `[commandbody]()` tag. If it exists, use all lines below it that // Look for a `[commandbody]()` tag. If it exists, use all lines below
// start with a `#` in the USAGE section. // it that start with a `#` in the USAGE section.
let mut text_lines: Vec<&str> = text.lines().collect(); let mut text_lines: Vec<&str> = text.lines().collect();
let command_body = text_lines let command_body = text_lines
.iter() .iter()
@ -936,8 +1113,11 @@ impl Service {
// This makes the usage of e.g. `register-appservice` more accurate // This makes the usage of e.g. `register-appservice` more accurate
let re = Regex::new("(?m)^USAGE:\n (.*?)\n\n") let re = Regex::new("(?m)^USAGE:\n (.*?)\n\n")
.expect("Regex compilation should not fail"); .expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>") re.replace_all(
.replace("[commandbodyblock]", &command_body) &text,
"USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>",
)
.replace("[commandbodyblock]", &command_body)
}; };
// Add HTML line-breaks // Add HTML line-breaks
@ -949,8 +1129,9 @@ impl Service {
/// Create the admin room. /// Create the admin room.
/// ///
/// Users in this room are considered admins by grapevine, and the room can be /// Users in this room are considered admins by grapevine, and the room can
/// used to issue admin commands by talking to the server user inside it. /// be used to issue admin commands by talking to the server user inside
/// it.
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn create_admin_room(&self) -> Result<()> { pub(crate) async fn create_admin_room(&self) -> Result<()> {
let room_id = RoomId::new(services().globals.server_name()); let room_id = RoomId::new(services().globals.server_name());
@ -993,7 +1174,9 @@ impl Service {
| RoomVersionId::V7 | RoomVersionId::V7
| RoomVersionId::V8 | RoomVersionId::V8
| RoomVersionId::V9 | RoomVersionId::V9
| RoomVersionId::V10 => RoomCreateEventContent::new_v1(grapevine_user.clone()), | RoomVersionId::V10 => {
RoomCreateEventContent::new_v1(grapevine_user.clone())
}
RoomVersionId::V11 => RoomCreateEventContent::new_v11(), RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => unreachable!("Validity of room version already checked"), _ => unreachable!("Validity of room version already checked"),
}; };
@ -1008,7 +1191,8 @@ impl Service {
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomCreate, event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"), content: to_raw_value(&content)
.expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
redacts: None, redacts: None,
@ -1079,8 +1263,10 @@ impl Service {
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomJoinRules, event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) content: to_raw_value(&RoomJoinRulesEventContent::new(
.expect("event is valid, we just created it"), JoinRule::Invite,
))
.expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
redacts: None, redacts: None,
@ -1098,9 +1284,11 @@ impl Service {
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomHistoryVisibility, event_type: TimelineEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new( content: to_raw_value(
HistoryVisibility::Shared, &RoomHistoryVisibilityEventContent::new(
)) HistoryVisibility::Shared,
),
)
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
@ -1134,15 +1322,18 @@ impl Service {
.await?; .await?;
// 5. Events implied by name and topic // 5. Events implied by name and topic
let room_name = format!("{} Admin Room", services().globals.server_name()); let room_name =
format!("{} Admin Room", services().globals.server_name());
services() services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomName, event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(room_name)) content: to_raw_value(&RoomNameEventContent::new(
.expect("event is valid, we just created it"), room_name,
))
.expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(String::new()), state_key: Some(String::new()),
redacts: None, redacts: None,
@ -1160,7 +1351,10 @@ impl Service {
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomTopic, event_type: TimelineEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent { content: to_raw_value(&RoomTopicEventContent {
topic: format!("Manage {}", services().globals.server_name()), topic: format!(
"Manage {}",
services().globals.server_name()
),
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -1174,9 +1368,10 @@ impl Service {
.await?; .await?;
// 6. Room alias // 6. Room alias
let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name()) let alias: OwnedRoomAliasId =
.try_into() format!("#admins:{}", services().globals.server_name())
.expect("#admins:server_name is a valid alias name"); .try_into()
.expect("#admins:server_name is a valid alias name");
services() services()
.rooms .rooms
@ -1206,7 +1401,8 @@ impl Service {
/// Gets the room ID of the admin room /// Gets the room ID of the admin room
/// ///
/// Errors are propagated from the database, and will have None if there is no admin room /// Errors are propagated from the database, and will have None if there is
/// no admin room
// Allowed because this function uses `services()` // Allowed because this function uses `services()`
#[allow(clippy::unused_self)] #[allow(clippy::unused_self)]
pub(crate) fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> { pub(crate) fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> {
@ -1215,10 +1411,7 @@ impl Service {
.try_into() .try_into()
.expect("#admins:server_name is a valid alias name"); .expect("#admins:server_name is a valid alias name");
services() services().rooms.alias.resolve_local_alias(&admin_room_alias)
.rooms
.alias
.resolve_local_alias(&admin_room_alias)
} }
/// Invite the user to the grapevine admin room. /// Invite the user to the grapevine admin room.
@ -1356,11 +1549,13 @@ mod test {
} }
fn get_help_inner(input: &str) { fn get_help_inner(input: &str) {
let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) let error =
.unwrap_err() AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.to_string(); .unwrap_err()
.to_string();
// Search for a handful of keywords that suggest the help printed properly // Search for a handful of keywords that suggest the help printed
// properly
assert!(error.contains("Usage:")); assert!(error.contains("Usage:"));
assert!(error.contains("Commands:")); assert!(error.contains("Commands:"));
assert!(error.contains("Options:")); assert!(error.contains("Options:"));

View file

@ -3,7 +3,6 @@ mod data;
use std::collections::BTreeMap; use std::collections::BTreeMap;
pub(crate) use data::Data; pub(crate) use data::Data;
use futures_util::Future; use futures_util::Future;
use regex::RegexSet; use regex::RegexSet;
use ruma::{ use ruma::{
@ -48,6 +47,8 @@ impl NamespaceRegex {
} }
impl TryFrom<Vec<Namespace>> for NamespaceRegex { impl TryFrom<Vec<Namespace>> for NamespaceRegex {
type Error = regex::Error;
fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> { fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> {
let mut exclusive = vec![]; let mut exclusive = vec![];
let mut non_exclusive = vec![]; let mut non_exclusive = vec![];
@ -73,8 +74,6 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
}, },
}) })
} }
type Error = regex::Error;
} }
/// Appservice registration combined with its compiled regular expressions. /// Appservice registration combined with its compiled regular expressions.
@ -99,6 +98,8 @@ impl RegistrationInfo {
} }
impl TryFrom<Registration> for RegistrationInfo { impl TryFrom<Registration> for RegistrationInfo {
type Error = regex::Error;
fn try_from(value: Registration) -> Result<RegistrationInfo, regex::Error> { fn try_from(value: Registration) -> Result<RegistrationInfo, regex::Error> {
Ok(RegistrationInfo { Ok(RegistrationInfo {
users: value.namespaces.users.clone().try_into()?, users: value.namespaces.users.clone().try_into()?,
@ -107,8 +108,6 @@ impl TryFrom<Registration> for RegistrationInfo {
registration: value, registration: value,
}) })
} }
type Error = regex::Error;
} }
pub(crate) struct Service { pub(crate) struct Service {
@ -135,8 +134,12 @@ impl Service {
registration_info: RwLock::new(registration_info), registration_info: RwLock::new(registration_info),
}) })
} }
/// Registers an appservice and returns the ID to the caller. /// Registers an appservice and returns the ID to the caller.
pub(crate) async fn register_appservice(&self, yaml: Registration) -> Result<String> { pub(crate) async fn register_appservice(
&self,
yaml: Registration,
) -> Result<String> {
//TODO: Check for collisions between exclusive appservice namespaces //TODO: Check for collisions between exclusive appservice namespaces
self.registration_info self.registration_info
.write() .write()
@ -151,19 +154,27 @@ impl Service {
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
pub(crate) async fn unregister_appservice(&self, service_name: &str) -> Result<()> { pub(crate) async fn unregister_appservice(
&self,
service_name: &str,
) -> Result<()> {
services() services()
.appservice .appservice
.registration_info .registration_info
.write() .write()
.await .await
.remove(service_name) .remove(service_name)
.ok_or_else(|| crate::Error::AdminCommand("Appservice not found"))?; .ok_or_else(|| {
crate::Error::AdminCommand("Appservice not found")
})?;
self.db.unregister_appservice(service_name) self.db.unregister_appservice(service_name)
} }
pub(crate) async fn get_registration(&self, id: &str) -> Option<Registration> { pub(crate) async fn get_registration(
&self,
id: &str,
) -> Option<Registration> {
self.registration_info self.registration_info
.read() .read()
.await .await
@ -173,15 +184,13 @@ impl Service {
} }
pub(crate) async fn iter_ids(&self) -> Vec<String> { pub(crate) async fn iter_ids(&self) -> Vec<String> {
self.registration_info self.registration_info.read().await.keys().cloned().collect()
.read()
.await
.keys()
.cloned()
.collect()
} }
pub(crate) async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> { pub(crate) async fn find_from_token(
&self,
token: &str,
) -> Option<RegistrationInfo> {
self.read() self.read()
.await .await
.values() .values()
@ -207,8 +216,12 @@ impl Service {
pub(crate) fn read( pub(crate) fn read(
&self, &self,
) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> ) -> impl Future<
{ Output = tokio::sync::RwLockReadGuard<
'_,
BTreeMap<String, RegistrationInfo>,
>,
> {
self.registration_info.read() self.registration_info.read()
} }
} }

View file

@ -15,7 +15,9 @@ pub(crate) trait Data: Send + Sync {
fn get_registration(&self, id: &str) -> Result<Option<Registration>>; fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>; fn iter_ids<'a>(
&'a self,
) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
fn all(&self) -> Result<Vec<(String, Registration)>>; fn all(&self) -> Result<Vec<(String, Registration)>>;
} }

View file

@ -1,26 +1,4 @@
mod data; mod data;
pub(crate) use data::Data;
use ruma::{
serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId,
};
use crate::api::server_server::FedDest;
use crate::{services, Config, Error, Result};
use futures_util::FutureExt;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service as HyperService,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
use ruma::{
api::{
client::sync::sync_events,
federation::discovery::{ServerSigningKeys, VerifyKey},
},
DeviceId, RoomVersionId, ServerName, UserId,
};
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
error::Error as StdError, error::Error as StdError,
@ -35,11 +13,29 @@ use std::{
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use base64::{engine::general_purpose, Engine as _};
pub(crate) use data::Data;
use futures_util::FutureExt;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service as HyperService,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
use ruma::{
api::{
client::sync::sync_events,
federation::discovery::{ServerSigningKeys, VerifyKey},
},
serde::Base64,
DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId, RoomVersionId, ServerName, UserId,
};
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore}; use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
use tracing::{error, info}; use tracing::{error, info};
use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::TokioAsyncResolver;
use base64::{engine::general_purpose, Engine as _}; use crate::{api::server_server::FedDest, services, Config, Error, Result};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>; type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
@ -66,27 +62,40 @@ pub(crate) struct Service {
default_client: reqwest::Client, default_client: reqwest::Client,
pub(crate) stable_room_versions: Vec<RoomVersionId>, pub(crate) stable_room_versions: Vec<RoomVersionId>,
pub(crate) unstable_room_versions: Vec<RoomVersionId>, pub(crate) unstable_room_versions: Vec<RoomVersionId>,
pub(crate) bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>, pub(crate) bad_event_ratelimiter:
pub(crate) bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>, Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub(crate) bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>, pub(crate) bad_signature_ratelimiter:
pub(crate) servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>, Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub(crate) sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>, pub(crate) bad_query_ratelimiter:
pub(crate) roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub(crate) servername_ratelimiter:
Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub(crate) sync_receivers:
RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub(crate) roomid_mutex_insert:
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
// this lock will be held longer // this lock will be held longer
pub(crate) roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, pub(crate) roomid_mutex_federation:
pub(crate) roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>, RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_federationhandletime:
RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub(crate) stateres_mutex: Arc<Mutex<()>>, pub(crate) stateres_mutex: Arc<Mutex<()>>,
pub(crate) rotate: RotationHandler, pub(crate) rotate: RotationHandler,
pub(crate) shutdown: AtomicBool, pub(crate) shutdown: AtomicBool,
} }
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. /// Handles "rotation" of long-polling requests. "Rotation" in this context is
/// similar to "rotation" of log files and the like.
/// ///
/// This is utilized to have sync workers return early and release read locks on the database. /// This is utilized to have sync workers return early and release read locks on
pub(crate) struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>); /// the database.
pub(crate) struct RotationHandler(
broadcast::Sender<()>,
broadcast::Receiver<()>,
);
impl RotationHandler { impl RotationHandler {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
@ -136,7 +145,10 @@ impl Resolve for Resolver {
.and_then(|(override_name, port)| { .and_then(|(override_name, port)| {
override_name.first().map(|first_name| { override_name.first().map(|first_name| {
let x: Box<dyn Iterator<Item = SocketAddr> + Send> = let x: Box<dyn Iterator<Item = SocketAddr> + Send> =
Box::new(iter::once(SocketAddr::new(*first_name, *port))); Box::new(iter::once(SocketAddr::new(
*first_name,
*port,
)));
let x: Resolving = Box::pin(future::ready(Ok(x))); let x: Resolving = Box::pin(future::ready(Ok(x)));
x x
}) })
@ -144,9 +156,11 @@ impl Resolve for Resolver {
.unwrap_or_else(|| { .unwrap_or_else(|| {
let this = &mut self.inner.clone(); let this = &mut self.inner.clone();
Box::pin(HyperService::<Name>::call(this, name).map(|result| { Box::pin(HyperService::<Name>::call(this, name).map(|result| {
result result.map(|addrs| -> Addrs { Box::new(addrs) }).map_err(
.map(|addrs| -> Addrs { Box::new(addrs) }) |err| -> Box<dyn StdError + Send + Sync> {
.map_err(|err| -> Box<dyn StdError + Send + Sync> { Box::new(err) }) Box::new(err)
},
)
})) }))
}) })
} }
@ -167,10 +181,9 @@ impl Service {
let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new())); let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new()));
let jwt_decoding_key = config let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| {
.jwt_secret jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())
.as_ref() });
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let default_client = reqwest_client_builder(&config)?.build()?; let default_client = reqwest_client_builder(&config)?.build()?;
let federation_client = reqwest_client_builder(&config)? let federation_client = reqwest_client_builder(&config)?
@ -187,20 +200,28 @@ impl Service {
RoomVersionId::V11, RoomVersionId::V11,
]; ];
// Experimental, partially supported room versions // Experimental, partially supported room versions
let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; let unstable_room_versions =
vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let mut s = Self { let mut s = Self {
db, db,
config, config,
keypair: Arc::new(keypair), keypair: Arc::new(keypair),
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { dns_resolver: TokioAsyncResolver::tokio_from_system_conf()
error!( .map_err(|e| {
"Failed to set up trust dns resolver with system config: {}", error!(
e "Failed to set up trust dns resolver with system \
); config: {}",
Error::bad_config("Failed to set up trust dns resolver with system config.") e
})?, );
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), Error::bad_config(
"Failed to set up trust dns resolver with system \
config.",
)
})?,
actual_destination_cache: Arc::new(
RwLock::new(WellKnownMap::new()),
),
tls_name_override, tls_name_override,
federation_client, federation_client,
default_client, default_client,
@ -223,12 +244,11 @@ impl Service {
fs::create_dir_all(s.get_media_folder())?; fs::create_dir_all(s.get_media_folder())?;
if !s if !s.supported_room_versions().contains(&s.config.default_room_version)
.supported_room_versions()
.contains(&s.config.default_room_version)
{ {
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = crate::config::default_default_room_version(); s.config.default_room_version =
crate::config::default_default_room_version();
}; };
Ok(s) Ok(s)
@ -261,7 +281,11 @@ impl Service {
self.db.current_count() self.db.current_count()
} }
pub(crate) async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { pub(crate) async fn watch(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
self.db.watch(user_id, device_id).await self.db.watch(user_id, device_id).await
} }
@ -313,7 +337,9 @@ impl Service {
&self.dns_resolver &self.dns_resolver
} }
pub(crate) fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { pub(crate) fn jwt_decoding_key(
&self,
) -> Option<&jsonwebtoken::DecodingKey> {
self.jwt_decoding_key.as_ref() self.jwt_decoding_key.as_ref()
} }
@ -353,7 +379,8 @@ impl Service {
/// TODO: the key valid until timestamp is only honored in room version > 4 /// TODO: the key valid until timestamp is only honored in room version > 4
/// Remove the outdated keys and insert the new ones. /// Remove the outdated keys and insert the new ones.
/// ///
/// This doesn't actually check that the keys provided are newer than the old set. /// This doesn't actually check that the keys provided are newer than the
/// old set.
pub(crate) fn add_signing_key( pub(crate) fn add_signing_key(
&self, &self,
origin: &ServerName, origin: &ServerName,
@ -362,7 +389,8 @@ impl Service {
self.db.add_signing_key(origin, new_keys) self.db.add_signing_key(origin, new_keys)
} }
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub(crate) fn signing_keys_for( pub(crate) fn signing_keys_for(
&self, &self,
origin: &ServerName, origin: &ServerName,

View file

@ -13,7 +13,8 @@ use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>; fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>; fn current_count(&self) -> Result<u64>;
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; async fn watch(&self, user_id: &UserId, device_id: &DeviceId)
-> Result<()>;
fn cleanup(&self) -> Result<()>; fn cleanup(&self) -> Result<()>;
fn memory_usage(&self) -> String; fn memory_usage(&self) -> String;
fn clear_caches(&self, amount: u32); fn clear_caches(&self, amount: u32);
@ -25,7 +26,8 @@ pub(crate) trait Data: Send + Sync {
new_keys: ServerSigningKeys, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>; ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for( fn signing_keys_for(
&self, &self,
origin: &ServerName, origin: &ServerName,

View file

@ -1,12 +1,13 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::Result;
use ruma::{ use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw, serde::Raw,
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn create_backup( fn create_backup(
&self, &self,
@ -23,12 +24,21 @@ pub(crate) trait Data: Send + Sync {
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String>; ) -> Result<String>;
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>; fn get_latest_backup_version(
&self,
user_id: &UserId,
) -> Result<Option<String>>;
fn get_latest_backup(&self, user_id: &UserId) fn get_latest_backup(
-> Result<Option<(String, Raw<BackupAlgorithm>)>>; &self,
user_id: &UserId,
) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>; fn get_backup(
&self,
user_id: &UserId,
version: &str,
) -> Result<Option<Raw<BackupAlgorithm>>>;
fn add_key( fn add_key(
&self, &self,
@ -66,7 +76,12 @@ pub(crate) trait Data: Send + Sync {
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; fn delete_room_keys(
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
) -> Result<()>;
fn delete_room_key( fn delete_room_key(
&self, &self,

View file

@ -2,15 +2,14 @@ mod data;
use std::io::Cursor; use std::io::Cursor;
pub(crate) use data::Data; pub(crate) use data::Data;
use crate::{services, Result};
use image::imageops::FilterType; use image::imageops::FilterType;
use tokio::{ use tokio::{
fs::File, fs::File,
io::{AsyncReadExt, AsyncWriteExt, BufReader}, io::{AsyncReadExt, AsyncWriteExt, BufReader},
}; };
use crate::{services, Result};
pub(crate) struct FileMeta { pub(crate) struct FileMeta {
pub(crate) content_disposition: Option<String>, pub(crate) content_disposition: Option<String>,
pub(crate) content_type: Option<String>, pub(crate) content_type: Option<String>,
@ -31,9 +30,13 @@ impl Service {
file: &[u8], file: &[u8],
) -> Result<()> { ) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail // Width, Height = 0 if it's not a thumbnail
let key = self let key = self.db.create_file_metadata(
.db mxc,
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; 0,
0,
content_disposition,
content_type,
)?;
let path = services().globals.get_media_file(&key); let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?; let mut f = File::create(path).await?;
@ -52,9 +55,13 @@ impl Service {
height: u32, height: u32,
file: &[u8], file: &[u8],
) -> Result<()> { ) -> Result<()> {
let key = let key = self.db.create_file_metadata(
self.db mxc,
.create_file_metadata(mxc, width, height, content_disposition, content_type)?; width,
height,
content_disposition,
content_type,
)?;
let path = services().globals.get_media_file(&key); let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?; let mut f = File::create(path).await?;
@ -84,9 +91,12 @@ impl Service {
} }
} }
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when /// Returns width, height of the thumbnail and whether it should be cropped.
/// the server should send the original file. /// Returns None when the server should send the original file.
fn thumbnail_properties(width: u32, height: u32) -> Option<(u32, u32, bool)> { fn thumbnail_properties(
width: u32,
height: u32,
) -> Option<(u32, u32, bool)> {
match (width, height) { match (width, height) {
(0..=32, 0..=32) => Some((32, 32, true)), (0..=32, 0..=32) => Some((32, 32, true)),
(0..=96, 0..=96) => Some((96, 96, true)), (0..=96, 0..=96) => Some((96, 96, true)),
@ -102,11 +112,15 @@ impl Service {
/// Here's an example on how it works: /// Here's an example on how it works:
/// ///
/// - Client requests an image with width=567, height=567 /// - Client requests an image with width=567, height=567
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails /// - Server rounds that up to (800, 600), so it doesn't have to save too
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) /// many thumbnails
/// - Server rounds that up again to (958, 600) to fix the aspect ratio
/// (only for width,height>96)
/// - Server creates the thumbnail and sends it to the user /// - Server creates the thumbnail and sends it to the user
/// ///
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. /// For width,height <= 96 the server uses another thumbnailing algorithm
/// which crops the image afterwards.
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_thumbnail( pub(crate) async fn get_thumbnail(
&self, &self,
mxc: String, mxc: String,
@ -154,7 +168,8 @@ impl Service {
} else { } else {
let (exact_width, exact_height) = { let (exact_width, exact_height) = {
// Copied from image::dynimage::resize_dimensions // Copied from image::dynimage::resize_dimensions
let use_width = (u64::from(width) * u64::from(original_height)) let use_width = (u64::from(width)
* u64::from(original_height))
<= (u64::from(original_width) * u64::from(height)); <= (u64::from(original_width) * u64::from(height));
let intermediate = if use_width { let intermediate = if use_width {
u64::from(original_height) * u64::from(width) u64::from(original_height) * u64::from(width)
@ -165,21 +180,31 @@ impl Service {
}; };
if use_width { if use_width {
if intermediate <= u64::from(::std::u32::MAX) { if intermediate <= u64::from(::std::u32::MAX) {
(width, intermediate.try_into().unwrap_or(u32::MAX)) (
width,
intermediate.try_into().unwrap_or(u32::MAX),
)
} else { } else {
( (
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate) (u64::from(width)
* u64::from(::std::u32::MAX)
/ intermediate)
.try_into() .try_into()
.unwrap_or(u32::MAX), .unwrap_or(u32::MAX),
::std::u32::MAX, ::std::u32::MAX,
) )
} }
} else if intermediate <= u64::from(::std::u32::MAX) { } else if intermediate <= u64::from(::std::u32::MAX) {
(intermediate.try_into().unwrap_or(u32::MAX), height) (
intermediate.try_into().unwrap_or(u32::MAX),
height,
)
} else { } else {
( (
::std::u32::MAX, ::std::u32::MAX,
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate) (u64::from(height)
* u64::from(::std::u32::MAX)
/ intermediate)
.try_into() .try_into()
.unwrap_or(u32::MAX), .unwrap_or(u32::MAX),
) )
@ -195,7 +220,8 @@ impl Service {
image::ImageOutputFormat::Png, image::ImageOutputFormat::Png,
)?; )?;
// Save thumbnail in database so we don't have to generate it again next time // Save thumbnail in database so we don't have to generate it
// again next time
let thumbnail_key = self.db.create_file_metadata( let thumbnail_key = self.db.create_file_metadata(
mxc, mxc,
width, width,

View file

@ -1,24 +1,31 @@
use crate::Error; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use ruma::{ use ruma::{
canonical_json::redact_content_in_place, canonical_json::redact_content_in_place,
events::{ events::{
room::{member::RoomMemberEventContent, redaction::RoomRedactionEventContent}, room::{
member::RoomMemberEventContent,
redaction::RoomRedactionEventContent,
},
space::child::HierarchySpaceChildEvent, space::child::HierarchySpaceChildEvent,
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent,
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
AnyTimelineEvent, StateEvent, TimelineEventType,
}, },
serde::Raw, serde::Raw,
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, state_res, CanonicalJsonObject, CanonicalJsonValue, EventId,
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
RoomVersionId, UInt, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{ use serde_json::{
json, json,
value::{to_raw_value, RawValue as RawJsonValue}, value::{to_raw_value, RawValue as RawJsonValue},
}; };
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use tracing::warn; use tracing::warn;
use crate::Error;
/// Content hashes of a PDU. /// Content hashes of a PDU.
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct EventHash { pub(crate) struct EventHash {
@ -61,10 +68,18 @@ impl PduEvent {
) -> crate::Result<()> { ) -> crate::Result<()> {
self.unsigned = None; self.unsigned = None;
let mut content = serde_json::from_str(self.content.get()) let mut content =
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?; serde_json::from_str(self.content.get()).map_err(|_| {
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) Error::bad_database("PDU in db has invalid content.")
.map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; })?;
redact_content_in_place(
&mut content,
&room_version_id,
self.kind.to_string(),
)
.map_err(|e| {
Error::Redaction(self.sender.server_name().to_owned(), e)
})?;
self.unsigned = Some(to_raw_value(&json!({ self.unsigned = Some(to_raw_value(&json!({
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
@ -78,10 +93,12 @@ impl PduEvent {
pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> { pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> {
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
serde_json::from_str(unsigned.get()) serde_json::from_str(unsigned.get()).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; Error::bad_database("Invalid unsigned in pdu event")
})?;
unsigned.remove("transaction_id"); unsigned.remove("transaction_id");
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); self.unsigned =
Some(to_raw_value(&unsigned).expect("unsigned is valid"));
} }
Ok(()) Ok(())
@ -91,31 +108,45 @@ impl PduEvent {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned .unsigned
.as_ref() .as_ref()
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) .map_or_else(
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; || Ok(BTreeMap::new()),
|u| serde_json::from_str(u.get()),
)
.map_err(|_| {
Error::bad_database("Invalid unsigned in pdu event")
})?;
unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap());
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); self.unsigned =
Some(to_raw_value(&unsigned).expect("unsigned is valid"));
Ok(()) Ok(())
} }
/// Copies the `redacts` property of the event to the `content` dict and vice-versa. /// Copies the `redacts` property of the event to the `content` dict and
/// vice-versa.
/// ///
/// This follows the specification's /// This follows the specification's
/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): /// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property):
/// ///
/// > For backwards-compatibility with older clients, servers should add a redacts /// > For backwards-compatibility with older clients, servers should add a
/// > property to the top level of m.room.redaction events in when serving such events /// > redacts
/// > property to the top level of m.room.redaction events in when serving
/// > such events
/// > over the Client-Server API. /// > over the Client-Server API.
/// > /// >
/// > For improved compatibility with newer clients, servers should add a redacts property /// > For improved compatibility with newer clients, servers should add a
/// > to the content of m.room.redaction events in older room versions when serving /// > redacts property
/// > to the content of m.room.redaction events in older room versions when
/// > serving
/// > such events over the Client-Server API. /// > such events over the Client-Server API.
pub(crate) fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) { pub(crate) fn copy_redacts(
&self,
) -> (Option<Arc<EventId>>, Box<RawJsonValue>) {
if self.kind == TimelineEventType::RoomRedaction { if self.kind == TimelineEventType::RoomRedaction {
if let Ok(mut content) = if let Ok(mut content) = serde_json::from_str::<
serde_json::from_str::<RoomRedactionEventContent>(self.content.get()) RoomRedactionEventContent,
>(self.content.get())
{ {
if let Some(redacts) = content.redacts { if let Some(redacts) = content.redacts {
return (Some(redacts.into()), self.content.clone()); return (Some(redacts.into()), self.content.clone());
@ -123,7 +154,9 @@ impl PduEvent {
content.redacts = Some(redacts.into()); content.redacts = Some(redacts.into());
return ( return (
self.redacts.clone(), self.redacts.clone(),
to_raw_value(&content).expect("Must be valid, we only added redacts field"), to_raw_value(&content).expect(
"Must be valid, we only added redacts field",
),
); );
} }
} }
@ -281,7 +314,9 @@ impl PduEvent {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub(crate) fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> { pub(crate) fn to_stripped_spacechild_state_event(
&self,
) -> Raw<HierarchySpaceChildEvent> {
let json = json!({ let json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
@ -294,7 +329,9 @@ impl PduEvent {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub(crate) fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { pub(crate) fn to_member_event(
&self,
) -> Raw<StateEvent<RoomMemberEventContent>> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
@ -318,16 +355,16 @@ impl PduEvent {
pub(crate) fn convert_to_outgoing_federation_event( pub(crate) fn convert_to_outgoing_federation_event(
mut pdu_json: CanonicalJsonObject, mut pdu_json: CanonicalJsonObject,
) -> Box<RawJsonValue> { ) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json if let Some(unsigned) =
.get_mut("unsigned") pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut())
.and_then(|val| val.as_object_mut())
{ {
unsigned.remove("transaction_id"); unsigned.remove("transaction_id");
} }
pdu_json.remove("event_id"); pdu_json.remove("event_id");
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") to_raw_value(&pdu_json)
.expect("CanonicalJson is valid serde_json::Value")
} }
pub(crate) fn from_id_val( pub(crate) fn from_id_val(
@ -374,11 +411,15 @@ impl state_res::Event for PduEvent {
self.state_key.as_deref() self.state_key.as_deref()
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn prev_events(
&self,
) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
Box::new(self.prev_events.iter()) Box::new(self.prev_events.iter())
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn auth_events(
&self,
) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
Box::new(self.auth_events.iter()) Box::new(self.auth_events.iter())
} }
@ -408,15 +449,17 @@ impl Ord for PduEvent {
/// Generates a correct eventId for the incoming pdu. /// Generates a correct eventId for the incoming pdu.
/// ///
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
/// CanonicalJsonValue>`.
pub(crate) fn gen_event_id_canonical_json( pub(crate) fn gen_event_id_canonical_json(
pdu: &RawJsonValue, pdu: &RawJsonValue,
room_version_id: &RoomVersionId, room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { ) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { let value: CanonicalJsonObject =
warn!("Error parsing incoming event {:?}: {:?}", pdu, e); serde_json::from_str(pdu.get()).map_err(|e| {
Error::BadServerResponse("Invalid PDU in server response") warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
})?; Error::BadServerResponse("Invalid PDU in server response")
})?;
let event_id = format!( let event_id = format!(
"${}", "${}",

View file

@ -1,27 +1,34 @@
mod data; mod data;
pub(crate) use data::Data; use std::{fmt::Debug, mem};
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
use crate::{services, Error, PduEvent, Result};
use bytes::BytesMut; use bytes::BytesMut;
pub(crate) use data::Data;
use ruma::{ use ruma::{
api::{ api::{
client::push::{set_pusher, Pusher, PusherKind}, client::push::{set_pusher, Pusher, PusherKind},
push_gateway::send_event_notification::{ push_gateway::send_event_notification::{
self, self,
v1::{Device, Notification, NotificationCounts, NotificationPriority}, v1::{
Device, Notification, NotificationCounts, NotificationPriority,
},
}, },
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
}, },
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType}, events::{
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent,
StateEventType, TimelineEventType,
},
push::{
Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat,
Ruleset, Tweak,
},
serde::Raw, serde::Raw,
uint, RoomId, UInt, UserId, uint, RoomId, UInt, UserId,
}; };
use std::{fmt::Debug, mem};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{services, Error, PduEvent, Result};
pub(crate) struct Service { pub(crate) struct Service {
pub(crate) db: &'static dyn Data, pub(crate) db: &'static dyn Data,
} }
@ -35,7 +42,11 @@ impl Service {
self.db.set_pusher(sender, pusher) self.db.set_pusher(sender, pusher)
} }
pub(crate) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { pub(crate) fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>> {
self.db.get_pusher(sender, pushkey) self.db.get_pusher(sender, pushkey)
} }
@ -43,7 +54,10 @@ impl Service {
self.db.get_pushers(sender) self.db.get_pushers(sender)
} }
pub(crate) fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> { pub(crate) fn get_pushkeys(
&self,
sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>>> {
self.db.get_pushkeys(sender) self.db.get_pushkeys(sender)
} }
@ -73,11 +87,8 @@ impl Service {
let reqwest_request = reqwest::Request::try_from(http_request)?; let reqwest_request = reqwest::Request::try_from(http_request)?;
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
let response = services() let response =
.globals services().globals.default_client().execute(reqwest_request).await;
.default_client()
.execute(reqwest_request)
.await;
match response { match response {
Ok(mut response) => { Ok(mut response) => {
@ -119,11 +130,16 @@ impl Service {
"Push gateway returned invalid response bytes {}\n{}", "Push gateway returned invalid response bytes {}\n{}",
destination, url destination, url
); );
Error::BadServerResponse("Push gateway returned bad response.") Error::BadServerResponse(
"Push gateway returned bad response.",
)
}) })
} }
Err(e) => { Err(e) => {
warn!("Could not send request to pusher {}: {}", destination, e); warn!(
"Could not send request to pusher {}: {}",
destination, e
);
Err(e.into()) Err(e.into())
} }
} }
@ -146,8 +162,9 @@ impl Service {
.state_accessor .state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
serde_json::from_str(ev.content.get()) serde_json::from_str(ev.content.get()).map_err(|_| {
.map_err(|_| Error::bad_database("invalid m.room.power_levels event")) Error::bad_database("invalid m.room.power_levels event")
})
}) })
.transpose()? .transpose()?
.unwrap_or_default(); .unwrap_or_default();
@ -228,11 +245,16 @@ impl Service {
PusherKind::Http(http) => { PusherKind::Http(http) => {
// TODO: // TODO:
// Two problems with this // Two problems with this
// 1. if "event_id_only" is the only format kind it seems we should never add more info // 1. if "event_id_only" is the only format kind it seems we
// should never add more info
// 2. can pusher/devices have conflicting formats // 2. can pusher/devices have conflicting formats
let event_id_only = http.format == Some(PushFormat::EventIdOnly); let event_id_only =
http.format == Some(PushFormat::EventIdOnly);
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); let mut device = Device::new(
pusher.ids.app_id.clone(),
pusher.ids.pushkey.clone(),
);
device.data.default_payload = http.default_payload.clone(); device.data.default_payload = http.default_payload.clone();
device.data.format = http.format.clone(); device.data.format = http.format.clone();
@ -251,32 +273,43 @@ impl Service {
notifi.counts = NotificationCounts::new(unread, uint!(0)); notifi.counts = NotificationCounts::new(unread, uint!(0));
if event.kind == TimelineEventType::RoomEncrypted if event.kind == TimelineEventType::RoomEncrypted
|| tweaks || tweaks.iter().any(|t| {
.iter() matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) })
{ {
notifi.prio = NotificationPriority::High; notifi.prio = NotificationPriority::High;
} }
if event_id_only { if event_id_only {
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) self.send_request(
.await?; &http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
} else { } else {
notifi.sender = Some(event.sender.clone()); notifi.sender = Some(event.sender.clone());
notifi.event_type = Some(event.kind.clone()); notifi.event_type = Some(event.kind.clone());
notifi.content = serde_json::value::to_raw_value(&event.content).ok(); notifi.content =
serde_json::value::to_raw_value(&event.content).ok();
if event.kind == TimelineEventType::RoomMember { if event.kind == TimelineEventType::RoomMember {
notifi.user_is_target = notifi.user_is_target = event.state_key.as_deref()
event.state_key.as_deref() == Some(event.sender.as_str()); == Some(event.sender.as_str());
} }
notifi.sender_display_name = services().users.displayname(&event.sender)?; notifi.sender_display_name =
services().users.displayname(&event.sender)?;
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; notifi.room_name = services()
.rooms
.state_accessor
.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) self.send_request(
.await?; &http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
} }
Ok(()) Ok(())

View file

@ -1,16 +1,27 @@
use crate::Result;
use ruma::{ use ruma::{
api::client::push::{set_pusher, Pusher}, api::client::push::{set_pusher, Pusher},
UserId, UserId,
}; };
pub(crate) trait Data: Send + Sync { use crate::Result;
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>; pub(crate) trait Data: Send + Sync {
fn set_pusher(
&self,
sender: &UserId,
pusher: set_pusher::v3::PusherAction,
) -> Result<()>;
fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>>;
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>; fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
fn get_pushkeys<'a>(&'a self, sender: &UserId) fn get_pushkeys<'a>(
-> Box<dyn Iterator<Item = Result<String>> + 'a>; &'a self,
sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
} }

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Creates or updates the alias to the given room id. /// Creates or updates the alias to the given room id.
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
@ -9,7 +10,10 @@ pub(crate) trait Data: Send + Sync {
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
/// Looks up the roomid for the given alias. /// Looks up the roomid for the given alias.
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>; fn resolve_local_alias(
&self,
alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>>;
/// Returns all local aliases that point to the given room /// Returns all local aliases that point to the given room
fn local_aliases_for_room<'a>( fn local_aliases_for_room<'a>(

View file

@ -43,7 +43,8 @@ impl Service {
let mut i = 0; let mut i = 0;
for id in starting_events { for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?; let short =
services().rooms.short.get_or_create_shorteventid(&id)?;
// I'm afraid to change this in case there is accidental reliance on // I'm afraid to change this in case there is accidental reliance on
// the truncation // the truncation
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)] #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
@ -64,7 +65,8 @@ impl Service {
continue; continue;
} }
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); let chunk_key: Vec<u64> =
chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services() if let Some(cached) = services()
.rooms .rooms
.auth_chain .auth_chain
@ -90,11 +92,13 @@ impl Service {
chunk_cache.extend(cached.iter().copied()); chunk_cache.extend(cached.iter().copied());
} else { } else {
misses2 += 1; misses2 += 1;
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); let auth_chain = Arc::new(
services() self.get_auth_chain_inner(room_id, &event_id)?,
.rooms );
.auth_chain services().rooms.auth_chain.cache_auth_chain(
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; vec![sevent_id],
Arc::clone(&auth_chain),
)?;
debug!( debug!(
event_id = ?event_id, event_id = ?event_id,
chain_length = ?auth_chain.len(), chain_length = ?auth_chain.len(),
@ -129,13 +133,17 @@ impl Service {
"Auth chain stats", "Auth chain stats",
); );
Ok(full_auth_chain Ok(full_auth_chain.into_iter().filter_map(move |sid| {
.into_iter() services().rooms.short.get_eventid_from_short(sid).ok()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) }))
} }
#[tracing::instrument(skip(self, event_id))] #[tracing::instrument(skip(self, event_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> { fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)]; let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new(); let mut found = HashSet::new();
@ -143,7 +151,10 @@ impl Service {
match services().rooms.timeline.get_pdu(&event_id) { match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => { Ok(Some(pdu)) => {
if pdu.room_id != room_id { if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Evil event in db",
));
} }
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
let sauthevent = services() let sauthevent = services()
@ -158,10 +169,17 @@ impl Service {
} }
} }
Ok(None) => { Ok(None) => {
warn!(?event_id, "Could not find pdu mentioned in auth events"); warn!(
?event_id,
"Could not find pdu mentioned in auth events"
);
} }
Err(error) => { Err(error) => {
error!(?event_id, ?error, "Could not load event in auth chain"); error!(
?event_id,
?error,
"Could not load event in auth chain"
);
} }
} }
} }

View file

@ -1,11 +1,15 @@
use crate::Result;
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn get_cached_eventid_authchain( fn get_cached_eventid_authchain(
&self, &self,
shorteventid: &[u64], shorteventid: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>>; ) -> Result<Option<Arc<HashSet<u64>>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) fn cache_auth_chain(
-> Result<()>; &self,
shorteventid: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()>;
} }

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Adds the room to the public room directory /// Adds the room to the public room directory
fn set_public(&self, room_id: &RoomId) -> Result<()>; fn set_public(&self, room_id: &RoomId) -> Result<()>;
@ -12,5 +13,7 @@ pub(crate) trait Data: Send + Sync {
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>; fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
/// Returns the unsorted public room directory /// Returns the unsorted public room directory
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; fn public_rooms<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
} }

View file

@ -1,5 +1,8 @@
use ruma::{
events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId,
};
use crate::Result; use crate::Result;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
@ -10,7 +13,8 @@ pub(crate) trait Data: Send + Sync {
event: ReceiptEvent, event: ReceiptEvent,
) -> Result<()>; ) -> Result<()>;
/// Returns an iterator over the most recent read receipts in a room that happened after the event with id `since`. /// Returns an iterator over the most recent read receipts in a room that
/// happened after the event with id `since`.
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn readreceipts_since<'a>( fn readreceipts_since<'a>(
&'a self, &'a self,
@ -27,11 +31,24 @@ pub(crate) trait Data: Send + Sync {
>; >;
/// Sets a private read marker at `count`. /// Sets a private read marker at `count`.
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; fn private_read_set(
&self,
room_id: &RoomId,
user_id: &UserId,
count: u64,
) -> Result<()>;
/// Returns the private read marker. /// Returns the private read marker.
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>>;
/// Returns the count of the last typing update in this room. /// Returns the count of the last typing update in this room.
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64>;
} }

View file

@ -1,8 +1,9 @@
use std::collections::BTreeMap;
use ruma::{ use ruma::{
events::{typing::TypingEventContent, SyncEphemeralRoomEvent}, events::{typing::TypingEventContent, SyncEphemeralRoomEvent},
OwnedRoomId, OwnedUserId, RoomId, UserId, OwnedRoomId, OwnedUserId, RoomId, UserId,
}; };
use std::collections::BTreeMap;
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, RwLock};
use tracing::trace; use tracing::trace;
@ -10,15 +11,16 @@ use crate::{services, utils, Result};
pub(crate) struct Service { pub(crate) struct Service {
// u64 is unix timestamp of timeout // u64 is unix timestamp of timeout
pub(crate) typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>, pub(crate) typing:
RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>,
// timestamp of the last change to typing users // timestamp of the last change to typing users
pub(crate) last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>, pub(crate) last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>,
pub(crate) typing_update_sender: broadcast::Sender<OwnedRoomId>, pub(crate) typing_update_sender: broadcast::Sender<OwnedRoomId>,
} }
impl Service { impl Service {
/// Sets a user as typing until the timeout timestamp is reached or `roomtyping_remove` is /// Sets a user as typing until the timeout timestamp is reached or
/// called. /// `roomtyping_remove` is called.
pub(crate) async fn typing_add( pub(crate) async fn typing_add(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -36,13 +38,20 @@ impl Service {
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!(
"receiver found what it was looking for and is no longer \
interested"
);
} }
Ok(()) Ok(())
} }
/// Removes a user from typing before the timeout is reached. /// Removes a user from typing before the timeout is reached.
pub(crate) async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub(crate) async fn typing_remove(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
self.typing self.typing
.write() .write()
.await .await
@ -54,7 +63,10 @@ impl Service {
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!(
"receiver found what it was looking for and is no longer \
interested"
);
} }
Ok(()) Ok(())
} }
@ -97,14 +109,20 @@ impl Service {
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!(
"receiver found what it was looking for and is no longer \
interested"
);
} }
} }
Ok(()) Ok(())
} }
/// Returns the count of the last typing update in this room. /// Returns the count of the last typing update in this room.
pub(crate) async fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { pub(crate) async fn last_typing_update(
&self,
room_id: &RoomId,
) -> Result<u64> {
self.typings_maintain(room_id).await?; self.typings_maintain(room_id).await?;
Ok(self Ok(self
.last_typing_update .last_typing_update

File diff suppressed because it is too large Load diff

View file

@ -5,16 +5,19 @@ pub(crate) use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::Result;
use super::timeline::PduCount; use super::timeline::PduCount;
use crate::Result;
pub(crate) struct Service { pub(crate) struct Service {
pub(crate) db: &'static dyn Data, pub(crate) db: &'static dyn Data,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub(crate) lazy_load_waiting: pub(crate) lazy_load_waiting: Mutex<
Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>, HashMap<
(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount),
HashSet<OwnedUserId>,
>,
>,
} }
impl Service { impl Service {
@ -26,8 +29,7 @@ impl Service {
room_id: &RoomId, room_id: &RoomId,
ll_user: &UserId, ll_user: &UserId,
) -> Result<bool> { ) -> Result<bool> {
self.db self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{DeviceId, RoomId, UserId}; use ruma::{DeviceId, RoomId, UserId};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn lazy_load_was_sent_before( fn lazy_load_was_sent_before(
&self, &self,

View file

@ -1,10 +1,13 @@
use crate::Result;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Checks if a room exists. /// Checks if a room exists.
fn exists(&self, room_id: &RoomId) -> Result<bool>; fn exists(&self, room_id: &RoomId) -> Result<bool>;
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; fn iter_ids<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
fn is_disabled(&self, room_id: &RoomId) -> Result<bool>; fn is_disabled(&self, room_id: &RoomId) -> Result<bool>;
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>;
} }

View file

@ -4,8 +4,15 @@ use crate::{PduEvent, Result};
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
/// Returns the pdu from the outlier tree. /// Returns the pdu from the outlier tree.
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; fn get_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
/// Append the PDU as an outlier. /// Append the PDU as an outlier.
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; fn add_pdu_outlier(
&self,
event_id: &EventId,
pdu: &CanonicalJsonObject,
) -> Result<()>;
} }

View file

@ -9,9 +9,8 @@ use ruma::{
}; };
use serde::Deserialize; use serde::Deserialize;
use crate::{services, PduEvent, Result};
use super::timeline::PduCount; use super::timeline::PduCount;
use crate::{services, PduEvent, Result};
pub(crate) struct Service { pub(crate) struct Service {
pub(crate) db: &'static dyn Data, pub(crate) db: &'static dyn Data,
@ -29,9 +28,15 @@ struct ExtractRelatesToEventId {
impl Service { impl Service {
#[tracing::instrument(skip(self, from, to))] #[tracing::instrument(skip(self, from, to))]
pub(crate) fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { pub(crate) fn add_relation(
&self,
from: PduCount,
to: PduCount,
) -> Result<()> {
match (from, to) { match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), (PduCount::Normal(f), PduCount::Normal(t)) => {
self.db.add_relation(f, t)
}
_ => { _ => {
// TODO: Relations with backfilled pdus // TODO: Relations with backfilled pdus
@ -42,6 +47,7 @@ impl Service {
#[allow( #[allow(
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::too_many_lines,
// Allowed because this function uses `services()` // Allowed because this function uses `services()`
clippy::unused_self, clippy::unused_self,
)] )]
@ -68,15 +74,17 @@ impl Service {
.relations_until(sender_user, room_id, target, from)? .relations_until(sender_user, room_id, target, from)?
.filter(|r| { .filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| { r.as_ref().map_or(true, |(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t) filter_event_type
&& if let Ok(content) = .as_ref()
serde_json::from_str::<ExtractRelatesToEventId>( .map_or(true, |t| &&pdu.kind == t)
pdu.content.get(), && if let Ok(content) = serde_json::from_str::<
) ExtractRelatesToEventId,
{ >(
filter_rel_type pdu.content.get()
.as_ref() ) {
.map_or(true, |r| &&content.relates_to.rel_type == r) filter_rel_type.as_ref().map_or(true, |r| {
&&content.relates_to.rel_type == r
})
} else { } else {
false false
} }
@ -88,13 +96,18 @@ impl Service {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(
sender_user,
room_id,
&pdu.event_id,
)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) .take_while(|&(k, _)| Some(k) != to)
.collect(); .collect();
next_token = events_after.last().map(|(count, _)| count).copied(); next_token =
events_after.last().map(|(count, _)| count).copied();
// Reversed because relations are always most recent first // Reversed because relations are always most recent first
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after
@ -116,15 +129,17 @@ impl Service {
.relations_until(sender_user, room_id, target, from)? .relations_until(sender_user, room_id, target, from)?
.filter(|r| { .filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| { r.as_ref().map_or(true, |(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t) filter_event_type
&& if let Ok(content) = .as_ref()
serde_json::from_str::<ExtractRelatesToEventId>( .map_or(true, |t| &&pdu.kind == t)
pdu.content.get(), && if let Ok(content) = serde_json::from_str::<
) ExtractRelatesToEventId,
{ >(
filter_rel_type pdu.content.get()
.as_ref() ) {
.map_or(true, |r| &&content.relates_to.rel_type == r) filter_rel_type.as_ref().map_or(true, |r| {
&&content.relates_to.rel_type == r
})
} else { } else {
false false
} }
@ -136,13 +151,18 @@ impl Service {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(
sender_user,
room_id,
&pdu.event_id,
)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) .take_while(|&(k, _)| Some(k) != to)
.collect(); .collect();
next_token = events_before.last().map(|(count, _)| count).copied(); next_token =
events_before.last().map(|(count, _)| count).copied();
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before
.into_iter() .into_iter()
@ -165,7 +185,8 @@ impl Service {
target: &'a EventId, target: &'a EventId,
until: PduCount, until: PduCount,
) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; let room_id =
services().rooms.short.get_or_create_shortroomid(room_id)?;
let target = match services().rooms.timeline.get_pdu_count(target)? { let target = match services().rooms.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c, Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations // TODO: Support backfilled relations
@ -185,17 +206,27 @@ impl Service {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub(crate) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { pub(crate) fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
self.db.is_event_referenced(room_id, event_id) self.db.is_event_referenced(room_id, event_id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub(crate) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { pub(crate) fn mark_event_soft_failed(
&self,
event_id: &EventId,
) -> Result<()> {
self.db.mark_event_soft_failed(event_id) self.db.mark_event_soft_failed(event_id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub(crate) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { pub(crate) fn is_event_soft_failed(
&self,
event_id: &EventId,
) -> Result<bool> {
self.db.is_event_soft_failed(event_id) self.db.is_event_soft_failed(event_id)
} }
} }

View file

@ -1,8 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
use ruma::{EventId, RoomId, UserId}; use ruma::{EventId, RoomId, UserId};
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn add_relation(&self, from: u64, to: u64) -> Result<()>; fn add_relation(&self, from: u64, to: u64) -> Result<()>;
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
@ -13,8 +14,16 @@ pub(crate) trait Data: Send + Sync {
target: u64, target: u64,
until: PduCount, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>; ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>; fn mark_as_referenced(
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>; &self,
room_id: &RoomId,
event_ids: &[Arc<EventId>],
) -> Result<()>;
fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool>;
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>;
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>; fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>;
} }

View file

@ -1,8 +1,14 @@
use crate::Result;
use ruma::RoomId; use ruma::RoomId;
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; fn index_pdu(
&self,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()>;
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn search_pdus<'a>( fn search_pdus<'a>(

View file

@ -1,8 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use crate::Result;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync { pub(crate) trait Data: Send + Sync {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>; fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>;
@ -18,12 +19,19 @@ pub(crate) trait Data: Send + Sync {
state_key: &str, state_key: &str,
) -> Result<u64>; ) -> Result<u64>;
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>; fn get_eventid_from_short(&self, shorteventid: u64)
-> Result<Arc<EventId>>;
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; fn get_statekey_from_short(
&self,
shortstatekey: u64,
) -> Result<(StateEventType, String)>;
/// Returns `(shortstatehash, already_existed)` /// Returns `(shortstatehash, already_existed)`
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>; fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;

Some files were not shown because too many files have changed in this diff Show more