change rustfmt configuration

This change is fully automated, except the `rustfmt.toml` changes and
a few clippy directives to allow specific functions with too many lines
because they are longer now.
This commit is contained in:
Charles Hall 2024-05-16 01:19:04 -07:00
parent 40d6ce230d
commit 0afc1d2f50
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
123 changed files with 7881 additions and 4687 deletions

View file

@ -1,2 +1,17 @@
unstable_features = true
imports_granularity="Crate"
edition = "2021"
condense_wildcard_suffixes = true
format_code_in_doc_comments = true
format_macro_bodies = true
format_macro_matchers = true
format_strings = true
group_imports = "StdExternalCrate"
hex_literal_case = "Upper"
imports_granularity = "Crate"
max_width = 80
newline_style = "Unix"
reorder_impl_items = true
use_field_init_shorthand = true
use_small_heuristics = "Off"
use_try_shorthand = true
wrap_comments = true

View file

@ -1,14 +1,18 @@
use crate::{services, utils, Error, Result};
use std::{fmt::Debug, mem, time::Duration};
use bytes::BytesMut;
use ruma::api::{
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken,
};
use std::{fmt::Debug, mem, time::Duration};
use tracing::warn;
use crate::{services, utils, Error, Result};
/// Sends a request to an appservice
///
/// Only returns None if there is no url specified in the appservice registration file
/// Only returns None if there is no url specified in the appservice
/// registration file
#[tracing::instrument(skip(request))]
pub(crate) async fn send_request<T: OutgoingRequest>(
registration: Registration,
@ -45,7 +49,8 @@ where
.parse()
.unwrap(),
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
*http_request.uri_mut() =
parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = reqwest::Request::try_from(http_request)?;
@ -70,9 +75,8 @@ where
// reqwest::Response -> http::Response conversion
let status = response.status();
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
let mut http_response_builder =
http::Response::builder().status(status).version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder

View file

@ -1,22 +1,25 @@
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{api::client_server, services, utils, Error, Result, Ruma};
use register::RegistrationKind;
use ruma::{
api::client::{
account::{
change_password, deactivate, get_3pids, get_username_availability,
register::{self, LoginType},
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn,
whoami, ThirdPartyIdRemovalStatus,
request_3pid_management_token_via_email,
request_3pid_management_token_via_msisdn, whoami,
ThirdPartyIdRemovalStatus,
},
error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo},
},
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
events::{
room::message::RoomMessageEventContent, GlobalAccountDataEventType,
},
push, UserId,
};
use tracing::{info, warn};
use register::RegistrationKind;
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{api::client_server, services, utils, Error, Result, Ruma};
const RANDOM_USER_ID_LENGTH: usize = 10;
@ -29,7 +32,8 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
/// - The server name of the user id matches this server
/// - No user or appservice on this server already claimed this username
///
/// Note: This will not reserve the username, so the username might become invalid when trying to register
/// Note: This will not reserve the username, so the username might become
/// invalid when trying to register
pub(crate) async fn get_register_available_route(
body: Ruma<get_username_availability::v3::Request>,
) -> Result<get_username_availability::v3::Response> {
@ -40,7 +44,8 @@ pub(crate) async fn get_register_available_route(
)
.ok()
.filter(|user_id| {
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
!user_id.is_historical()
&& user_id.server_name() == services().globals.server_name()
})
.ok_or(Error::BadRequest(
ErrorKind::InvalidUsername,
@ -58,27 +63,35 @@ pub(crate) async fn get_register_available_route(
// TODO add check for appservice namespaces
// If no if check is true we have an username that's available to be used.
Ok(get_username_availability::v3::Response { available: true })
Ok(get_username_availability::v3::Response {
available: true,
})
}
/// # `POST /_matrix/client/r0/register`
///
/// Register an account on this homeserver.
///
/// You can use [`GET /_matrix/client/r0/register/available`](get_register_available_route)
/// You can use [`GET
/// /_matrix/client/r0/register/available`](get_register_available_route)
/// to check if the user id is valid and available.
///
/// - Only works if registration is enabled
/// - If type is guest: ignores all parameters except `initial_device_display_name`
/// - If type is guest: ignores all parameters except
/// `initial_device_display_name`
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
/// - If type is not guest and no username is given: Always fails after UIAA check
/// - If type is not guest and no username is given: Always fails after UIAA
/// check
/// - Creates a new account and populates it with default account data
/// - If `inhibit_login` is false: Creates a device and returns `device_id` and `access_token`
/// - If `inhibit_login` is false: Creates a device and returns `device_id` and
/// `access_token`
#[allow(clippy::too_many_lines)]
pub(crate) async fn register_route(
body: Ruma<register::v3::Request>,
) -> Result<register::v3::Response> {
if !services().globals.allow_registration() && body.appservice_info.is_none() {
if !services().globals.allow_registration()
&& body.appservice_info.is_none()
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Registration has been disabled.",
@ -158,7 +171,8 @@ pub(crate) async fn register_route(
};
body.appservice_info.is_some()
} else {
// No registration token necessary, but clients must still go through the flow
// No registration token necessary, but clients must still go through
// the flow
uiaainfo = UiaaInfo {
flows: vec![AuthFlow {
stages: vec![AuthType::Dummy],
@ -174,8 +188,11 @@ pub(crate) async fn register_route(
if !skip_auth {
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(
&UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid"),
&UserId::parse_with_server_name(
"",
services().globals.server_name(),
)
.expect("we know this is valid"),
"".into(),
auth,
&uiaainfo,
@ -187,8 +204,11 @@ pub(crate) async fn register_route(
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(
&UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid"),
&UserId::parse_with_server_name(
"",
services().globals.server_name(),
)
.expect("we know this is valid"),
"".into(),
&uiaainfo,
&json,
@ -211,9 +231,7 @@ pub(crate) async fn register_route(
// Default to pretty displayname
let displayname = user_id.localpart().to_owned();
services()
.users
.set_displayname(&user_id, Some(displayname.clone()))?;
services().users.set_displayname(&user_id, Some(displayname.clone()))?;
// Initial account data
services().account_data.update(
@ -260,29 +278,24 @@ pub(crate) async fn register_route(
info!("New user {} registered on this server.", user_id);
if body.appservice_info.is_none() && !is_guest {
services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"New user {user_id} registered on this server."
)));
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("New user {user_id} registered on this server."),
));
}
// If this is the first real user, grant them admin privileges
// Note: the server user, @grapevine:servername, is generated first
if !is_guest {
if let Some(admin_room) = services().admin.get_admin_room()? {
if services()
.rooms
.state_cache
.room_joined_count(&admin_room)?
if services().rooms.state_cache.room_joined_count(&admin_room)?
== Some(1)
{
services()
.admin
.make_user_admin(&user_id, displayname)
.await?;
services().admin.make_user_admin(&user_id, displayname).await?;
warn!("Granting {} admin privileges as the first user", user_id);
warn!(
"Granting {} admin privileges as the first user",
user_id
);
}
}
}
@ -302,19 +315,23 @@ pub(crate) async fn register_route(
///
/// - Requires UIAA to verify user password
/// - Changes the password of the sender user
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is
/// - The password hash is calculated using argon2 with 32 character salt, the
/// plain password is
/// not saved
///
/// If `logout_devices` is true it does the following for each device except the sender device:
/// If `logout_devices` is true it does the following for each device except the
/// sender device:
/// - Invalidates access token
/// - Deletes device metadata (device ID, device display name, last seen IP, last seen timestamp)
/// - Deletes device metadata (device ID, device display name, last seen IP,
/// last seen timestamp)
/// - Forgets to-device events
/// - Triggers device list updates
pub(crate) async fn change_password_route(
body: Ruma<change_password::v3::Request>,
) -> Result<change_password::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow {
@ -327,27 +344,25 @@ pub(crate) async fn change_password_route(
};
if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user,
sender_device,
auth,
&uiaainfo,
)?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
services()
.users
.set_password(sender_user, Some(&body.new_password))?;
services().users.set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices {
// Logout all devices except the current one
@ -362,11 +377,9 @@ pub(crate) async fn change_password_route(
}
info!("User {} changed their password.", sender_user);
services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} changed their password."
)));
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("User {sender_user} changed their password."),
));
Ok(change_password::v3::Response {})
}
@ -376,14 +389,17 @@ pub(crate) async fn change_password_route(
/// Get `user_id` of the sender user.
///
/// Note: Also works for Application Services
pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> {
pub(crate) async fn whoami_route(
body: Ruma<whoami::v3::Request>,
) -> Result<whoami::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device_id = body.sender_device.as_ref().cloned();
Ok(whoami::v3::Response {
user_id: sender_user.clone(),
device_id,
is_guest: services().users.is_deactivated(sender_user)? && body.appservice_info.is_none(),
is_guest: services().users.is_deactivated(sender_user)?
&& body.appservice_info.is_none(),
})
}
@ -393,7 +409,8 @@ pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoa
///
/// - Leaves all rooms and rejects all invitations
/// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
/// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events
/// - Triggers device list updates
/// - Removes ability to log in again
@ -401,7 +418,8 @@ pub(crate) async fn deactivate_route(
body: Ruma<deactivate::v3::Request>,
) -> Result<deactivate::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow {
@ -414,19 +432,19 @@ pub(crate) async fn deactivate_route(
};
if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user,
sender_device,
auth,
&uiaainfo,
)?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -439,11 +457,9 @@ pub(crate) async fn deactivate_route(
services().users.deactivate_account(sender_user)?;
info!("User {} deactivated their account.", sender_user);
services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} deactivated their account."
)));
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("User {sender_user} deactivated their account."),
));
Ok(deactivate::v3::Response {
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
@ -458,16 +474,19 @@ pub(crate) async fn deactivate_route(
pub(crate) async fn third_party_route(
body: Ruma<get_3pids::v3::Request>,
) -> Result<get_3pids::v3::Response> {
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
let _sender_user =
body.sender_user.as_ref().expect("user is authenticated");
Ok(get_3pids::v3::Response::new(Vec::new()))
}
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
///
/// "This API should be used to request validation tokens when adding an email address to an account"
/// "This API should be used to request validation tokens when adding an email
/// address to an account"
///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
/// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub(crate) async fn request_3pid_management_token_via_email_route(
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
) -> Result<request_3pid_management_token_via_email::v3::Response> {
@ -479,9 +498,11 @@ pub(crate) async fn request_3pid_management_token_via_email_route(
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
///
/// "This API should be used to request validation tokens when adding an phone number to an account"
/// "This API should be used to request validation tokens when adding an phone
/// number to an account"
///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
/// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub(crate) async fn request_3pid_management_token_via_msisdn_route(
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {

View file

@ -1,4 +1,3 @@
use crate::{services, Error, Result, Ruma};
use rand::seq::SliceRandom;
use ruma::{
api::{
@ -12,6 +11,8 @@ use ruma::{
OwnedRoomAliasId,
};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/directory/room/{roomAlias}`
///
/// Creates a new room alias on this server.
@ -32,30 +33,18 @@ pub(crate) async fn create_alias_route(
"Room alias is not in namespace.",
));
}
} else if services()
.appservice
.is_exclusive_alias(&body.room_alias)
.await
{
} else if services().appservice.is_exclusive_alias(&body.room_alias).await {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias reserved by appservice.",
));
}
if services()
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.is_some()
{
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
return Err(Error::Conflict("Alias already exists."));
}
services()
.rooms
.alias
.set_alias(&body.room_alias, &body.room_id)?;
services().rooms.alias.set_alias(&body.room_alias, &body.room_id)?;
Ok(create_alias::v3::Response::new())
}
@ -83,11 +72,7 @@ pub(crate) async fn delete_alias_route(
"Room alias is not in namespace.",
));
}
} else if services()
.appservice
.is_exclusive_alias(&body.room_alias)
.await
{
} else if services().appservice.is_exclusive_alias(&body.room_alias).await {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias reserved by appservice.",
@ -157,7 +142,10 @@ pub(crate) async fn get_alias_helper(
.alias
.resolve_local_alias(&room_alias)?
.ok_or_else(|| {
Error::bad_config("Appservice lied to us. Room does not exist.")
Error::bad_config(
"Appservice lied to us. Room does not \
exist.",
)
})?,
);
break;

View file

@ -1,15 +1,16 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{
backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
create_backup_version, delete_backup_keys, delete_backup_keys_for_room,
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys,
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info,
update_backup_version,
delete_backup_keys_for_session, delete_backup_version, get_backup_info,
get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
get_latest_backup_info, update_backup_version,
},
error::ErrorKind,
};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version`
///
/// Creates a new backup.
@ -17,23 +18,27 @@ pub(crate) async fn create_backup_version_route(
body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services()
.key_backups
.create_backup(sender_user, &body.algorithm)?;
let version =
services().key_backups.create_backup(sender_user, &body.algorithm)?;
Ok(create_backup_version::v3::Response { version })
Ok(create_backup_version::v3::Response {
version,
})
}
/// # `PUT /_matrix/client/r0/room_keys/version/{version}`
///
/// Update information about an existing backup. Only `auth_data` can be modified.
/// Update information about an existing backup. Only `auth_data` can be
/// modified.
pub(crate) async fn update_backup_version_route(
body: Ruma<update_backup_version::v3::Request>,
) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?;
services().key_backups.update_backup(
sender_user,
&body.version,
&body.algorithm,
)?;
Ok(update_backup_version::v3::Response {})
}
@ -88,9 +93,7 @@ pub(crate) async fn get_backup_info_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
version: body.version.clone(),
})
}
@ -99,15 +102,14 @@ pub(crate) async fn get_backup_info_route(
///
/// Delete an existing key backup.
///
/// - Deletes both information about the backup, as well as all key data related to the backup
/// - Deletes both information about the backup, as well as all key data related
/// to the backup
pub(crate) async fn delete_backup_version_route(
body: Ruma<delete_backup_version::v3::Request>,
) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.key_backups
.delete_backup(sender_user, &body.version)?;
services().key_backups.delete_backup(sender_user, &body.version)?;
Ok(delete_backup_version::v3::Response {})
}
@ -116,7 +118,8 @@ pub(crate) async fn delete_backup_version_route(
///
/// Add the received backup keys to the database.
///
/// - Only manipulating the most recently created version of the backup is allowed
/// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_route(
@ -132,7 +135,8 @@ pub(crate) async fn add_backup_keys_route(
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
"You may only manipulate the most recently created version of the \
backup.",
));
}
@ -154,9 +158,7 @@ pub(crate) async fn add_backup_keys_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}
@ -164,7 +166,8 @@ pub(crate) async fn add_backup_keys_route(
///
/// Add the received backup keys to the database.
///
/// - Only manipulating the most recently created version of the backup is allowed
/// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_room_route(
@ -180,7 +183,8 @@ pub(crate) async fn add_backup_keys_for_room_route(
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
"You may only manipulate the most recently created version of the \
backup.",
));
}
@ -200,9 +204,7 @@ pub(crate) async fn add_backup_keys_for_room_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}
@ -210,7 +212,8 @@ pub(crate) async fn add_backup_keys_for_room_route(
///
/// Add the received backup key to the database.
///
/// - Only manipulating the most recently created version of the backup is allowed
/// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_session_route(
@ -226,7 +229,8 @@ pub(crate) async fn add_backup_keys_for_session_route(
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
"You may only manipulate the most recently created version of the \
backup.",
));
}
@ -244,9 +248,7 @@ pub(crate) async fn add_backup_keys_for_session_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}
@ -260,7 +262,9 @@ pub(crate) async fn get_backup_keys_route(
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
Ok(get_backup_keys::v3::Response { rooms })
Ok(get_backup_keys::v3::Response {
rooms,
})
}
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
@ -271,11 +275,15 @@ pub(crate) async fn get_backup_keys_for_room_route(
) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = services()
.key_backups
.get_room(sender_user, &body.version, &body.room_id)?;
let sessions = services().key_backups.get_room(
sender_user,
&body.version,
&body.room_id,
)?;
Ok(get_backup_keys_for_room::v3::Response { sessions })
Ok(get_backup_keys_for_room::v3::Response {
sessions,
})
}
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
@ -288,13 +296,20 @@ pub(crate) async fn get_backup_keys_for_session_route(
let key_data = services()
.key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
.get_session(
sender_user,
&body.version,
&body.room_id,
&body.session_id,
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Backup key not found for this user's session.",
))?;
Ok(get_backup_keys_for_session::v3::Response { key_data })
Ok(get_backup_keys_for_session::v3::Response {
key_data,
})
}
/// # `DELETE /_matrix/client/r0/room_keys/keys`
@ -305,9 +320,7 @@ pub(crate) async fn delete_backup_keys_route(
) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.key_backups
.delete_all_keys(sender_user, &body.version)?;
services().key_backups.delete_all_keys(sender_user, &body.version)?;
Ok(delete_backup_keys::v3::Response {
count: services()
@ -315,9 +328,7 @@ pub(crate) async fn delete_backup_keys_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}
@ -329,9 +340,11 @@ pub(crate) async fn delete_backup_keys_for_room_route(
) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
services().key_backups.delete_room_keys(
sender_user,
&body.version,
&body.room_id,
)?;
Ok(delete_backup_keys_for_room::v3::Response {
count: services()
@ -339,9 +352,7 @@ pub(crate) async fn delete_backup_keys_for_room_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}
@ -366,8 +377,6 @@ pub(crate) async fn delete_backup_keys_for_session_route(
.count_keys(sender_user, &body.version)?
.try_into()
.expect("count should fit in UInt"),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
etag: services().key_backups.get_etag(sender_user, &body.version)?,
})
}

View file

@ -1,12 +1,15 @@
use crate::{services, Result, Ruma};
use std::collections::BTreeMap;
use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
};
use std::collections::BTreeMap;
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/r0/capabilities`
///
/// Get information on the supported feature set and other relevent capabilities of this server.
/// Get information on the supported feature set and other relevent capabilities
/// of this server.
pub(crate) async fn get_capabilities_route(
_body: Ruma<get_capabilities::v3::Request>,
) -> Result<get_capabilities::v3::Response> {
@ -24,5 +27,7 @@ pub(crate) async fn get_capabilities_route(
available,
};
Ok(get_capabilities::v3::Response { capabilities })
Ok(get_capabilities::v3::Response {
capabilities,
})
}

View file

@ -1,18 +1,21 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::client::{
config::{
get_global_account_data, get_room_account_data, set_global_account_data,
set_room_account_data,
get_global_account_data, get_room_account_data,
set_global_account_data, set_room_account_data,
},
error::ErrorKind,
},
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
events::{
AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent,
},
serde::Raw,
};
use serde::Deserialize;
use serde_json::{json, value::RawValue as RawJsonValue};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
///
/// Sets some account data for the sender user.
@ -22,7 +25,9 @@ pub(crate) async fn set_global_account_data_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
.map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Data is invalid.")
})?;
let event_type = body.event_type.to_string();
@ -48,7 +53,9 @@ pub(crate) async fn set_room_account_data_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
.map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Data is invalid.")
})?;
let event_type = body.event_type.to_string();
@ -78,11 +85,16 @@ pub(crate) async fn get_global_account_data_route(
.get(None, sender_user, body.event_type.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content;
let account_data =
serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?
.content;
Ok(get_global_account_data::v3::Response { account_data })
Ok(get_global_account_data::v3::Response {
account_data,
})
}
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
@ -98,11 +110,16 @@ pub(crate) async fn get_room_account_data_route(
.get(Some(&body.room_id), sender_user, body.event_type.clone())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content;
let account_data =
serde_json::from_str::<ExtractRoomEventContent>(event.get())
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?
.content;
Ok(get_room_account_data::v3::Response { account_data })
Ok(get_room_account_data::v3::Response {
account_data,
})
}
#[derive(Deserialize)]

View file

@ -1,60 +1,57 @@
use crate::{services, Error, Result, Ruma};
use std::collections::HashSet;
use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
api::client::{
context::get_context, error::ErrorKind, filter::LazyLoadOptions,
},
events::StateEventType,
uint,
};
use std::collections::HashSet;
use tracing::error;
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
///
/// Allows loading room history around an event.
///
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was
/// - Only works if the user is joined (TODO: always allow, but only show events
/// if the user was
/// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_context_route(
body: Ruma<get_context::v3::Request>,
) -> Result<get_context::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options {
LazyLoadOptions::Enabled {
include_redundant_members,
} => (true, *include_redundant_members),
LazyLoadOptions::Disabled => (false, false),
};
let (lazy_load_enabled, lazy_load_send_redundant) =
match &body.filter.lazy_load_options {
LazyLoadOptions::Enabled {
include_redundant_members,
} => (true, *include_redundant_members),
LazyLoadOptions::Disabled => (false, false),
};
let mut lazy_loaded = HashSet::new();
let base_token = services()
.rooms
.timeline
.get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Base event id not found.",
))?;
let base_token =
services().rooms.timeline.get_pdu_count(&body.event_id)?.ok_or(
Error::BadRequest(ErrorKind::NotFound, "Base event id not found."),
)?;
let base_event =
services()
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Base event not found.",
))?;
let base_event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or(
Error::BadRequest(ErrorKind::NotFound, "Base event not found."),
)?;
let room_id = base_event.room_id.clone();
if !services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &body.event_id)?
{
if !services().rooms.state_accessor.user_can_see_event(
sender_user,
&room_id,
&body.event_id,
)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this event.",
@ -72,8 +69,8 @@ pub(crate) async fn get_context_route(
}
// Use limit with maximum 100
let half_limit =
usize::try_from(body.limit.min(uint!(100)) / uint!(2)).expect("0-50 should fit in usize");
let half_limit = usize::try_from(body.limit.min(uint!(100)) / uint!(2))
.expect("0-50 should fit in usize");
let base_event = base_event.to_room_event();
@ -108,10 +105,8 @@ pub(crate) async fn get_context_route(
.last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_before: Vec<_> = events_before
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let events_before: Vec<_> =
events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
let events_after: Vec<_> = services()
.rooms
@ -140,41 +135,33 @@ pub(crate) async fn get_context_route(
}
}
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
events_after
.last()
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
)? {
Some(s) => s,
None => services()
.rooms
.state
.get_room_shortstatehash(&room_id)?
.expect("All rooms have state"),
};
let shortstatehash =
match services().rooms.state_accessor.pdu_shortstatehash(
events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id),
)? {
Some(s) => s,
None => services()
.rooms
.state
.get_room_shortstatehash(&room_id)?
.expect("All rooms have state"),
};
let state_ids = services()
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?;
let state_ids =
services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
let end_token = events_after
.last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_after: Vec<_> = events_after
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let events_after: Vec<_> =
events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
let mut state = Vec::new();
for (shortstatekey, id) in state_ids {
let (event_type, state_key) = services()
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
let (event_type, state_key) =
services().rooms.short.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else {

View file

@ -1,11 +1,14 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
device::{
self, delete_device, delete_devices, get_device, get_devices,
update_device,
},
error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo},
};
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/devices`
///
@ -21,7 +24,9 @@ pub(crate) async fn get_devices_route(
.filter_map(Result::ok)
.collect();
Ok(get_devices::v3::Response { devices })
Ok(get_devices::v3::Response {
devices,
})
}
/// # `GET /_matrix/client/r0/devices/{deviceId}`
@ -37,7 +42,9 @@ pub(crate) async fn get_device_route(
.get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
Ok(get_device::v3::Response { device })
Ok(get_device::v3::Response {
device,
})
}
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
@ -55,9 +62,11 @@ pub(crate) async fn update_device_route(
device.display_name = body.display_name.clone();
services()
.users
.update_device_metadata(sender_user, &body.device_id, &device)?;
services().users.update_device_metadata(
sender_user,
&body.device_id,
&device,
)?;
Ok(update_device::v3::Response {})
}
@ -68,14 +77,16 @@ pub(crate) async fn update_device_route(
///
/// - Requires UIAA to verify user password
/// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
/// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events
/// - Triggers device list updates
pub(crate) async fn delete_device_route(
body: Ruma<delete_device::v3::Request>,
) -> Result<delete_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA
let mut uiaainfo = UiaaInfo {
@ -89,27 +100,25 @@ pub(crate) async fn delete_device_route(
};
if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user,
sender_device,
auth,
&uiaainfo,
)?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
services()
.users
.remove_device(sender_user, &body.device_id)?;
services().users.remove_device(sender_user, &body.device_id)?;
Ok(delete_device::v3::Response {})
}
@ -122,14 +131,16 @@ pub(crate) async fn delete_device_route(
///
/// For each device:
/// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
/// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events
/// - Triggers device list updates
pub(crate) async fn delete_devices_route(
body: Ruma<delete_devices::v3::Request>,
) -> Result<delete_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA
let mut uiaainfo = UiaaInfo {
@ -143,19 +154,19 @@ pub(crate) async fn delete_devices_route(
};
if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user,
sender_device,
auth,
&uiaainfo,
)?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));

View file

@ -1,10 +1,9 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::{
client::{
directory::{
get_public_rooms, get_public_rooms_filtered, get_room_visibility,
set_room_visibility,
get_public_rooms, get_public_rooms_filtered,
get_room_visibility, set_room_visibility,
},
error::ErrorKind,
room,
@ -18,7 +17,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent},
topic::RoomTopicEventContent,
},
@ -28,6 +29,8 @@ use ruma::{
};
use tracing::{error, info, warn};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/publicRooms`
///
/// Lists the public rooms on this server.
@ -91,7 +94,9 @@ pub(crate) async fn set_room_visibility_route(
services().rooms.directory.set_public(&body.room_id)?;
info!("{} made {} public", sender_user, body.room_id);
}
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
room::Visibility::Private => {
services().rooms.directory.set_not_public(&body.room_id)?;
}
_ => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
@ -115,7 +120,11 @@ pub(crate) async fn get_room_visibility_route(
}
Ok(get_room_visibility::v3::Response {
visibility: if services().rooms.directory.is_public_room(&body.room_id)? {
visibility: if services()
.rooms
.directory
.is_public_room(&body.room_id)?
{
room::Visibility::Public
} else {
room::Visibility::Private
@ -131,8 +140,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
filter: &Filter,
_network: &RoomNetwork,
) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(other_server) =
server.filter(|server| *server != services().globals.server_name().as_str())
if let Some(other_server) = server
.filter(|server| *server != services().globals.server_name().as_str())
{
let response = services()
.sending
@ -174,10 +183,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}
};
num_since = characters
.collect::<String>()
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?;
num_since = characters.collect::<String>().parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token.")
})?;
if backwards {
num_since = num_since.saturating_sub(limit);
@ -195,12 +203,19 @@ pub(crate) async fn get_public_rooms_filtered_helper(
canonical_alias: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
.room_state_get(
&room_id,
&StateEventType::RoomCanonicalAlias,
"",
)?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| {
Error::bad_database("Invalid canonical alias event in database.")
Error::bad_database(
"Invalid canonical alias event in \
database.",
)
})
})?,
name: services().rooms.state_accessor.get_name(&room_id)?,
@ -222,36 +237,55 @@ pub(crate) async fn get_public_rooms_filtered_helper(
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|_| {
error!("Invalid room topic event in database for room {}", room_id);
Error::bad_database("Invalid room topic event in database.")
error!(
"Invalid room topic event in database for \
room {}",
room_id
);
Error::bad_database(
"Invalid room topic event in database.",
)
})
})?,
world_readable: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
.room_state_get(
&room_id,
&StateEventType::RoomHistoryVisibility,
"",
)?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
c.history_visibility
== HistoryVisibility::WorldReadable
})
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in database.",
"Invalid room history visibility event in \
database.",
)
})
})?,
guest_can_join: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
.room_state_get(
&room_id,
&StateEventType::RoomGuestAccess,
"",
)?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| {
c.guest_access == GuestAccess::CanJoin
})
.map_err(|_| {
Error::bad_database("Invalid room guest access event in database.")
Error::bad_database(
"Invalid room guest access event in \
database.",
)
})
})?,
avatar_url: services()
@ -262,7 +296,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
serde_json::from_str(s.content.get())
.map(|c: RoomAvatarEventContent| c.url)
.map_err(|_| {
Error::bad_database("Invalid room avatar event in database.")
Error::bad_database(
"Invalid room avatar event in database.",
)
})
})
.transpose()?
@ -270,33 +306,59 @@ pub(crate) async fn get_public_rooms_filtered_helper(
join_rule: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.room_state_get(
&room_id,
&StateEventType::RoomJoinRules,
"",
)?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => Some(PublicRoomJoinRule::Public),
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
_ => None,
.map(|c: RoomJoinRulesEventContent| {
match c.join_rule {
JoinRule::Public => {
Some(PublicRoomJoinRule::Public)
}
JoinRule::Knock => {
Some(PublicRoomJoinRule::Knock)
}
_ => None,
}
})
.map_err(|e| {
error!("Invalid room join rule event in database: {}", e);
Error::BadDatabase("Invalid room join rule event in database.")
error!(
"Invalid room join rule event in \
database: {}",
e
);
Error::BadDatabase(
"Invalid room join rule event in database.",
)
})
})
.transpose()?
.flatten()
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
.ok_or_else(|| {
Error::bad_database(
"Missing room join rule event for room.",
)
})?,
room_type: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.map(|s| {
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(
|e| {
error!("Invalid room create event in database: {}", e);
Error::BadDatabase("Invalid room create event in database.")
},
serde_json::from_str::<RoomCreateEventContent>(
s.content.get(),
)
.map_err(|e| {
error!(
"Invalid room create event in database: {}",
e
);
Error::BadDatabase(
"Invalid room create event in database.",
)
})
})
.transpose()?
.and_then(|e| e.room_type),
@ -306,10 +368,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
})
.filter_map(Result::<_>::ok)
.filter(|chunk| {
if let Some(query) = filter
.generic_search_term
.as_ref()
.map(|q| q.to_lowercase())
if let Some(query) =
filter.generic_search_term.as_ref().map(|q| q.to_lowercase())
{
if let Some(name) = &chunk.name {
if name.as_str().to_lowercase().contains(&query) {
@ -324,7 +384,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}
if let Some(canonical_alias) = &chunk.canonical_alias {
if canonical_alias.as_str().to_lowercase().contains(&query) {
if canonical_alias.as_str().to_lowercase().contains(&query)
{
return true;
}
}
@ -339,7 +400,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
let total_room_count_estimate = all_rooms.len().try_into().unwrap_or(UInt::MAX);
let total_room_count_estimate =
all_rooms.len().try_into().unwrap_or(UInt::MAX);
let chunk: Vec<_> = all_rooms
.into_iter()
@ -353,11 +415,12 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Some(format!("p{num_since}"))
};
let next_batch = if chunk.len() < limit.try_into().expect("UInt should fit in usize") {
None
} else {
Some(format!("n{}", num_since + limit))
};
let next_batch =
if chunk.len() < limit.try_into().expect("UInt should fit in usize") {
None
} else {
Some(format!("n{}", num_since + limit))
};
Ok(get_public_rooms_filtered::v3::Response {
chunk,

View file

@ -1,9 +1,10 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{
error::ErrorKind,
filter::{create_filter, get_filter},
};
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
///
/// Loads a filter that was previously created.
@ -13,8 +14,13 @@ pub(crate) async fn get_filter_route(
body: Ruma<get_filter::v3::Request>,
) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let Some(filter) = services().users.get_filter(sender_user, &body.filter_id)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found."));
let Some(filter) =
services().users.get_filter(sender_user, &body.filter_id)?
else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Filter not found.",
));
};
Ok(get_filter::v3::Response::new(filter))

View file

@ -1,13 +1,16 @@
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::{
client::{
error::ErrorKind,
keys::{
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures,
upload_signing_keys,
claim_keys, get_key_changes, get_keys, upload_keys,
upload_signatures, upload_signing_keys,
},
uiaa::{AuthFlow, AuthType, UiaaInfo},
},
@ -17,28 +20,32 @@ use ruma::{
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
};
use serde_json::json;
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use tracing::debug;
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/keys/upload`
///
/// Publish end-to-end encryption keys for the sender device.
///
/// - Adds one time keys
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
/// - If there are no device keys yet: Adds device keys (TODO: merge with
/// existing keys?)
pub(crate) async fn upload_keys_route(
body: Ruma<upload_keys::v3::Request>,
) -> Result<upload_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys {
services()
.users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
services().users.add_one_time_key(
sender_user,
sender_device,
key_key,
key_value,
)?;
}
if let Some(device_keys) = &body.device_keys {
@ -49,9 +56,11 @@ pub(crate) async fn upload_keys_route(
.get_device_keys(sender_user, sender_device)?
.is_none()
{
services()
.users
.add_device_keys(sender_user, sender_device, device_keys)?;
services().users.add_device_keys(
sender_user,
sender_device,
device_keys,
)?;
}
}
@ -68,14 +77,17 @@ pub(crate) async fn upload_keys_route(
///
/// - Always fetches users from other servers over federation
/// - Gets master keys, self-signing keys, user signing keys and device keys.
/// - The master and self-signing keys contain signatures that the user is allowed to see
/// - The master and self-signing keys contain signatures that the user is
/// allowed to see
pub(crate) async fn get_keys_route(
body: Ruma<get_keys::v3::Request>,
) -> Result<get_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let response =
get_keys_helper(Some(sender_user), &body.device_keys, |u| u == sender_user).await?;
let response = get_keys_helper(Some(sender_user), &body.device_keys, |u| {
u == sender_user
})
.await?;
Ok(response)
}
@ -100,7 +112,8 @@ pub(crate) async fn upload_signing_keys_route(
body: Ruma<upload_signing_keys::v3::Request>,
) -> Result<upload_signing_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
// UIAA
let mut uiaainfo = UiaaInfo {
@ -114,19 +127,19 @@ pub(crate) async fn upload_signing_keys_route(
};
if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user,
sender_device,
auth,
&uiaainfo,
)?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -156,8 +169,9 @@ pub(crate) async fn upload_signatures_route(
for (user_id, keys) in &body.signed_keys {
for (key_id, key) in keys {
let key = serde_json::to_value(key)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?;
let key = serde_json::to_value(key).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON")
})?;
for signature in key
.get("signatures")
@ -189,9 +203,12 @@ pub(crate) async fn upload_signatures_route(
))?
.to_owned(),
);
services()
.users
.sign_key(user_id, key_id, signature, sender_user)?;
services().users.sign_key(
user_id,
key_id,
signature,
sender_user,
)?;
}
}
}
@ -204,7 +221,8 @@ pub(crate) async fn upload_signatures_route(
/// # `POST /_matrix/client/r0/keys/changes`
///
/// Gets a list of users who have updated their device identity keys since the previous sync token.
/// Gets a list of users who have updated their device identity keys since the
/// previous sync token.
///
/// - TODO: left users
pub(crate) async fn get_key_changes_route(
@ -219,14 +237,15 @@ pub(crate) async fn get_key_changes_route(
.users
.keys_changed(
sender_user.as_str(),
body.from
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
body.from.parse().map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid `from`.",
)
})?,
Some(body.to.parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
})?),
)
.filter_map(Result::ok),
);
@ -243,10 +262,16 @@ pub(crate) async fn get_key_changes_route(
.keys_changed(
room_id.as_ref(),
body.from.parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.")
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid `from`.",
)
})?,
Some(body.to.parse().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid `to`.",
)
})?),
)
.filter_map(Result::ok),
@ -287,16 +312,24 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut container = BTreeMap::new();
for device_id in services().users.all_device_ids(user_id) {
let device_id = device_id?;
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
if let Some(mut keys) =
services().users.get_device_keys(user_id, &device_id)?
{
let metadata = services()
.users
.get_device_metadata(user_id, &device_id)?
.ok_or_else(|| {
Error::bad_database("all_device_keys contained nonexistent device.")
Error::bad_database(
"all_device_keys contained nonexistent device.",
)
})?;
add_unsigned_device_display_name(&mut keys, metadata)
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
.map_err(|_| {
Error::bad_database(
"invalid device keys in database",
)
})?;
container.insert(device_id, keys);
}
}
@ -304,7 +337,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} else {
for device_id in device_ids {
let mut container = BTreeMap::new();
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
if let Some(mut keys) =
services().users.get_device_keys(user_id, device_id)?
{
let metadata = services()
.users
.get_device_metadata(user_id, device_id)?
@ -314,29 +349,35 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
))?;
add_unsigned_device_display_name(&mut keys, metadata)
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
.map_err(|_| {
Error::bad_database(
"invalid device keys in database",
)
})?;
container.insert(device_id.to_owned(), keys);
}
device_keys.insert(user_id.to_owned(), container);
}
}
if let Some(master_key) =
services()
.users
.get_master_key(sender_user, user_id, &allowed_signatures)?
{
if let Some(master_key) = services().users.get_master_key(
sender_user,
user_id,
&allowed_signatures,
)? {
master_keys.insert(user_id.to_owned(), master_key);
}
if let Some(self_signing_key) =
services()
.users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
{
if let Some(self_signing_key) = services().users.get_self_signing_key(
sender_user,
user_id,
&allowed_signatures,
)? {
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
}
if Some(user_id) == sender_user {
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
if let Some(user_signing_key) =
services().users.get_user_signing_key(user_id)?
{
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
}
}
@ -345,17 +386,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new();
let back_off = |id| async {
match services()
.globals
.bad_query_ratelimiter
.write()
.await
.entry(id)
{
match services().globals.bad_query_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1 + 1);
}
}
};
@ -370,7 +407,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
.get(server)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
let mut min_elapsed_duration =
Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
@ -379,7 +417,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
debug!("Backing off query from {:?}", server);
return (
server,
Err(Error::BadServerResponse("bad query, still backing off")),
Err(Error::BadServerResponse(
"bad query, still backing off",
)),
);
}
}
@ -417,15 +457,19 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
&user,
&allowed_signatures,
)? {
let (_, our_master_key) =
services().users.parse_master_key(&user, &our_master_key)?;
let (_, our_master_key) = services()
.users
.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures);
}
let json = serde_json::to_value(master_key).expect("to_value always works");
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
let json = serde_json::to_value(master_key)
.expect("to_value always works");
let raw = serde_json::from_value(json)
.expect("Raw::from_value always works");
services().users.add_cross_signing_keys(
&user, &raw, &None, &None,
// Dont notify. A notification would trigger another key request resulting in an endless loop
// Dont notify. A notification would trigger another key
// request resulting in an endless loop
false,
)?;
master_keys.insert(user, raw);
@ -454,11 +498,13 @@ fn add_unsigned_device_display_name(
metadata: ruma::api::client::device::Device,
) -> serde_json::Result<()> {
if let Some(display_name) = metadata.display_name {
let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
let mut object = keys
.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
let unsigned = object.entry("unsigned").or_insert_with(|| json!({}));
if let serde_json::Value::Object(unsigned_object) = unsigned {
unsigned_object.insert("device_display_name".to_owned(), display_name.into());
unsigned_object
.insert("device_display_name".to_owned(), display_name.into());
}
*keys = Raw::from_json(serde_json::value::to_raw_value(&object)?);
@ -468,7 +514,10 @@ fn add_unsigned_device_display_name(
}
pub(crate) async fn claim_keys_helper(
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>,
one_time_keys_input: &BTreeMap<
OwnedUserId,
BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>,
>,
) -> Result<claim_keys::v3::Response> {
let mut one_time_keys = BTreeMap::new();
@ -484,11 +533,11 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map {
if let Some(one_time_keys) =
services()
.users
.take_one_time_key(user_id, device_id, key_algorithm)?
{
if let Some(one_time_keys) = services().users.take_one_time_key(
user_id,
device_id,
key_algorithm,
)? {
let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1);
container.insert(device_id.clone(), c);

View file

@ -1,14 +1,15 @@
use std::time::Duration;
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
use ruma::api::client::{
error::ErrorKind,
media::{
create_content, get_content, get_content_as_filename, get_content_thumbnail,
get_media_config,
create_content, get_content, get_content_as_filename,
get_content_thumbnail, get_media_config,
},
};
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
const MXC_LENGTH: usize = 32;
/// # `GET /_matrix/media/r0/config`
@ -110,9 +111,12 @@ pub(crate) async fn get_content_route(
content_disposition,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
} else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?;
get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(remote_content_response)
} else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
@ -130,21 +134,32 @@ pub(crate) async fn get_content_as_filename_route(
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta {
content_type, file, ..
content_type,
file,
..
}) = services().media.get(mxc.clone()).await?
{
Ok(get_content_as_filename::v3::Response {
file,
content_type,
content_disposition: Some(format!("inline; filename={}", body.filename)),
content_disposition: Some(format!(
"inline; filename={}",
body.filename
)),
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
} else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?;
get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(get_content_as_filename::v3::Response {
content_disposition: Some(format!("inline: filename={}", body.filename)),
content_disposition: Some(format!(
"inline: filename={}",
body.filename
)),
content_type: remote_content_response.content_type,
file: remote_content_response.file,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
@ -165,17 +180,19 @@ pub(crate) async fn get_content_thumbnail_route(
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta {
content_type, file, ..
content_type,
file,
..
}) = services()
.media
.get_thumbnail(
mxc.clone(),
body.width
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
body.height
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
body.width.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?,
body.height.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?,
)
.await?
{
@ -184,7 +201,9 @@ pub(crate) async fn get_content_thumbnail_route(
content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
} else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let get_thumbnail_response = services()
.sending
.send_federation_request(

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,8 @@
use crate::{
service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, Result, Ruma,
use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
};
use ruma::{
api::client::{
error::ErrorKind,
@ -10,18 +11,21 @@ use ruma::{
events::{StateEventType, TimelineEventType},
uint,
};
use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
use crate::{
service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, Result, Ruma,
};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}`
///
/// Send a message event into the room.
///
/// - Is a NOOP if the txn id was already used before and returns the same event id again
/// - Is a NOOP if the txn id was already used before and returns the same event
/// id again
/// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed
/// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
pub(crate) async fn send_message_event_route(
body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> {
@ -50,29 +54,37 @@ pub(crate) async fn send_message_event_route(
}
// Check if this is a new transaction id
if let Some(response) =
services()
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)?
{
if let Some(response) = services().transaction_ids.existing_txnid(
sender_user,
sender_device,
&body.txn_id,
)? {
// The client might have sent a txnid of the /sendToDevice endpoint
// This txnid has no response associated with it
if response.is_empty() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to use txn id already used for an incompatible endpoint.",
"Tried to use txn id already used for an incompatible \
endpoint.",
));
}
let event_id = utils::string_from_bytes(&response)
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
.map_err(|_| {
Error::bad_database("Invalid txnid bytes in database.")
})?
.try_into()
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
return Ok(send_message_event::v3::Response { event_id });
.map_err(|_| {
Error::bad_database("Invalid event id in txnid data.")
})?;
return Ok(send_message_event::v3::Response {
event_id,
});
}
let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
unsigned
.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let event_id = services()
.rooms
@ -81,7 +93,12 @@ pub(crate) async fn send_message_event_route(
PduBuilder {
event_type: body.event_type.to_string().into(),
content: serde_json::from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
.map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid JSON body.",
)
})?,
unsigned: Some(unsigned),
state_key: None,
redacts: None,
@ -101,23 +118,23 @@ pub(crate) async fn send_message_event_route(
drop(state_lock);
Ok(send_message_event::v3::Response::new(
(*event_id).to_owned(),
))
Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
}
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
///
/// Allows paginating through room history.
///
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was
/// - Only works if the user is joined (TODO: always allow, but only show events
/// where the user was
/// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_message_events_route(
body: Ruma<get_message_events::v3::Request>,
) -> Result<get_message_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
@ -127,15 +144,17 @@ pub(crate) async fn get_message_events_route(
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
services()
.rooms
.lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)
.lazy_load_confirm_delivery(
sender_user,
sender_device,
&body.room_id,
from,
)
.await?;
let limit = body
@ -162,7 +181,11 @@ pub(crate) async fn get_message_events_route(
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.user_can_see_event(
sender_user,
&body.room_id,
&pdu.event_id,
)
.unwrap_or(false)
})
.take_while(|&(k, _)| Some(k) != to)
@ -214,7 +237,11 @@ pub(crate) async fn get_message_events_route(
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.user_can_see_event(
sender_user,
&body.room_id,
&pdu.event_id,
)
.unwrap_or(false)
})
.take_while(|&(k, _)| Some(k) != to)
@ -254,11 +281,13 @@ pub(crate) async fn get_message_events_route(
resp.state = Vec::new();
for ll_id in &lazy_loaded {
if let Some(member_event) = services().rooms.state_accessor.room_state_get(
&body.room_id,
&StateEventType::RoomMember,
ll_id.as_str(),
)? {
if let Some(member_event) =
services().rooms.state_accessor.room_state_get(
&body.room_id,
&StateEventType::RoomMember,
ll_id.as_str(),
)?
{
resp.state.push(member_event.to_state_event());
}
}

View file

@ -1,20 +1,25 @@
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma};
use std::sync::Arc;
use ruma::{
api::{
client::{
error::ErrorKind,
profile::{
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name,
get_avatar_url, get_display_name, get_profile, set_avatar_url,
set_display_name,
},
},
federation::{self, query::get_profile_information::v1::ProfileField},
},
events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType},
events::{
room::member::RoomMemberEventContent, StateEventType, TimelineEventType,
},
};
use serde_json::value::to_raw_value;
use std::sync::Arc;
use tracing::warn;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/profile/{userId}/displayname`
///
/// Updates the displayname.
@ -25,9 +30,7 @@ pub(crate) async fn set_displayname_route(
) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.users
.set_displayname(sender_user, body.displayname.clone())?;
services().users.set_displayname(sender_user, body.displayname.clone())?;
// Send a new membership event and presence update into all joined rooms
let all_rooms_joined: Vec<_> = services()
@ -53,14 +56,18 @@ pub(crate) async fn set_displayname_route(
)?
.ok_or_else(|| {
Error::bad_database(
"Tried to send displayname update for user not in the \
room.",
"Tried to send displayname update for \
user not in the room.",
)
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
.map_err(|_| {
Error::bad_database(
"Database contains invalid PDU.",
)
})?
})
.expect("event is valid, we just created it"),
unsigned: None,
@ -88,7 +95,12 @@ pub(crate) async fn set_displayname_route(
if let Err(error) = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await
{
warn!(%error, "failed to add PDU");
@ -138,13 +150,9 @@ pub(crate) async fn set_avatar_url_route(
) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.users
.set_avatar_url(sender_user, body.avatar_url.clone())?;
services().users.set_avatar_url(sender_user, body.avatar_url.clone())?;
services()
.users
.set_blurhash(sender_user, body.blurhash.clone())?;
services().users.set_blurhash(sender_user, body.blurhash.clone())?;
// Send a new membership event and presence update into all joined rooms
let all_joined_rooms: Vec<_> = services()
@ -170,14 +178,18 @@ pub(crate) async fn set_avatar_url_route(
)?
.ok_or_else(|| {
Error::bad_database(
"Tried to send displayname update for user not in the \
room.",
"Tried to send displayname update for \
user not in the room.",
)
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
.map_err(|_| {
Error::bad_database(
"Database contains invalid PDU.",
)
})?
})
.expect("event is valid, we just created it"),
unsigned: None,
@ -205,7 +217,12 @@ pub(crate) async fn set_avatar_url_route(
if let Err(error) = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await
{
warn!(%error, "failed to add PDU");
@ -219,7 +236,8 @@ pub(crate) async fn set_avatar_url_route(
///
/// Returns the `avatar_url` and `blurhash` of the user.
///
/// - If user is on another server: Fetches `avatar_url` and `blurhash` over federation
/// - If user is on another server: Fetches `avatar_url` and `blurhash` over
/// federation
pub(crate) async fn get_avatar_url_route(
body: Ruma<get_avatar_url::v3::Request>,
) -> Result<get_avatar_url::v3::Response> {

View file

@ -1,17 +1,18 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::client::{
error::ErrorKind,
push::{
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled,
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions,
set_pushrule_enabled, RuleScope,
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions,
get_pushrule_enabled, get_pushrules_all, set_pusher, set_pushrule,
set_pushrule_actions, set_pushrule_enabled, RuleScope,
},
},
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
push::{AnyPushRuleRef, InsertPushRuleError, RemovePushRuleError},
};
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/pushrules`
///
/// Retrieves the push rules event for this user.
@ -71,12 +72,11 @@ pub(crate) async fn get_pushrule_route(
.map(Into::into);
if let Some(rule) = rule {
Ok(get_pushrule::v3::Response { rule })
Ok(get_pushrule::v3::Response {
rule,
})
} else {
Err(Error::BadRequest(
ErrorKind::NotFound,
"Push rule not found.",
))
Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
}
}
@ -109,7 +109,9 @@ pub(crate) async fn set_pushrule_route(
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if let Err(error) = account_data.content.global.insert(
body.rule.clone(),
@ -119,16 +121,20 @@ pub(crate) async fn set_pushrule_route(
let err = match error {
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
ErrorKind::InvalidParam,
"Rule IDs starting with a dot are reserved for server-default rules.",
"Rule IDs starting with a dot are reserved for server-default \
rules.",
),
InsertPushRuleError::InvalidRuleId => Error::BadRequest(
ErrorKind::InvalidParam,
"Rule ID containing invalid characters.",
),
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
ErrorKind::InvalidParam,
"Can't place a push rule relatively to a server-default rule.",
),
InsertPushRuleError::RelativeToServerDefaultRule => {
Error::BadRequest(
ErrorKind::InvalidParam,
"Can't place a push rule relatively to a server-default \
rule.",
)
}
InsertPushRuleError::UnknownRuleId => Error::BadRequest(
ErrorKind::NotFound,
"The before or after rule could not be found.",
@ -147,7 +153,8 @@ pub(crate) async fn set_pushrule_route(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
Ok(set_pushrule::v3::Response {})
@ -193,7 +200,9 @@ pub(crate) async fn get_pushrule_actions_route(
"Push rule not found.",
))?;
Ok(get_pushrule_actions::v3::Response { actions })
Ok(get_pushrule_actions::v3::Response {
actions,
})
}
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
@ -224,7 +233,9 @@ pub(crate) async fn set_pushrule_actions_route(
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if account_data
.content
@ -242,7 +253,8 @@ pub(crate) async fn set_pushrule_actions_route(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
Ok(set_pushrule_actions::v3::Response {})
@ -276,7 +288,9 @@ pub(crate) async fn get_pushrule_enabled_route(
))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
let global = account_data.content.global;
let enabled = global
@ -287,7 +301,9 @@ pub(crate) async fn get_pushrule_enabled_route(
"Push rule not found.",
))?;
Ok(get_pushrule_enabled::v3::Response { enabled })
Ok(get_pushrule_enabled::v3::Response {
enabled,
})
}
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
@ -318,7 +334,9 @@ pub(crate) async fn set_pushrule_enabled_route(
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if account_data
.content
@ -336,7 +354,8 @@ pub(crate) async fn set_pushrule_enabled_route(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
Ok(set_pushrule_enabled::v3::Response {})
@ -370,12 +389,12 @@ pub(crate) async fn delete_pushrule_route(
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
.map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
if let Err(error) = account_data
.content
.global
.remove(body.kind.clone(), &body.rule_id)
if let Err(error) =
account_data.content.global.remove(body.kind.clone(), &body.rule_id)
{
let err = match error {
RemovePushRuleError::ServerDefault => Error::BadRequest(
@ -395,7 +414,8 @@ pub(crate) async fn delete_pushrule_route(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
Ok(delete_pushrule::v3::Response {})
@ -424,9 +444,7 @@ pub(crate) async fn set_pushers_route(
) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services()
.pusher
.set_pusher(sender_user, body.action.clone())?;
services().pusher.set_pusher(sender_user, body.action.clone())?;
Ok(set_pusher::v3::Response::default())
}

View file

@ -1,20 +1,27 @@
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
use std::collections::BTreeMap;
use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
api::client::{
error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt,
},
events::{
receipt::{ReceiptThread, ReceiptType},
RoomAccountDataEventType,
},
MilliSecondsSinceUnixEpoch,
};
use std::collections::BTreeMap;
use crate::{
service::rooms::timeline::PduCount, services, Error, Result, Ruma,
};
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
///
/// Sets different types of read markers.
///
/// - Updates fully-read account data event to `fully_read`
/// - If `read_receipt` is set: Update private marker and public read receipt EDU
/// - If `read_receipt` is set: Update private marker and public read receipt
/// EDU
pub(crate) async fn set_read_marker_route(
body: Ruma<set_read_marker::v3::Request>,
) -> Result<set_read_marker::v3::Response> {
@ -30,7 +37,8 @@ pub(crate) async fn set_read_marker_route(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
&serde_json::to_value(fully_read_event)
.expect("to json value always works"),
)?;
}
@ -42,14 +50,9 @@ pub(crate) async fn set_read_marker_route(
}
if let Some(event) = &body.private_read_receipt {
let count = services()
.rooms
.timeline
.get_pdu_count(event)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Event does not exist.",
))?;
let count = services().rooms.timeline.get_pdu_count(event)?.ok_or(
Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."),
)?;
let count = match count {
PduCount::Backfilled(_) => {
return Err(Error::BadRequest(
@ -59,11 +62,11 @@ pub(crate) async fn set_read_marker_route(
}
PduCount::Normal(c) => c,
};
services()
.rooms
.edus
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
services().rooms.edus.read_receipt.private_read_set(
&body.room_id,
sender_user,
count,
)?;
}
if let Some(event) = &body.read_receipt {
@ -86,7 +89,9 @@ pub(crate) async fn set_read_marker_route(
sender_user,
&body.room_id,
ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
content: ruma::events::receipt::ReceiptEventContent(
receipt_content,
),
room_id: body.room_id.clone(),
},
)?;
@ -105,7 +110,8 @@ pub(crate) async fn create_receipt_route(
if matches!(
&body.receipt_type,
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
create_receipt::v3::ReceiptType::Read
| create_receipt::v3::ReceiptType::ReadPrivate
) {
services()
.rooms
@ -124,7 +130,8 @@ pub(crate) async fn create_receipt_route(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
&serde_json::to_value(fully_read_event)
.expect("to json value always works"),
)?;
}
create_receipt::v3::ReceiptType::Read => {
@ -146,7 +153,9 @@ pub(crate) async fn create_receipt_route(
sender_user,
&body.room_id,
ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
content: ruma::events::receipt::ReceiptEventContent(
receipt_content,
),
room_id: body.room_id.clone(),
},
)?;

View file

@ -1,13 +1,13 @@
use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
use ruma::{
api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
};
use serde_json::value::to_raw_value;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
///
/// Tries to send a redaction event into the room.
@ -54,5 +54,7 @@ pub(crate) async fn redact_event_route(
drop(state_lock);
let event_id = (*event_id).to_owned();
Ok(redact_event::v3::Response { event_id })
Ok(redact_event::v3::Response {
event_id,
})
}

View file

@ -23,10 +23,7 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
@ -36,27 +33,22 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
.try_into()
.expect("0-100 should fit in usize");
let res = services()
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
Some(&body.event_type),
Some(&body.rel_type),
from,
to,
limit,
)?;
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
Some(&body.event_type),
Some(&body.rel_type),
from,
to,
limit,
)?;
Ok(
get_relating_events_with_rel_type_and_event_type::v1::Response {
chunk: res.chunk,
next_batch: res.next_batch,
prev_batch: res.prev_batch,
},
)
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
chunk: res.chunk,
next_batch: res.next_batch,
prev_batch: res.prev_batch,
})
}
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
@ -74,10 +66,7 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
@ -87,19 +76,16 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
.try_into()
.expect("0-100 should fit in usize");
let res = services()
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
Some(&body.rel_type),
from,
to,
limit,
)?;
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
Some(&body.rel_type),
from,
to,
limit,
)?;
Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk,
@ -123,10 +109,7 @@ pub(crate) async fn get_relating_events_route(
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
@ -136,17 +119,14 @@ pub(crate) async fn get_relating_events_route(
.try_into()
.expect("0-100 should fit in usize");
services()
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
None,
from,
to,
limit,
)
services().rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
None,
from,
to,
limit,
)
}

View file

@ -1,14 +1,14 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::client::{error::ErrorKind, room::report_content},
events::room::message,
int,
};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/rooms/{roomId}/report/{eventId}`
///
/// Reports an inappropriate event to homeserver admins
///
pub(crate) async fn report_event_route(
body: Ruma<report_content::v3::Request>,
) -> Result<report_content::v3::Response> {

View file

@ -1,6 +1,5 @@
use crate::{
api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma,
};
use std::{cmp::max, collections::BTreeMap, sync::Arc};
use ruma::{
api::client::{
error::ErrorKind,
@ -11,7 +10,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
name::RoomNameEventContent,
@ -26,9 +27,13 @@ use ruma::{
CanonicalJsonObject, OwnedRoomAliasId, RoomAliasId, RoomId, RoomVersionId,
};
use serde_json::{json, value::to_raw_value};
use std::{cmp::max, collections::BTreeMap, sync::Arc};
use tracing::{info, warn};
use crate::{
api::client_server::invite_helper, service::pdu::PduBuilder, services,
Error, Result, Ruma,
};
/// # `POST /_matrix/client/r0/createRoom`
///
/// Creates a new room.
@ -79,32 +84,27 @@ pub(crate) async fn create_room_route(
}
let alias: Option<OwnedRoomAliasId> =
body.room_alias_name
.as_ref()
.map_or(Ok(None), |localpart| {
// TODO: Check for invalid characters and maximum length
let alias = RoomAliasId::parse(format!(
"#{}:{}",
localpart,
services().globals.server_name()
))
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
if services()
.rooms
.alias
.resolve_local_alias(&alias)?
.is_some()
{
Err(Error::BadRequest(
ErrorKind::RoomInUse,
"Room alias already exists.",
))
} else {
Ok(Some(alias))
}
body.room_alias_name.as_ref().map_or(Ok(None), |localpart| {
// TODO: Check for invalid characters and maximum length
let alias = RoomAliasId::parse(format!(
"#{}:{}",
localpart,
services().globals.server_name()
))
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.")
})?;
if services().rooms.alias.resolve_local_alias(&alias)?.is_some() {
Err(Error::BadRequest(
ErrorKind::RoomInUse,
"Room alias already exists.",
))
} else {
Ok(Some(alias))
}
})?;
if let Some(alias) = &alias {
if let Some(info) = &body.appservice_info {
if !info.aliases.is_match(alias.as_str()) {
@ -159,7 +159,10 @@ pub(crate) async fn create_room_route(
content.insert(
"creator".into(),
json!(&sender_user).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?,
);
}
@ -171,7 +174,10 @@ pub(crate) async fn create_room_route(
content.insert(
"room_version".into(),
json!(room_version.as_str()).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?,
);
content
@ -187,20 +193,30 @@ pub(crate) async fn create_room_route(
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()),
| RoomVersionId::V10 => {
RoomCreateEventContent::new_v1(sender_user.clone())
}
RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => unreachable!("Validity of room version already checked"),
};
let mut content = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&content)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?
.map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?
.get(),
)
.unwrap();
content.insert(
"room_version".into(),
json!(room_version.as_str()).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
Error::BadRequest(
ErrorKind::BadJson,
"Invalid creation content",
)
})?,
);
content
@ -209,9 +225,7 @@ pub(crate) async fn create_room_route(
// Validate creation content
let de_result = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&content)
.expect("Invalid creation content")
.get(),
to_raw_value(&content).expect("Invalid creation content").get(),
);
if de_result.is_err() {
@ -228,7 +242,8 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"),
content: to_raw_value(&content)
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
@ -285,17 +300,24 @@ pub(crate) async fn create_room_route(
}
}
let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it");
let mut power_levels_content =
serde_json::to_value(RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it");
if let Some(power_level_content_override) = &body.power_level_content_override {
let json: JsonObject = serde_json::from_str(power_level_content_override.json().get())
.map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.")
})?;
if let Some(power_level_content_override) =
&body.power_level_content_override
{
let json: JsonObject =
serde_json::from_str(power_level_content_override.json().get())
.map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Invalid power_level_content_override.",
)
})?;
for (key, value) in json {
power_levels_content[key] = value;
@ -353,11 +375,13 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(match preset {
RoomPreset::PublicChat => JoinRule::Public,
// according to spec "invite" is the default
_ => JoinRule::Invite,
}))
content: to_raw_value(&RoomJoinRulesEventContent::new(
match preset {
RoomPreset::PublicChat => JoinRule::Public,
// according to spec "invite" is the default
_ => JoinRule::Invite,
},
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
@ -397,10 +421,12 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(match preset {
RoomPreset::PublicChat => GuestAccess::Forbidden,
_ => GuestAccess::CanJoin,
}))
content: to_raw_value(&RoomGuestAccessEventContent::new(
match preset {
RoomPreset::PublicChat => GuestAccess::Forbidden,
_ => GuestAccess::CanJoin,
},
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
@ -414,10 +440,14 @@ pub(crate) async fn create_room_route(
// 6. Events listed in initial_state
for event in &body.initial_state {
let mut pdu_builder = event.deserialize_as::<PduBuilder>().map_err(|e| {
warn!("Invalid initial state event: {:?}", e);
Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.")
})?;
let mut pdu_builder =
event.deserialize_as::<PduBuilder>().map_err(|e| {
warn!("Invalid initial state event: {:?}", e);
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid initial state event.",
)
})?;
// Implicit state key defaults to ""
pdu_builder.state_key.get_or_insert_with(String::new);
@ -432,7 +462,12 @@ pub(crate) async fn create_room_route(
services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await?;
}
@ -444,8 +479,10 @@ pub(crate) async fn create_room_route(
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(name.clone()))
.expect("event is valid, we just created it"),
content: to_raw_value(&RoomNameEventContent::new(
name.clone(),
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
@ -483,7 +520,8 @@ pub(crate) async fn create_room_route(
drop(state_lock);
for user_id in &body.invite {
if let Err(error) =
invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await
invite_helper(sender_user, user_id, &room_id, None, body.is_direct)
.await
{
warn!(%error, "invite helper failed");
};
@ -507,20 +545,19 @@ pub(crate) async fn create_room_route(
///
/// Gets a single event.
///
/// - You have to currently be joined to the room (TODO: Respect history visibility)
/// - You have to currently be joined to the room (TODO: Respect history
/// visibility)
pub(crate) async fn get_room_event_route(
body: Ruma<get_room_event::v3::Request>,
) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services()
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or_else(|| {
let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(
|| {
warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?;
},
)?;
if !services().rooms.state_accessor.user_can_see_event(
sender_user,
@ -545,17 +582,14 @@ pub(crate) async fn get_room_event_route(
///
/// Lists all aliases of the room.
///
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if `history_visibility` is world readable
/// - Only users joined to the room are allowed to call this TODO: Allow any
/// user to call it if `history_visibility` is world readable
pub(crate) async fn get_room_aliases_route(
body: Ruma<aliases::v3::Request>,
) -> Result<aliases::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services()
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this room.",
@ -588,10 +622,7 @@ pub(crate) async fn upgrade_room_route(
) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services()
.globals
.supported_room_versions()
.contains(&body.new_version)
if !services().globals.supported_room_versions().contains(&body.new_version)
{
return Err(Error::BadRequest(
ErrorKind::UnsupportedRoomVersion,
@ -601,10 +632,7 @@ pub(crate) async fn upgrade_room_route(
// Create a replacement room
let replacement_room = RoomId::new(services().globals.server_name());
services()
.rooms
.short
.get_or_create_shortroomid(&replacement_room)?;
services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
let mutex_state = Arc::clone(
services()
@ -617,8 +645,9 @@ pub(crate) async fn upgrade_room_route(
);
let state_lock = mutex_state.lock().await;
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further
// Fail if the sender does not have the required permissions
// Send a m.room.tombstone event to the old room to indicate that it is not
// intended to be used any further Fail if the sender does not have the
// required permissions
let tombstone_event_id = services()
.rooms
.timeline
@ -659,7 +688,9 @@ pub(crate) async fn upgrade_room_route(
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.ok_or_else(|| {
Error::bad_database("Found room without m.room.create event.")
})?
.content
.get(),
)
@ -671,7 +702,8 @@ pub(crate) async fn upgrade_room_route(
(*tombstone_event_id).to_owned(),
));
// Send a m.room.create event containing a predecessor field and the applicable room_version
// Send a m.room.create event containing a predecessor field and the
// applicable room_version
match body.new_version {
RoomVersionId::V1
| RoomVersionId::V2
@ -686,7 +718,10 @@ pub(crate) async fn upgrade_room_route(
create_event_content.insert(
"creator".into(),
json!(&sender_user).try_into().map_err(|_| {
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")
Error::BadRequest(
ErrorKind::BadJson,
"Error forming creation event",
)
})?,
);
}
@ -698,15 +733,21 @@ pub(crate) async fn upgrade_room_route(
}
create_event_content.insert(
"room_version".into(),
json!(&body.new_version)
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?,
json!(&body.new_version).try_into().map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Error forming creation event",
)
})?,
);
create_event_content.insert(
"predecessor".into(),
json!(predecessor)
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?,
json!(predecessor).try_into().map_err(|_| {
Error::BadRequest(
ErrorKind::BadJson,
"Error forming creation event",
)
})?,
);
// Validate creation event content
@ -784,16 +825,15 @@ pub(crate) async fn upgrade_room_route(
// Replicate transferable state events to the new room
for event_type in transferable_state_events {
let event_content =
match services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &event_type, "")?
{
Some(v) => v.content.clone(),
// Skipping missing events.
None => continue,
};
let event_content = match services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &event_type, "")?
{
Some(v) => v.content.clone(),
// Skipping missing events.
None => continue,
};
services()
.rooms
@ -820,30 +860,39 @@ pub(crate) async fn upgrade_room_route(
.local_aliases_for_room(&body.room_id)
.filter_map(Result::ok)
{
services()
.rooms
.alias
.set_alias(&alias, &replacement_room)?;
services().rooms.alias.set_alias(&alias, &replacement_room)?;
}
// Get the old room power levels
let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str(
services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
let mut power_levels_event_content: RoomPowerLevelsEventContent =
serde_json::from_str(
services()
.rooms
.state_accessor
.room_state_get(
&body.room_id,
&StateEventType::RoomPowerLevels,
"",
)?
.ok_or_else(|| {
Error::bad_database(
"Found room without m.room.create event.",
)
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
// Setting events_default and invite to the greater of 50 and users_default + 1
let new_level = max(int!(50), power_levels_event_content.users_default + int!(1));
// Setting events_default and invite to the greater of 50 and users_default
// + 1
let new_level =
max(int!(50), power_levels_event_content.users_default + int!(1));
power_levels_event_content.events_default = new_level;
power_levels_event_content.invite = new_level;
// Modify the power levels in the old room to prevent sending of events and inviting new users
// Modify the power levels in the old room to prevent sending of events and
// inviting new users
let _ = services()
.rooms
.timeline
@ -865,5 +914,7 @@ pub(crate) async fn upgrade_room_route(
drop(state_lock);
// Return the replacement room id
Ok(upgrade_room::v3::Response { replacement_room })
Ok(upgrade_room::v3::Response {
replacement_room,
})
}

View file

@ -1,22 +1,27 @@
use crate::{services, Error, Result, Ruma};
use std::collections::BTreeMap;
use ruma::{
api::client::{
error::ErrorKind,
search::search_events::{
self,
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
v3::{
EventContextResult, ResultCategories, ResultRoomEvents,
SearchResult,
},
},
},
uint,
};
use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/search`
///
/// Searches rooms for messages.
///
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
/// - Only works if the user is currently joined to the room (TODO: Respect
/// history visibility)
#[allow(clippy::too_many_lines)]
pub(crate) async fn search_events_route(
body: Ruma<search_events::v3::Request>,
@ -46,11 +51,7 @@ pub(crate) async fn search_events_route(
let mut searches = Vec::new();
for room_id in room_ids {
if !services()
.rooms
.state_cache
.is_joined(sender_user, &room_id)?
{
if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this room.",
@ -102,7 +103,11 @@ pub(crate) async fn search_events_route(
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.user_can_see_event(
sender_user,
&pdu.room_id,
&pdu.event_id,
)
.unwrap_or(false)
})
.map(|pdu| pdu.to_room_event())

View file

@ -1,5 +1,3 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use ruma::{
api::client::{
error::ErrorKind,
@ -17,6 +15,9 @@ use ruma::{
use serde::Deserialize;
use tracing::{info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
#[derive(Debug, Deserialize)]
struct Claims {
sub: String,
@ -24,30 +25,36 @@ struct Claims {
/// # `GET /_matrix/client/r0/login`
///
/// Get the supported login types of this server. One of these should be used as the `type` field
/// when logging in.
/// Get the supported login types of this server. One of these should be used as
/// the `type` field when logging in.
pub(crate) async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
Ok(get_login_types::v3::Response::new(vec![
get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
get_login_types::v3::LoginType::ApplicationService(
ApplicationServiceLoginType::default(),
),
]))
}
/// # `POST /_matrix/client/r0/login`
///
/// Authenticates the user and returns an access token it can use in subsequent requests.
/// Authenticates the user and returns an access token it can use in subsequent
/// requests.
///
/// - The user needs to authenticate using their password (or if enabled using a json web token)
/// - The user needs to authenticate using their password (or if enabled using a
/// json web token)
/// - If `device_id` is known: invalidates old access token of that device
/// - If `device_id` is unknown: creates a new device
/// - Returns access token that is associated with the user and device
///
/// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to see
/// supported login types.
/// Note: You can use [`GET /_matrix/client/r0/login`](get_login_types_route) to
/// see supported login types.
#[allow(clippy::too_many_lines)]
pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
pub(crate) async fn login_route(
body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
// To allow deprecated login methods
#![allow(deprecated)]
// Validate login method
@ -59,18 +66,29 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
user,
..
}) => {
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
UserId::parse_with_server_name(
user_id.to_lowercase(),
services().globals.server_name(),
)
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
let user_id =
if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) =
identifier
{
UserId::parse_with_server_name(
user_id.to_lowercase(),
services().globals.server_name(),
)
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Bad login type.",
));
}
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(
@ -79,13 +97,12 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
));
}
let hash = services()
.users
.password_hash(&user_id)?
.ok_or(Error::BadRequest(
let hash = services().users.password_hash(&user_id)?.ok_or(
Error::BadRequest(
ErrorKind::Forbidden,
"Wrong username or password.",
))?;
),
)?;
if hash.is_empty() {
return Err(Error::BadRequest(
@ -94,7 +111,9 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
));
}
let hash_matches = argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes())
.unwrap_or(false);
if !hash_matches {
return Err(Error::BadRequest(
@ -105,20 +124,34 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
user_id
}
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
login::v3::LoginInfo::Token(login::v3::Token {
token,
}) => {
if let Some(jwt_decoding_key) =
services().globals.jwt_decoding_key()
{
let token = jsonwebtoken::decode::<Claims>(
token,
jwt_decoding_key,
&jsonwebtoken::Validation::default(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?;
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Token is invalid.",
)
})?;
let username = token.claims.sub.to_lowercase();
let user_id =
UserId::parse_with_server_name(username, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
let user_id = UserId::parse_with_server_name(
username,
services().globals.server_name(),
)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(
@ -131,26 +164,40 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).",
"Token login is not supported (server has no jwt decoding \
key).",
));
}
}
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
identifier,
user,
}) => {
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
UserId::parse_with_server_name(
user_id.to_lowercase(),
services().globals.server_name(),
)
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
login::v3::LoginInfo::ApplicationService(
login::v3::ApplicationService {
identifier,
user,
},
) => {
let user_id =
if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) =
identifier
{
UserId::parse_with_server_name(
user_id.to_lowercase(),
services().globals.server_name(),
)
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Bad login type.",
));
}
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if let Some(info) = &body.appservice_info {
if !info.is_user_match(&user_id) {
@ -225,12 +272,16 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
/// Log out the current device.
///
/// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
/// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events
/// - Triggers device list updates
pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
pub(crate) async fn logout_route(
body: Ruma<logout::v3::Request>,
) -> Result<logout::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device =
body.sender_device.as_ref().expect("user is authenticated");
if let Some(info) = &body.appservice_info {
if !info.is_user_match(sender_user) {
@ -251,12 +302,13 @@ pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logo
/// Log out all devices of this user.
///
/// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
/// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events
/// - Triggers device list updates
///
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](logout_route)
/// from each device of this user.
/// Note: This is equivalent to calling [`GET
/// /_matrix/client/r0/logout`](logout_route) from each device of this user.
pub(crate) async fn logout_all_route(
body: Ruma<logout_all::v3::Request>,
) -> Result<logout_all::v3::Response> {

View file

@ -1,19 +1,18 @@
use crate::{services, Result, Ruma};
use ruma::{api::client::space::get_hierarchy, uint};
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`
///
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space.
/// Paginates over the space tree in a depth-first manner to locate child rooms
/// of a given space.
pub(crate) async fn get_hierarchy_route(
body: Ruma<get_hierarchy::v1::Request>,
) -> Result<get_hierarchy::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let skip = body
.from
.as_ref()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
let skip =
body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
let limit = body
.limit
@ -23,8 +22,10 @@ pub(crate) async fn get_hierarchy_route(
.expect("0-100 should fit in usize");
// Plus one to skip the space room itself
let max_depth = usize::try_from(body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3)))
.expect("0-10 should fit in usize")
let max_depth = usize::try_from(
body.max_depth.map(|x| x.min(uint!(10))).unwrap_or(uint!(3)),
)
.expect("0-10 should fit in usize")
+ 1;
services()

View file

@ -1,25 +1,30 @@
use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
use ruma::{
api::client::{
error::ErrorKind,
state::{get_state_events, get_state_events_for_key, send_state_event},
},
events::{
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType,
room::canonical_alias::RoomCanonicalAliasEventContent,
AnyStateEventContent, StateEventType,
},
serde::Raw,
EventId, RoomId, UserId,
};
use tracing::log::warn;
use crate::{
service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse,
};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
///
/// Sends a state event into the room.
///
/// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed
/// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_key_route(
body: Ruma<send_state_event::v3::Request>,
@ -37,7 +42,9 @@ pub(crate) async fn send_state_event_for_key_route(
.await?;
let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id })
Ok(send_state_event::v3::Response {
event_id,
})
}
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
@ -45,7 +52,8 @@ pub(crate) async fn send_state_event_for_key_route(
/// Sends a state event into the room.
///
/// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed
/// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_empty_key_route(
body: Ruma<send_state_event::v3::Request>,
@ -53,7 +61,9 @@ pub(crate) async fn send_state_event_for_empty_key_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Forbid m.room.encryption if encryption is disabled
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
if body.event_type == StateEventType::RoomEncryption
&& !services().globals.allow_encryption()
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Encryption has been disabled",
@ -70,14 +80,18 @@ pub(crate) async fn send_state_event_for_empty_key_route(
.await?;
let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }.into())
Ok(send_state_event::v3::Response {
event_id,
}
.into())
}
/// # `GET /_matrix/client/r0/rooms/{roomid}/state`
///
/// Get all state events for a room.
///
/// - If not joined: Only works if current room history visibility is world readable
/// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_route(
body: Ruma<get_state_events::v3::Request>,
) -> Result<get_state_events::v3::Response> {
@ -110,7 +124,8 @@ pub(crate) async fn get_state_events_route(
///
/// Get single state event of a room.
///
/// - If not joined: Only works if current room history visibility is world readable
/// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_for_key_route(
body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<get_state_events_for_key::v3::Response> {
@ -140,8 +155,9 @@ pub(crate) async fn get_state_events_for_key_route(
})?;
Ok(get_state_events_for_key::v3::Response {
content: serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content in database"))?,
content: serde_json::from_str(event.content.get()).map_err(|_| {
Error::bad_database("Invalid event content in database")
})?,
})
}
@ -149,7 +165,8 @@ pub(crate) async fn get_state_events_for_key_route(
///
/// Get single state event of a room.
///
/// - If not joined: Only works if current room history visibility is world readable
/// - If not joined: Only works if current room history visibility is world
/// readable
pub(crate) async fn get_state_events_for_empty_key_route(
body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
@ -179,8 +196,9 @@ pub(crate) async fn get_state_events_for_empty_key_route(
})?;
Ok(get_state_events_for_key::v3::Response {
content: serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content in database"))?,
content: serde_json::from_str(event.content.get()).map_err(|_| {
Error::bad_database("Invalid event content in database")
})?,
}
.into())
}
@ -194,10 +212,11 @@ async fn send_state_event_for_key_helper(
) -> Result<Arc<EventId>> {
let sender_user = sender;
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it
// previously existed
if let Ok(canonical_alias) =
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get())
// TODO: Review this check, error if event is unparsable, use event type,
// allow alias if it previously existed
if let Ok(canonical_alias) = serde_json::from_str::<
RoomCanonicalAliasEventContent,
>(json.json().get())
{
let mut aliases = canonical_alias.alt_aliases.clone();
@ -216,8 +235,8 @@ async fn send_state_event_for_key_helper(
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You are only allowed to send canonical_alias \
events when it's aliases already exists",
"You are only allowed to send canonical_alias events when \
it's aliases already exists",
));
}
}
@ -240,7 +259,8 @@ async fn send_state_event_for_key_helper(
.build_and_append_pdu(
PduBuilder {
event_type: event_type.to_string().into(),
content: serde_json::from_str(json.json().get()).expect("content is valid json"),
content: serde_json::from_str(json.json().get())
.expect("content is valid json"),
unsigned: None,
state_key: Some(state_key),
redacts: None,

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,5 @@
use crate::{services, Error, Result, Ruma};
use std::collections::BTreeMap;
use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags},
events::{
@ -6,7 +7,8 @@ use ruma::{
RoomAccountDataEventType,
},
};
use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
///
@ -33,8 +35,9 @@ pub(crate) async fn update_tag_route(
})
},
|e| {
serde_json::from_str(e.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
serde_json::from_str(e.get()).map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})
},
)?;
@ -78,8 +81,9 @@ pub(crate) async fn delete_tag_route(
})
},
|e| {
serde_json::from_str(e.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
serde_json::from_str(e.get()).map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})
},
)?;
@ -120,8 +124,9 @@ pub(crate) async fn get_tags_route(
})
},
|e| {
serde_json::from_str(e.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
serde_json::from_str(e.get()).map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})
},
)?;

View file

@ -1,7 +1,8 @@
use crate::{Result, Ruma};
use std::collections::BTreeMap;
use ruma::api::client::thirdparty::get_protocols;
use std::collections::BTreeMap;
use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/thirdparty/protocols`
///

View file

@ -9,11 +9,8 @@ pub(crate) async fn get_threads_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|l| l.try_into().ok())
.unwrap_or(10)
.min(100);
let limit =
body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
let from = if let Some(from) = &body.from {
from.parse()

View file

@ -1,6 +1,5 @@
use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
use ruma::{
api::{
client::{error::ErrorKind, to_device::send_event_to_device},
@ -9,6 +8,8 @@ use ruma::{
to_device::DeviceIdOrAllDevices,
};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
///
/// Send a to-device event to a set of client devices.
@ -29,7 +30,8 @@ pub(crate) async fn send_event_to_device_route(
for (target_user_id, map) in &body.messages {
for (target_device_id_maybe, event) in map {
if target_user_id.server_name() != services().globals.server_name() {
if target_user_id.server_name() != services().globals.server_name()
{
let mut map = BTreeMap::new();
map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new();
@ -38,14 +40,16 @@ pub(crate) async fn send_event_to_device_route(
services().sending.send_reliable_edu(
target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent {
sender: sender_user.clone(),
ev_type: body.event_type.clone(),
message_id: count.to_string().into(),
messages,
},
))
serde_json::to_vec(
&federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent {
sender: sender_user.clone(),
ev_type: body.event_type.clone(),
message_id: count.to_string().into(),
messages,
},
),
)
.expect("DirectToDevice EDU can be serialized"),
count,
)?;
@ -61,20 +65,28 @@ pub(crate) async fn send_event_to_device_route(
target_device_id,
&body.event_type.to_string(),
event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
})?,
)?;
}
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) {
for target_device_id in
services().users.all_device_ids(target_user_id)
{
services().users.add_to_device_event(
sender_user,
target_user_id,
&target_device_id?,
&body.event_type.to_string(),
event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
})?,
)?;
}
@ -84,9 +96,12 @@ pub(crate) async fn send_event_to_device_route(
}
// Save transaction id with empty data
services()
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
services().transaction_ids.add_txnid(
sender_user,
sender_device,
&body.txn_id,
&[],
)?;
Ok(send_event_to_device::v3::Response {})
}

View file

@ -1,6 +1,7 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
use crate::{services, utils, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
///
/// Sets the typing state of the sender user.
@ -11,11 +12,7 @@ pub(crate) async fn create_typing_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services()
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You are not in this room.",

View file

@ -6,14 +6,16 @@ use crate::{Result, Ruma};
/// # `GET /_matrix/client/versions`
///
/// Get the versions of the specification and unstable features supported by this server.
/// Get the versions of the specification and unstable features supported by
/// this server.
///
/// - Versions take the form MAJOR.MINOR.PATCH
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
/// - Unstable features are namespaced and may include version information in their name
/// - Unstable features are namespaced and may include version information in
/// their name
///
/// Note: Unstable features are used while developing new features. Clients should avoid using
/// unstable features in their stable releases
/// Note: Unstable features are used while developing new features. Clients
/// should avoid using unstable features in their stable releases
pub(crate) async fn get_supported_versions_route(
_body: Ruma<get_supported_versions::Request>,
) -> Result<get_supported_versions::Response> {
@ -27,7 +29,10 @@ pub(crate) async fn get_supported_versions_route(
"v1.4".to_owned(),
"v1.5".to_owned(),
],
unstable_features: BTreeMap::from_iter([("org.matrix.e2e_cross_signing".to_owned(), true)]),
unstable_features: BTreeMap::from_iter([(
"org.matrix.e2e_cross_signing".to_owned(),
true,
)]),
};
Ok(resp)

View file

@ -1,4 +1,3 @@
use crate::{services, Result, Ruma};
use ruma::{
api::client::user_directory::search_users,
events::{
@ -7,11 +6,14 @@ use ruma::{
},
};
use crate::{services, Result, Ruma};
/// # `POST /_matrix/client/r0/user_directory/search`
///
/// Searches all known users for a match.
///
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
/// - Hides any local users that aren't in any public rooms (i.e. those that
/// have the join rule set to public)
/// and don't share a room with the sender
pub(crate) async fn search_users_route(
body: Ruma<search_users::v3::Request>,
@ -38,8 +40,7 @@ pub(crate) async fn search_users_route(
.display_name
.as_ref()
.filter(|name| {
name.to_lowercase()
.contains(&body.search_term.to_lowercase())
name.to_lowercase().contains(&body.search_term.to_lowercase())
})
.is_some();
@ -62,10 +63,12 @@ pub(crate) async fn search_users_route(
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
.map_or(false, |event| {
event.map_or(false, |event| {
serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| {
serde_json::from_str(event.content.get()).map_or(
false,
|r: RoomJoinRulesEventContent| {
r.join_rule == JoinRule::Public
})
},
)
})
})
});
@ -96,5 +99,8 @@ pub(crate) async fn search_users_route(
let results = users.by_ref().take(limit).collect();
let limited = users.next().is_some();
Ok(search_users::v3::Response { results, limited })
Ok(search_users::v3::Response {
results,
limited,
})
}

View file

@ -1,9 +1,11 @@
use crate::{services, Result, Ruma};
use std::time::{Duration, SystemTime};
use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
use sha1::Sha1;
use std::time::{Duration, SystemTime};
use crate::{services, Result, Ruma};
type HmacSha1 = Hmac<Sha1>;
@ -24,7 +26,8 @@ pub(crate) async fn turn_server_route(
)
} else {
let expiry = SecondsSinceUnixEpoch::from_system_time(
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
SystemTime::now()
+ Duration::from_secs(services().globals.turn_ttl()),
)
.expect("time is valid");
@ -34,7 +37,8 @@ pub(crate) async fn turn_server_route(
.expect("HMAC can take key of any size");
mac.update(username.as_bytes());
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
let password: String =
general_purpose::STANDARD.encode(mac.finalize().into_bytes());
(username, password)
};

View file

@ -1,10 +1,12 @@
use crate::{service::appservice::RegistrationInfo, Error};
use ruma::{
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
OwnedUserId,
};
use std::ops::Deref;
use ruma::{
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId,
OwnedServerName, OwnedUserId,
};
use crate::{service::appservice::RegistrationInfo, Error};
mod axum;
/// Extractor for Ruma request structs

View file

@ -3,7 +3,9 @@ use std::{collections::BTreeMap, iter::FromIterator, str};
use axum::{
async_trait,
body::{Full, HttpBody},
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
extract::{
rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader,
},
headers::{
authorization::{Bearer, Credentials},
Authorization,
@ -14,7 +16,9 @@ use axum::{
use bytes::{Buf, BufMut, Bytes, BytesMut};
use http::{Request, StatusCode};
use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
api::{
client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse,
},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
};
use serde::Deserialize;
@ -41,7 +45,10 @@ where
type Rejection = Error;
#[allow(clippy::too_many_lines)]
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
async fn from_request(
req: Request<B>,
_state: &S,
) -> Result<Self, Self::Rejection> {
#[derive(Deserialize)]
struct QueryParams {
access_token: Option<String>,
@ -51,22 +58,23 @@ where
let (mut parts, mut body) = match req.with_limited_body() {
Ok(limited_req) => {
let (parts, body) = limited_req.into_parts();
let body = to_bytes(body)
.await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
let body = to_bytes(body).await.map_err(|_| {
Error::BadRequest(ErrorKind::MissingToken, "Missing token.")
})?;
(parts, body)
}
Err(original_req) => {
let (parts, body) = original_req.into_parts();
let body = to_bytes(body)
.await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
let body = to_bytes(body).await.map_err(|_| {
Error::BadRequest(ErrorKind::MissingToken, "Missing token.")
})?;
(parts, body)
}
};
let metadata = T::METADATA;
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
let auth_header: Option<TypedHeader<Authorization<Bearer>>> =
parts.extract().await?;
let path_params: Path<Vec<String>> = parts.extract().await?;
let query = parts.uri.query().unwrap_or_default();
@ -87,9 +95,13 @@ where
};
let token = if let Some(token) = token {
if let Some(reg_info) = services().appservice.find_from_token(token).await {
if let Some(reg_info) =
services().appservice.find_from_token(token).await
{
Token::Appservice(Box::new(reg_info.clone()))
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? {
} else if let Some((user_id, device_id)) =
services().users.find_from_token(token)?
{
Token::User((user_id, OwnedDeviceId::from(device_id)))
} else {
Token::Invalid
@ -98,13 +110,16 @@ where
Token::None
};
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let mut json_body =
serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let (sender_user, sender_device, sender_servername, appservice_info) =
match (metadata.authentication, token) {
(_, Token::Invalid) => {
return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false },
ErrorKind::UnknownToken {
soft_logout: false,
},
"Unknown access token.",
))
}
@ -121,7 +136,10 @@ where
UserId::parse,
)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if !info.is_user_match(&user_id) {
@ -153,7 +171,9 @@ where
));
}
(
AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None,
AuthScheme::AccessToken
| AuthScheme::AccessTokenOptional
| AuthScheme::None,
Token::User((user_id, device_id)),
) => (Some(user_id), Some(device_id), None, None),
(AuthScheme::ServerSignatures, Token::None) => {
@ -161,7 +181,10 @@ where
.extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
warn!(
"Missing or invalid Authorization header: {}",
e
);
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => {
@ -189,7 +212,9 @@ where
let mut request_map = BTreeMap::from_iter([
(
"method".to_owned(),
CanonicalJsonValue::String(parts.method.to_string()),
CanonicalJsonValue::String(
parts.method.to_string(),
),
),
(
"uri".to_owned(),
@ -197,12 +222,18 @@ where
),
(
"origin".to_owned(),
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
CanonicalJsonValue::String(
x_matrix.origin.as_str().to_owned(),
),
),
(
"destination".to_owned(),
CanonicalJsonValue::String(
services().globals.server_name().as_str().to_owned(),
services()
.globals
.server_name()
.as_str()
.to_owned(),
),
),
(
@ -212,13 +243,17 @@ where
]);
if let Some(json_body) = &json_body {
request_map.insert("content".to_owned(), json_body.clone());
request_map
.insert("content".to_owned(), json_body.clone());
};
let keys_result = services()
.rooms
.event_handler
.fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.clone()])
.fetch_signing_keys(
&x_matrix.origin,
vec![x_matrix.key.clone()],
)
.await;
let keys = match keys_result {
@ -232,22 +267,29 @@ where
}
};
let pub_key_map =
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
let pub_key_map = BTreeMap::from_iter([(
x_matrix.origin.as_str().to_owned(),
keys,
)]);
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
match ruma::signatures::verify_json(
&pub_key_map,
&request_map,
) {
Ok(()) => (None, None, Some(x_matrix.origin), None),
Err(e) => {
warn!(
"Failed to verify json request from {}: {}\n{:?}",
"Failed to verify json request from {}: \
{}\n{:?}",
x_matrix.origin, e, request_map
);
if parts.uri.to_string().contains('@') {
warn!(
"Request uri contained '@' character. Make sure your \
reverse proxy gives Grapevine the raw uri (apache: use \
nocanon)"
"Request uri contained '@' character. \
Make sure your reverse proxy gives \
Grapevine the raw uri (apache: use \
nocanon)"
);
}
@ -264,27 +306,36 @@ where
| AuthScheme::AccessTokenOptional,
Token::None,
) => (None, None, None, None),
(AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => {
(
AuthScheme::ServerSignatures,
Token::Appservice(_) | Token::User(_),
) => {
return Err(Error::BadRequest(
ErrorKind::Unauthorized,
"Only server signatures should be used on this endpoint.",
"Only server signatures should be used on this \
endpoint.",
));
}
(AuthScheme::AppserviceToken, Token::User(_)) => {
return Err(Error::BadRequest(
ErrorKind::Unauthorized,
"Only appservice access tokens should be used on this endpoint.",
"Only appservice access tokens should be used on this \
endpoint.",
));
}
};
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
let mut http_request =
http::Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| {
UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid")
UserId::parse_with_server_name(
"",
services().globals.server_name(),
)
.expect("we know this is valid")
});
let uiaa_request = json_body
@ -300,14 +351,17 @@ where
)
});
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
if let Some(CanonicalJsonValue::Object(initial_request)) =
uiaa_request
{
for (key, value) in initial_request {
json_body.entry(key).or_insert(value);
}
}
let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail");
serde_json::to_writer(&mut buf, json_body)
.expect("value serialization can't fail");
body = buf.into_inner().freeze();
}
@ -315,11 +369,15 @@ where
debug!("{:?}", http_request);
let body = T::try_from_http_request(http_request, &path_params).map_err(|e| {
warn!("try_from_http_request failed: {:?}", e);
debug!("JSON body: {:?}", json_body);
Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.")
})?;
let body = T::try_from_http_request(http_request, &path_params)
.map_err(|e| {
warn!("try_from_http_request failed: {:?}", e);
debug!("JSON body: {:?}", json_body);
Error::BadRequest(
ErrorKind::BadJson,
"Failed to deserialize request.",
)
})?;
Ok(Ruma {
body,
@ -345,7 +403,8 @@ impl Credentials for XMatrix {
fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
"HeaderValue to decode should start with \"X-Matrix ..\", \
received = {value:?}",
);
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
@ -359,8 +418,9 @@ impl Credentials for XMatrix {
for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not in the spec,
// let's simply accept either form for every field.
// It's not at all clear why some fields are quoted and others not
// in the spec, let's simply accept either form for
// every field.
let value = value
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))

File diff suppressed because it is too large Load diff

View file

@ -108,7 +108,10 @@ impl Config {
}
if was_deprecated {
warn!("Read grapevine documentation and check your configuration if any new configuration parameters should be adjusted");
warn!(
"Read grapevine documentation and check your configuration if \
any new configuration parameters should be adjusted"
);
}
}
}

View file

@ -24,9 +24,10 @@ use crate::Result;
/// ## Include vs. Exclude
/// If include is an empty list, it is assumed to be `["*"]`.
///
/// If a domain matches both the exclude and include list, the proxy will only be used if it was
/// included because of a more specific rule than it was excluded. In the above example, the proxy
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
/// If a domain matches both the exclude and include list, the proxy will only
/// be used if it was included because of a more specific rule than it was
/// excluded. In the above example, the proxy would be used for
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
@ -43,7 +44,9 @@ impl ProxyConfig {
pub(crate) fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() {
ProxyConfig::None => None,
ProxyConfig::Global { url } => Some(Proxy::all(url)?),
ProxyConfig::Global {
url,
} => Some(Proxy::all(url)?),
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
// first matching proxy
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned()
@ -112,25 +115,32 @@ impl WildCardedDomain {
WildCardedDomain::Exact(d) => domain == d,
}
}
pub(crate) fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
a != b && a.ends_with(b)
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => {
other.matches(a)
}
(
WildCardedDomain::WildCarded(a),
WildCardedDomain::WildCarded(b),
) => a != b && a.ends_with(b),
_ => false,
}
}
}
impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation?
Ok(s.strip_prefix("*.")
.map(|x| WildCardedDomain::WildCarded(x.to_owned()))
.or_else(|| (s == "*").then(|| WildCardedDomain::WildCarded(String::new())))
.or_else(|| {
(s == "*").then(|| WildCardedDomain::WildCarded(String::new()))
})
.unwrap_or_else(|| WildCardedDomain::Exact(s.to_owned())))
}
}

View file

@ -1,23 +1,6 @@
pub(crate) mod abstraction;
pub(crate) mod key_value;
use crate::{
service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services,
SERVICES,
};
use abstraction::{KeyValueDatabaseEngine, KvTree};
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
UserId,
};
use std::{
collections::{BTreeMap, HashMap, HashSet},
fs,
@ -27,8 +10,25 @@ use std::{
sync::{Arc, Mutex, RwLock},
};
use abstraction::{KeyValueDatabaseEngine, KvTree};
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId,
OwnedUserId, RoomId, UserId,
};
use tracing::{debug, error, info, warn};
use crate::{
service::rooms::timeline::PduCount, services, utils, Config, Error,
PduEvent, Result, Services, SERVICES,
};
pub(crate) struct KeyValueDatabase {
db: Arc<dyn KeyValueDatabaseEngine>,
@ -74,8 +74,9 @@ pub(crate) struct KeyValueDatabase {
// Trees "owned" by `self::key_value::uiaa`
// User-interactive authentication
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>,
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
pub(super) userdevicesessionid_uiaarequest: RwLock<
BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>,
>,
// Trees "owned" by `self::key_value::rooms::edus`
// ReadReceiptId = RoomId + Count + UserId
@ -169,13 +170,15 @@ pub(crate) struct KeyValueDatabase {
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
// StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
// StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 +
// (shortstatekey+shorteventid--)
pub(super) shortstatehash_statediff: Arc<dyn KvTree>,
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
pub(super) softfailedeventids: Arc<dyn KvTree>,
@ -214,10 +217,12 @@ pub(crate) struct KeyValueDatabase {
// EduCount: Count of last EDU sync
pub(super) servername_educount: Arc<dyn KvTree>,
// ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
// ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id
// (for edus), Data = EDU content
pub(super) servernameevent_data: Arc<dyn KvTree>,
// ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
// ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for
// edus), Data = EDU content
pub(super) servercurrentevent_data: Arc<dyn KvTree>,
// Trees "owned" by `self::key_value::appservice`
@ -231,10 +236,14 @@ pub(crate) struct KeyValueDatabase {
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) statekeyshort_cache:
Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache:
Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache:
RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache:
RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
@ -271,13 +280,15 @@ impl KeyValueDatabase {
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
"Found sqlite at database_path, but is not specified in \
config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
"Found rocksdb at database_path, but is not specified in \
config.",
));
}
@ -294,19 +305,30 @@ impl KeyValueDatabase {
Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
std::fs::create_dir_all(&config.database_path)
.map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?;
std::fs::create_dir_all(&config.database_path).map_err(|_| {
Error::BadConfig(
"Database folder doesn't exists and couldn't be created \
(e.g. due to missing permissions). Please create the \
database folder yourself.",
)
})?;
}
#[cfg_attr(
not(any(feature = "rocksdb", feature = "sqlite")),
allow(unused_variables)
)]
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config
.database_backend
{
#[cfg(feature = "sqlite")]
"sqlite" => Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?),
"sqlite" => {
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
#[cfg(feature = "rocksdb")]
"rocksdb" => Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?),
"rocksdb" => {
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
_ => {
return Err(Error::BadConfig("Database backend not found."));
}
@ -327,28 +349,38 @@ impl KeyValueDatabase {
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
userdeviceid_metadata: builder
.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder
.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
onetimekeyid_onetimekeys: builder
.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder
.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userid_selfsigningkeyid: builder
.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder
.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaainfo: builder
.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
readreceiptid_readreceipt: builder
.open_tree("readreceiptid_readreceipt")?,
// "Private" read receipt
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?,
roomuserid_privateread: builder
.open_tree("roomuserid_privateread")?,
roomuserid_lastprivatereadupdate: builder
.open_tree("roomuserid_lastprivatereadupdate")?,
presenceid_presence: builder.open_tree("presenceid_presence")?,
userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?,
userid_lastpresenceupdate: builder
.open_tree("userid_lastpresenceupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
@ -367,9 +399,12 @@ impl KeyValueDatabase {
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
roomuseroncejoinedids: builder
.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder
.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder
.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
@ -377,41 +412,57 @@ impl KeyValueDatabase {
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
userroomid_notificationcount: builder
.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder
.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder
.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
statekey_shortstatekey: builder
.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder
.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
shorteventid_authchain: builder
.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
shortstatehash_statediff: builder
.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
shorteventid_shortstatehash: builder
.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder
.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder
.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder
.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
roomuserdataid_accountdata: builder
.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder
.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
userdevicetxnid_response: builder
.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
servercurrentevent_data: builder
.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder
.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
@ -489,11 +540,13 @@ impl KeyValueDatabase {
if !services().users.exists(&grapevine_user)? {
error!(
"The {} server user does not exist, and the database is not new.",
"The {} server user does not exist, and the database is \
not new.",
grapevine_user
);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first."
"Cannot reuse an existing database after changing the \
server name, please delete the old one first.",
));
}
}
@ -505,14 +558,15 @@ impl KeyValueDatabase {
// MIGRATIONS
if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xff);
let room_id = parts.next().expect("split always returns one element");
let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id =
parts.next().expect("split always returns one element");
let Some(servername) = parts.next() else {
error!("Migration: Invalid roomserverid in db.");
continue;
};
let mut serverroomid = servername.to_vec();
serverroomid.push(0xff);
serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?;
@ -524,13 +578,16 @@ impl KeyValueDatabase {
}
if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just ""
// We accidentally inserted hashed versions of "" into the db
// instead of just ""
for (userid, password) in db.userid_password.iter() {
let password = utils::string_from_bytes(&password);
let empty_hashed_password = password.map_or(false, |password| {
argon2::verify_encoded(&password, b"").unwrap_or(false)
});
let empty_hashed_password =
password.map_or(false, |password| {
argon2::verify_encoded(&password, b"")
.unwrap_or(false)
});
if empty_hashed_password {
db.userid_password.insert(&userid, b"")?;
@ -567,10 +624,16 @@ impl KeyValueDatabase {
if services().users.is_deactivated(&our_user)? {
continue;
}
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
for room in
services().rooms.state_cache.rooms_joined(&our_user)
{
for user in
services().rooms.state_cache.room_members(&room?)
{
let user = user?;
if user.server_name() != services().globals.server_name() {
if user.server_name()
!= services().globals.server_name()
{
info!(?user, "Migration: creating user");
services().users.create(&user, None)?;
}
@ -585,16 +648,18 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 5 {
// Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xff);
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter()
{
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap();
let event_type =
roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id);
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid
@ -611,7 +676,10 @@ impl KeyValueDatabase {
for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?;
services()
.rooms
.state_cache
.update_joined_count(room_id)?;
}
services().globals.bump_database_version(6)?;
@ -621,7 +689,8 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut last_roomstates: HashMap<OwnedRoomId, u64> =
HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
@ -633,7 +702,8 @@ impl KeyValueDatabase {
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
counter += 1;
let last_roomsstatehash = last_roomstates.get(current_room);
let last_roomsstatehash =
last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
@ -641,12 +711,16 @@ impl KeyValueDatabase {
services()
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
.load_shortstatehash_info(
last_roomsstatehash,
)
},
)?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
if let Some(parent_stateinfo) =
states_parents.last()
{
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.copied()
@ -663,21 +737,28 @@ impl KeyValueDatabase {
(current_state, HashSet::new())
};
services().rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
// every state change is 2 event changes on average
2,
states_parents,
)?;
services()
.rooms
.state_compressor
.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
// every state change is 2 event changes on
// average
2,
states_parents,
)?;
Ok::<_, Error>(())
};
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
for (k, seventid) in
db.db.open_tree("stateid_shorteventid")?.iter()
{
let sstatehash =
utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
@ -687,15 +768,23 @@ impl KeyValueDatabase {
current_state,
&mut last_roomstates,
)?;
last_roomstates
.insert(current_room.clone().unwrap(), current_sstatehash);
last_roomstates.insert(
current_room.clone().unwrap(),
current_sstatehash,
);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let event_id = db
.shorteventid_eventid
.get(&seventid)
.unwrap()
.unwrap();
let string =
utils::string_from_bytes(&event_id).unwrap();
let event_id =
<&EventId>::try_from(string.as_str()).unwrap();
let pdu = services()
.rooms
.timeline
@ -710,7 +799,8 @@ impl KeyValueDatabase {
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
current_state
.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
@ -730,7 +820,8 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
let shortroomid =
services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
@ -739,7 +830,7 @@ impl KeyValueDatabase {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xff);
let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
@ -757,25 +848,26 @@ impl KeyValueDatabase {
db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let mut batch2 =
db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
Some((k, new_value))
});
db.eventid_pduid.insert_batch(&mut batch2)?;
@ -793,7 +885,7 @@ impl KeyValueDatabase {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xff);
let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
@ -806,7 +898,7 @@ impl KeyValueDatabase {
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(word);
new_key.push(0xff);
new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count);
Some((new_key, Vec::new()))
})
@ -836,12 +928,15 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
for (statekey, shortstatekey) in
db.statekey_shortstatekey.iter()
{
db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
// Force E2EE device list updates so we can send them over
// federation
for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?;
}
@ -852,9 +947,7 @@ impl KeyValueDatabase {
}
if services().globals.database_version()? < 11 {
db.db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?;
services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished");
@ -878,24 +971,34 @@ impl KeyValueDatabase {
.get(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
)
.unwrap()
.expect("Username is invalid");
let mut account_data =
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
serde_json::from_str::<PushRulesEvent>(
raw_rules_list.get(),
)
.unwrap();
let rules_list = &mut account_data.content.global;
//content rule
{
let content_rule_transformation =
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let content_rule_transformation = [
".m.rules.contains_user_name",
".m.rule.contains_user_name",
];
let rule = rules_list.content.get(content_rule_transformation[0]);
let rule = rules_list
.content
.get(content_rule_transformation[0]);
if rule.is_some() {
let mut rule = rule.unwrap().clone();
rule.rule_id = content_rule_transformation[1].to_owned();
rule.rule_id =
content_rule_transformation[1].to_owned();
rules_list
.content
.shift_remove(content_rule_transformation[0]);
@ -907,7 +1010,10 @@ impl KeyValueDatabase {
{
let underride_rule_transformation = [
[".m.rules.call", ".m.rule.call"],
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
[
".m.rules.room_one_to_one",
".m.rule.room_one_to_one",
],
[
".m.rules.encrypted_room_one_to_one",
".m.rule.encrypted_room_one_to_one",
@ -917,11 +1023,14 @@ impl KeyValueDatabase {
];
for transformation in underride_rule_transformation {
let rule = rules_list.underride.get(transformation[0]);
let rule =
rules_list.underride.get(transformation[0]);
if let Some(rule) = rule {
let mut rule = rule.clone();
rule.rule_id = transformation[1].to_owned();
rules_list.underride.shift_remove(transformation[0]);
rules_list
.underride
.shift_remove(transformation[0]);
rules_list.underride.insert(rule);
}
}
@ -930,8 +1039,11 @@ impl KeyValueDatabase {
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
}
@ -940,7 +1052,8 @@ impl KeyValueDatabase {
warn!("Migration: 11 -> 12 finished");
}
// This migration can be reused as-is anytime the server-default rules are updated.
// This migration can be reused as-is anytime the server-default
// rules are updated.
if services().globals.database_version()? < 13 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(
@ -959,15 +1072,21 @@ impl KeyValueDatabase {
.get(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
)
.unwrap()
.expect("Username is invalid");
let mut account_data =
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
serde_json::from_str::<PushRulesEvent>(
raw_rules_list.get(),
)
.unwrap();
let user_default_rules = ruma::push::Ruleset::server_default(&user);
let user_default_rules =
ruma::push::Ruleset::server_default(&user);
account_data
.content
.global
@ -976,8 +1095,11 @@ impl KeyValueDatabase {
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)?;
}
@ -1018,13 +1140,24 @@ impl KeyValueDatabase {
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!");
services().admin.send_message(RoomMessageEventContent::text_plain("The Grapevine account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
warn!(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin account \
recovery!"
);
services().admin.send_message(
RoomMessageEventContent::text_plain(
"The Grapevine account emergency password is set! \
Please unset it as soon as you finish admin \
account recovery!",
),
);
}
}
Err(e) => {
error!(
"Could not set the configured emergency password for the grapevine user: {}",
"Could not set the configured emergency password for the \
grapevine user: {}",
e
);
}
@ -1050,15 +1183,15 @@ impl KeyValueDatabase {
#[tracing::instrument]
pub(crate) async fn start_cleanup_task() {
use tokio::time::interval;
use std::time::{Duration, Instant};
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tokio::time::interval;
use std::time::{Duration, Instant};
let timer_interval =
Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
let timer_interval = Duration::from_secs(u64::from(
services().globals.config.cleanup_second_interval,
));
tokio::spawn(async move {
let mut i = interval(timer_interval);
@ -1092,11 +1225,14 @@ impl KeyValueDatabase {
}
}
/// Sets the emergency password and push rules for the @grapevine account in case emergency password is set
/// Sets the emergency password and push rules for the @grapevine account in
/// case emergency password is set
fn set_emergency_access() -> Result<bool> {
let grapevine_user =
UserId::parse_with_server_name("grapevine", services().globals.server_name())
.expect("@grapevine:server_name is a valid UserId");
let grapevine_user = UserId::parse_with_server_name(
"grapevine",
services().globals.server_name(),
)
.expect("@grapevine:server_name is a valid UserId");
services().users.set_password(
&grapevine_user,
@ -1113,7 +1249,9 @@ fn set_emergency_access() -> Result<bool> {
&grapevine_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent { global: ruleset },
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;

View file

@ -1,8 +1,8 @@
use std::{future::Future, pin::Pin, sync::Arc};
use super::Config;
use crate::Result;
use std::{future::Future, pin::Pin, sync::Arc};
#[cfg(feature = "sqlite")]
pub(crate) mod sqlite;
@ -22,7 +22,8 @@ pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
Ok(())
}
fn memory_usage(&self) -> Result<String> {
Ok("Current database engine does not support memory usage reporting.".to_owned())
Ok("Current database engine does not support memory usage reporting."
.to_owned())
}
fn clear_caches(&self) {}
}
@ -31,7 +32,10 @@ pub(crate) trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()>;
fn remove(&self, key: &[u8]) -> Result<()>;
@ -44,14 +48,20 @@ pub(crate) trait KvTree: Send + Sync {
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()>;
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear(&self) -> Result<()> {
for (key, _) in self.iter() {

View file

@ -1,17 +1,21 @@
use rocksdb::{
perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache,
ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType, DBRecoveryMode, DBWithThreadMode,
Direction, IteratorMode, MultiThreaded, Options, ReadOptions, WriteOptions,
};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{utils, Result};
use std::{
future::Future,
pin::Pin,
sync::{Arc, RwLock},
};
use rocksdb::{
perf::get_memory_usage_stats, BlockBasedOptions, BoundColumnFamily, Cache,
ColumnFamilyDescriptor, DBCompactionStyle, DBCompressionType,
DBRecoveryMode, DBWithThreadMode, Direction, IteratorMode, MultiThreaded,
Options, ReadOptions, WriteOptions,
};
use super::{
super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree,
};
use crate::{utils, Result};
pub(crate) struct Engine {
rocks: DBWithThreadMode<MultiThreaded>,
max_open_files: i32,
@ -38,7 +42,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &Cache) -> Options {
let mut db_opts = Options::default();
db_opts.set_block_based_table_factory(&block_based_options);
db_opts.create_if_missing(true);
db_opts.increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX));
db_opts
.increase_parallelism(num_cpus::get().try_into().unwrap_or(i32::MAX));
db_opts.set_max_open_files(max_open_files);
db_opts.set_compression_type(DBCompressionType::Lz4);
db_opts.set_bottommost_compression_type(DBCompressionType::Zstd);
@ -69,13 +74,17 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)]
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let cache_capacity_bytes =
(config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let rocksdb_cache = Cache::new_lru_cache(cache_capacity_bytes);
let db_opts = db_options(config.rocksdb_max_open_files, &rocksdb_cache);
let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(&db_opts, &config.database_path)
.unwrap_or_default();
let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(
&db_opts,
&config.database_path,
)
.unwrap_or_default();
let db = DBWithThreadMode::<MultiThreaded>::open_cf_descriptors(
&db_opts,
@ -119,14 +128,14 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
fn memory_usage(&self) -> Result<String> {
let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
let stats =
get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
Ok(format!(
"Approximate memory usage of all the mem-tables: {:.3} MB\n\
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\
Approximate memory usage of all the table readers: {:.3} MB\n\
Approximate memory usage by cache: {:.3} MB\n\
Approximate memory usage by cache pinned: {:.3} MB\n\
",
"Approximate memory usage of all the mem-tables: {:.3} \
MB\nApproximate memory usage of un-flushed mem-tables: {:.3} \
MB\nApproximate memory usage of all the table readers: {:.3} \
MB\nApproximate memory usage by cache: {:.3} MB\nApproximate \
memory usage by cache pinned: {:.3} MB\n",
stats.mem_table_total as f64 / 1024.0 / 1024.0,
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
@ -154,9 +163,7 @@ impl KvTree for RocksDbEngineTree<'_> {
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default();
let lock = self.write_lock.read().unwrap();
self.db
.rocks
.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
drop(lock);
self.watchers.wake(key);
@ -164,12 +171,13 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(())
}
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()> {
let writeoptions = WriteOptions::default();
for (key, value) in iter {
self.db
.rocks
.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
}
Ok(())
@ -177,10 +185,7 @@ impl KvTree for RocksDbEngineTree<'_> {
fn remove(&self, key: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default();
Ok(self
.db
.rocks
.delete_cf_opt(&self.cf(), key, &writeoptions)?)
Ok(self.db.rocks.delete_cf_opt(&self.cf(), key, &writeoptions)?)
}
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
@ -230,26 +235,26 @@ impl KvTree for RocksDbEngineTree<'_> {
let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?;
let new = utils::increment(old.as_deref());
self.db
.rocks
.put_cf_opt(&self.cf(), key, &new, &writeoptions)?;
self.db.rocks.put_cf_opt(&self.cf(), key, &new, &writeoptions)?;
drop(lock);
Ok(new)
}
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()> {
let readoptions = ReadOptions::default();
let writeoptions = WriteOptions::default();
let lock = self.write_lock.write().unwrap();
for key in iter {
let old = self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?;
let old =
self.db.rocks.get_cf_opt(&self.cf(), &key, &readoptions)?;
let new = utils::increment(old.as_deref());
self.db
.rocks
.put_cf_opt(&self.cf(), key, new, &writeoptions)?;
self.db.rocks.put_cf_opt(&self.cf(), key, new, &writeoptions)?;
}
drop(lock);
@ -277,7 +282,10 @@ impl KvTree for RocksDbEngineTree<'_> {
)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix)
}
}

View file

@ -1,7 +1,3 @@
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use std::{
cell::RefCell,
future::Future,
@ -9,9 +5,15 @@ use std::{
pin::Pin,
sync::Arc,
};
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use thread_local::ThreadLocal;
use tracing::debug;
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
@ -68,7 +70,11 @@ impl Engine {
conn.pragma_update(Some(Main), "page_size", 2048)?;
conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
conn.pragma_update(
Some(Main),
"cache_size",
-i64::from(cache_size_kb),
)?;
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
Ok(conn)
@ -79,18 +85,23 @@ impl Engine {
}
fn read_lock(&self) -> &Connection {
self.read_conn_tls
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
self.read_conn_tls.get_or(|| {
Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()
})
}
fn read_lock_iterator(&self) -> &Connection {
self.read_iterator_conn_tls
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
self.read_iterator_conn_tls.get_or(|| {
Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()
})
}
pub(crate) fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock()
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
self.write_lock().pragma_update(
Some(Main),
"wal_checkpoint",
"RESTART",
)?;
Ok(())
}
}
@ -108,7 +119,8 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
// calculates cache-size per permanent connection
// 1. convert MB to KiB
// 2. divide by permanent connections + permanent iter connections + write connection
// 2. divide by permanent connections + permanent iter connections +
// write connection
// 3. round down to nearest integer
#[allow(
clippy::as_conversions,
@ -117,9 +129,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_sign_loss
)]
let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0)
/ ((num_cpus::get() as f64 * 2.0) + 1.0)) as u32;
/ ((num_cpus::get() as f64 * 2.0) + 1.0))
as u32;
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
let writer =
Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
let arc = Arc::new(Engine {
writer,
@ -133,7 +147,13 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
}
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?;
self.write_lock().execute(
&format!(
"CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY \
KEY, \"value\" BLOB NOT NULL )"
),
[],
)?;
Ok(Arc::new(SqliteTable {
engine: Arc::clone(self),
@ -161,14 +181,26 @@ pub(crate) struct SqliteTable {
type TupleOfBytes = (Vec<u8>, Vec<u8>);
impl SqliteTable {
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
fn get_with_guard(
&self,
guard: &Connection,
key: &[u8],
) -> Result<Option<Vec<u8>>> {
Ok(guard
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
.prepare(
format!("SELECT value FROM {} WHERE key = ?", self.name)
.as_str(),
)?
.query_row([key], |row| row.get(0))
.optional()?)
}
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
fn insert_with_guard(
&self,
guard: &Connection,
key: &[u8],
value: &[u8],
) -> Result<()> {
guard.execute(
format!(
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
@ -222,7 +254,10 @@ impl KvTree for SqliteTable {
Ok(())
}
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
) -> Result<()> {
let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?;
@ -236,7 +271,10 @@ impl KvTree for SqliteTable {
Ok(())
}
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
) -> Result<()> {
let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?;
@ -282,7 +320,8 @@ impl KvTree for SqliteTable {
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
"SELECT key, value FROM {} WHERE key <= ? ORDER BY \
key DESC",
&self.name
))
.unwrap(),
@ -292,7 +331,9 @@ impl KvTree for SqliteTable {
let iterator = Box::new(
statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.query_map([from], |row| {
Ok((row.get_unwrap(0), row.get_unwrap(1)))
})
.unwrap()
.map(Result::unwrap),
);
@ -304,7 +345,8 @@ impl KvTree for SqliteTable {
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
"SELECT key, value FROM {} WHERE key >= ? ORDER BY \
key ASC",
&self.name
))
.unwrap(),
@ -314,7 +356,9 @@ impl KvTree for SqliteTable {
let iterator = Box::new(
statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.query_map([from], |row| {
Ok((row.get_unwrap(0), row.get_unwrap(1)))
})
.unwrap()
.map(Result::unwrap),
);
@ -338,14 +382,20 @@ impl KvTree for SqliteTable {
Ok(new)
}
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
Box::new(
self.iter_from(&prefix, false)
.take_while(move |(key, _)| key.starts_with(&prefix)),
)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
fn watch_prefix<'a>(
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix)
}

View file

@ -4,12 +4,14 @@ use std::{
pin::Pin,
sync::RwLock,
};
use tokio::sync::watch;
#[derive(Default)]
pub(super) struct Watchers {
#[allow(clippy::type_complexity)]
watchers: RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>,
watchers:
RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>,
}
impl Watchers {
@ -17,7 +19,8 @@ impl Watchers {
&'a self,
prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec())
{
hash_map::Entry::Occupied(o) => o.get().1.clone(),
hash_map::Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(());
@ -31,6 +34,7 @@ impl Watchers {
rx.changed().await.unwrap();
})
}
pub(super) fn wake(&self, key: &[u8]) {
let watchers = self.watchers.read().unwrap();
let mut triggered = Vec::new();

View file

@ -7,10 +7,13 @@ use ruma::{
RoomId, UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry.
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
&self,
@ -24,13 +27,14 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xff);
roomuserdataid
.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
@ -45,13 +49,13 @@ impl service::account_data::Data for KeyValueDatabase {
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
&serde_json::to_vec(&data)
.expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
@ -74,17 +78,15 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
self.roomuserdataid_accountdata.get(&roomuserdataid).transpose()
})
.transpose()?
.map(|data| {
@ -101,7 +103,8 @@ impl service::account_data::Data for KeyValueDatabase {
room_id: Option<&RoomId>,
user_id: &UserId,
since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>
{
let mut userdata = HashMap::new();
let mut prefix = room_id
@ -109,9 +112,9 @@ impl service::account_data::Data for KeyValueDatabase {
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
@ -124,14 +127,27 @@ impl service::account_data::Data for KeyValueDatabase {
.map(|(k, v)| {
Ok::<_, Error>((
RoomAccountDataEventType::from(
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|| Error::bad_database("RoomUserData ID in db is invalid."),
)?)
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
utils::string_from_bytes(
k.rsplit(|&b| b == 0xFF).next().ok_or_else(
|| {
Error::bad_database(
"RoomUserData ID in db is invalid.",
)
},
)?,
)
.map_err(|_| {
Error::bad_database(
"RoomUserData ID in db is invalid.",
)
})?,
),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
Error::bad_database("Database contains invalid account data.")
})?,
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
.map_err(|_| {
Error::bad_database(
"Database contains invalid account data.",
)
})?,
))
})
{

View file

@ -20,8 +20,7 @@ impl service::appservice::Data for KeyValueDatabase {
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
Ok(())
}
@ -30,20 +29,25 @@ impl service::appservice::Data for KeyValueDatabase {
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes).map_err(|_| {
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
Error::bad_database(
"Invalid registration bytes in \
id_appserviceregistrations.",
)
})
})
.transpose()
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|(id, _)| {
utils::string_from_bytes(&id).map_err(|_| {
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
})
},
)))
fn iter_ids<'a>(
&'a self,
) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id).map_err(|_| {
Error::bad_database(
"Invalid id bytes in id_appserviceregistrations.",
)
})
})))
}
fn all(&self) -> Result<Vec<(String, Registration)>> {

View file

@ -6,10 +6,13 @@ use lru_cache::LruCache;
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName,
UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
pub(crate) const COUNTER: &[u8] = b"c";
@ -27,14 +30,18 @@ impl service::globals::Data for KeyValueDatabase {
})
}
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
async fn watch(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xff);
userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xff);
userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new();
@ -46,10 +53,10 @@ impl service::globals::Data for KeyValueDatabase {
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
self.userroomid_notificationcount.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
futures
.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
@ -70,17 +77,24 @@ impl service::globals::Data for KeyValueDatabase {
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xff);
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.edus.typing.wait_for_update(&room_id).await;
let _result = services()
.rooms
.edus
.typing
.wait_for_update(&room_id)
.await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
futures.push(
self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix),
);
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
@ -90,12 +104,11 @@ impl service::globals::Data for KeyValueDatabase {
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xff];
let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
@ -107,7 +120,8 @@ impl service::globals::Data for KeyValueDatabase {
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures
.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch()));
@ -126,10 +140,14 @@ impl service::globals::Data for KeyValueDatabase {
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let statekeyshort_cache =
self.statekeyshort_cache.lock().unwrap().len();
let our_real_users_cache =
self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache =
self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache =
self.lasttimelinecount_cache.lock().unwrap().len();
let mut response = format!(
"\
@ -194,27 +212,29 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|s| Ok(s.clone()),
)?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
parts.next().expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.ok_or_else(|| {
Error::bad_database("Invalid keypair format in database.")
})
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
Ed25519KeyPair::from_der(key, version).map_err(|_| {
Error::bad_database("Private or public keys are invalid.")
})
})
}
fn remove_keypair(&self) -> Result<()> {
self.global.remove(b"keypair")
}
@ -231,7 +251,10 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
ServerSigningKeys::new(
origin.to_owned(),
MilliSecondsSinceUnixEpoch::now(),
)
});
let ServerSigningKeys {
@ -245,7 +268,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
&serde_json::to_vec(&keys)
.expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys;
@ -258,7 +282,8 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
Ok(tree)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(
&self,
origin: &ServerName,
@ -283,8 +308,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version)
.map_err(|_| Error::bad_database("Database version id is invalid."))
utils::u64_from_bytes(&version).map_err(|_| {
Error::bad_database("Database version id is invalid.")
})
})
}

View file

@ -9,7 +9,9 @@ use ruma::{
OwnedRoomId, RoomId, UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::key_backups::Data for KeyValueDatabase {
fn create_backup(
@ -20,12 +22,13 @@ impl service::key_backups::Data for KeyValueDatabase {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
&serde_json::to_vec(backup_metadata)
.expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
@ -34,13 +37,13 @@ impl service::key_backups::Data for KeyValueDatabase {
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xff);
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
@ -56,7 +59,7 @@ impl service::key_backups::Data for KeyValueDatabase {
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
@ -73,9 +76,12 @@ impl service::key_backups::Data for KeyValueDatabase {
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
fn get_latest_backup_version(
&self,
user_id: &UserId,
) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -85,11 +91,13 @@ impl service::key_backups::Data for KeyValueDatabase {
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
.map_err(|_| {
Error::bad_database("backupid_algorithm key is invalid.")
})
})
.transpose()
}
@ -99,7 +107,7 @@ impl service::key_backups::Data for KeyValueDatabase {
user_id: &UserId,
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -109,33 +117,42 @@ impl service::key_backups::Data for KeyValueDatabase {
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
.map_err(|_| {
Error::bad_database("backupid_algorithm key is invalid.")
})?;
Ok((
version,
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
Error::bad_database(
"Algorithm in backupid_algorithm is invalid.",
)
})?,
))
})
.transpose()
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
fn get_backup(
&self,
user_id: &UserId,
version: &str,
) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database(
"Algorithm in backupid_algorithm is invalid.",
)
})
})
}
fn add_key(
@ -147,7 +164,7 @@ impl service::key_backups::Data for KeyValueDatabase {
key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
@ -160,9 +177,9 @@ impl service::key_backups::Data for KeyValueDatabase {
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
@ -173,7 +190,7 @@ impl service::key_backups::Data for KeyValueDatabase {
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
@ -181,7 +198,7 @@ impl service::key_backups::Data for KeyValueDatabase {
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
@ -200,40 +217,56 @@ impl service::key_backups::Data for KeyValueDatabase {
version: &str,
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff);
for result in
self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id =
utils::string_from_bytes(parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.")
})?)
.map_err(|_| {
Error::bad_database("backupkeyid_backup session_id is invalid.")
})?;
let room_id = RoomId::parse(
utils::string_from_bytes(parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.")
})?)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
let session_id = utils::string_from_bytes(
parts.next().ok_or_else(|| {
Error::bad_database(
"backupkeyid_backup key is invalid.",
)
})?,
)
.map_err(|_| {
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
Error::bad_database(
"backupkeyid_backup session_id is invalid.",
)
})?;
let key_data = serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
let room_id = RoomId::parse(
utils::string_from_bytes(parts.next().ok_or_else(
|| {
Error::bad_database(
"backupkeyid_backup key is invalid.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"backupkeyid_backup room_id is invalid.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"backupkeyid_backup room_id is invalid room id.",
)
})?;
let key_data =
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
})?;
Ok::<_, Error>((room_id, session_id, key_data))
})
{
@ -257,30 +290,38 @@ impl service::key_backups::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff);
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id =
utils::string_from_bytes(parts.next().ok_or_else(|| {
Error::bad_database("backupkeyid_backup key is invalid.")
})?)
.map_err(|_| {
Error::bad_database("backupkeyid_backup session_id is invalid.")
})?;
let key_data = serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
let session_id = utils::string_from_bytes(
parts.next().ok_or_else(|| {
Error::bad_database(
"backupkeyid_backup key is invalid.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"backupkeyid_backup session_id is invalid.",
)
})?;
let key_data =
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
})?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
@ -295,18 +336,20 @@ impl service::key_backups::Data for KeyValueDatabase {
session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
Error::bad_database(
"KeyBackupData in backupkeyid_backup is invalid.",
)
})
})
.transpose()
@ -314,9 +357,9 @@ impl service::key_backups::Data for KeyValueDatabase {
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xff);
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
@ -325,13 +368,18 @@ impl service::key_backups::Data for KeyValueDatabase {
Ok(())
}
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
fn delete_room_keys(
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
@ -348,11 +396,11 @@ impl service::key_backups::Data for KeyValueDatabase {
session_id: &str,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {

View file

@ -12,22 +12,19 @@ impl service::media::Data for KeyValueDatabase {
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
@ -42,24 +39,25 @@ impl service::media::Data for KeyValueDatabase {
height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xff);
prefix.push(0xFF);
let (key, _) = self
.mediaid_file
.scan_prefix(prefix)
.next()
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let (key, _) =
self.mediaid_file.scan_prefix(prefix).next().ok_or(
Error::BadRequest(ErrorKind::NotFound, "Media not found"),
)?;
let mut parts = key.rsplit(|&b| b == 0xff);
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
.map(|bytes| {
utils::string_from_bytes(bytes).map_err(|_| {
Error::bad_database("Content type in mediaid_file is invalid unicode.")
Error::bad_database(
"Content type in mediaid_file is invalid unicode.",
)
})
})
.transpose()?;
@ -71,11 +69,14 @@ impl service::media::Data for KeyValueDatabase {
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
})?,
)
Some(utils::string_from_bytes(content_disposition_bytes).map_err(
|_| {
Error::bad_database(
"Content Disposition in mediaid_file is invalid \
unicode.",
)
},
)?)
};
Ok((content_disposition, content_type, key))
}

View file

@ -6,30 +6,39 @@ use ruma::{
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
fn set_pusher(
&self,
sender: &UserId,
pusher: set_pusher::v3::PusherAction,
) -> Result<()> {
match &pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher.insert(
&key,
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
&serde_json::to_vec(&pusher)
.expect("Pusher is valid JSON value"),
)?;
Ok(())
}
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
}
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xff);
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
@ -43,7 +52,7 @@ impl service::pusher::Data for KeyValueDatabase {
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
@ -59,16 +68,20 @@ impl service::pusher::Data for KeyValueDatabase {
sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xff);
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
let push_key = parts.next().ok_or_else(|| {
Error::bad_database("Invalid senderkey_pusher in db")
})?;
let push_key_string =
utils::string_from_bytes(push_key).map_err(|_| {
Error::bad_database(
"Invalid pusher bytes in senderkey_pusher",
)
})?;
Ok(push_key_string)
}))

View file

@ -1,6 +1,11 @@
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use ruma::{
api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId,
RoomId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::alias::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
@ -8,17 +13,20 @@ impl service::rooms::alias::Data for KeyValueDatabase {
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xff);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
aliasid.push(0xFF);
aliasid
.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
#[tracing::instrument(skip(self))]
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
if let Some(room_id) =
self.alias_roomid.get(alias.alias().as_bytes())?
{
let mut prefix = room_id.clone();
prefix.push(0xff);
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
@ -34,14 +42,23 @@ impl service::rooms::alias::Data for KeyValueDatabase {
}
#[tracing::instrument(skip(self))]
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
fn resolve_local_alias(
&self,
alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
RoomId::parse(utils::string_from_bytes(&bytes).map_err(
|_| {
Error::bad_database(
"Room ID in alias_roomid is invalid unicode.",
)
},
)?)
.map_err(|_| {
Error::bad_database("Room ID in alias_roomid is invalid.")
})
})
.transpose()
}
@ -52,13 +69,17 @@ impl service::rooms::alias::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.map_err(|_| {
Error::bad_database("Invalid alias bytes in aliasid_alias.")
})?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
.map_err(|_| {
Error::bad_database("Invalid alias in aliasid_alias.")
})
}))
}
}

View file

@ -3,9 +3,13 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{database::KeyValueDatabase, service, utils, Result};
impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
fn get_cached_eventid_authchain(
&self,
key: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key)
{
return Ok(Some(Arc::clone(result)));
}
@ -18,7 +22,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.map(|chunk| {
utils::u64_from_bytes(chunk)
.expect("byte length is correct")
})
.collect()
});
@ -38,7 +45,11 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
Ok(None)
}
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
fn cache_auth_chain(
&self,
key: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
@ -51,10 +62,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
}
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
Ok(())
}

View file

@ -19,14 +19,18 @@ impl service::rooms::directory::Data for KeyValueDatabase {
}
#[tracing::instrument(skip(self))]
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
fn public_rooms<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Room ID in publicroomids is invalid unicode.",
)
})?)
.map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid.")
})
}))
}
}

View file

@ -1,10 +1,13 @@
use std::mem;
use ruma::{
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject,
OwnedUserId, RoomId, UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(
@ -14,7 +17,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
event: ReceiptEvent,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -25,7 +28,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
@ -36,13 +39,15 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xff);
room_latest_id
.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
&serde_json::to_vec(&event)
.expect("EduEvent::to_string always works"),
)?;
Ok(())
@ -64,7 +69,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
> + 'a,
> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
@ -79,21 +84,35 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
let count = utils::u64_from_bytes(
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
)
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
.map_err(|_| {
Error::bad_database(
"Invalid readreceiptid count in db.",
)
})?;
let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
.map_err(|_| {
Error::bad_database("Invalid readreceiptid userid bytes in db.")
})?,
utils::string_from_bytes(
&k[prefix.len() + mem::size_of::<u64>() + 1..],
)
.map_err(|_| {
Error::bad_database(
"Invalid readreceiptid userid bytes in db.",
)
})?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
.map_err(|_| {
Error::bad_database(
"Invalid readreceiptid userid in db.",
)
})?;
let mut json =
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
Error::bad_database(
"Read receipt in roomlatestid_roomlatest is invalid json.",
)
})?;
serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| {
Error::bad_database(
"Read receipt in roomlatestid_roomlatest \
is invalid json.",
)
})?;
json.remove("room_id");
Ok((
@ -109,36 +128,46 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
}
#[tracing::instrument(skip(self))]
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
fn private_read_set(
&self,
room_id: &RoomId,
user_id: &UserId,
count: u64,
) -> Result<()> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
}
#[tracing::instrument(skip(self))]
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
Error::bad_database("Invalid private read marker bytes")
})?))
})
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
Error::bad_database("Invalid private read marker bytes")
})?))
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
@ -146,7 +175,9 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
Error::bad_database(
"Count in roomuserid_lastprivatereadupdate is invalid.",
)
})
})
.transpose()?

View file

@ -11,11 +11,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
@ -28,11 +28,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
@ -50,11 +50,11 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;

View file

@ -1,6 +1,8 @@
use ruma::{OwnedRoomId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::metadata::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
@ -19,14 +21,18 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
fn iter_ids<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Room ID in publicroomids is invalid unicode.",
)
})?)
.map_err(|_| {
Error::bad_database("Room ID in roomid_shortroomid is invalid.")
})
}))
}

View file

@ -3,24 +3,35 @@ use ruma::{CanonicalJsonObject, EventId};
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
fn get_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(
Ok(None),
|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
},
)
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(
Ok(None),
|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
},
)
}
#[tracing::instrument(skip(self, pdu))]
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
fn add_pdu_outlier(
&self,
event_id: &EventId,
pdu: &CanonicalJsonObject,
) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),

View file

@ -22,7 +22,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
shortroomid: u64,
target: u64,
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
@ -40,8 +41,12 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let from = utils::u64_from_bytes(
&tofrom[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database("Invalid count in tofrom_relation.")
})?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
@ -50,7 +55,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
.ok_or_else(|| {
Error::bad_database(
"Pdu in tofrom_relation is invalid.",
)
})?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
@ -59,7 +68,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
fn mark_as_referenced(
&self,
room_id: &RoomId,
event_ids: &[Arc<EventId>],
) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
@ -69,7 +82,11 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
Ok(())
}
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
@ -80,8 +97,6 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
}
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
}
}

View file

@ -4,7 +4,12 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result};
impl service::rooms::search::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
fn index_pdu(
&self,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
@ -13,7 +18,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xff);
key.push(0xFF);
// TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id);
(key, Vec::new())
@ -28,7 +33,8 @@ impl service::rooms::search::Data for KeyValueDatabase {
&'a self,
room_id: &RoomId,
search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>> {
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>
{
let prefix = services()
.rooms
.short
@ -46,7 +52,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xff);
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
@ -60,7 +66,9 @@ impl service::rooms::search::Data for KeyValueDatabase {
});
// We compare b with a because we reversed the iterator earlier
let Some(common_elements) = utils::common_elements(iterators, |a, b| b.cmp(a)) else {
let Some(common_elements) =
utils::common_elements(iterators, |a, b| b.cmp(a))
else {
return Ok(None);
};

View file

@ -2,26 +2,32 @@ use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
if let Some(short) =
self.eventidshort_cache.lock().unwrap().get_mut(event_id)
{
return Ok(*short);
}
let short =
if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
let short = if let Some(shorteventid) =
self.eventid_shorteventid.get(event_id.as_bytes())?
{
utils::u64_from_bytes(&shorteventid).map_err(|_| {
Error::bad_database("Invalid shorteventid in db.")
})?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
self.eventidshort_cache
.lock()
@ -46,15 +52,16 @@ impl service::rooms::short::Data for KeyValueDatabase {
}
let mut db_key = event_type.to_string().as_bytes().to_vec();
db_key.push(0xff);
db_key.push(0xFF);
db_key.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&db_key)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
utils::u64_from_bytes(&shortstatekey).map_err(|_| {
Error::bad_database("Invalid shortstatekey in db.")
})
})
.transpose()?;
@ -83,12 +90,15 @@ impl service::rooms::short::Data for KeyValueDatabase {
}
let mut db_key = event_type.to_string().as_bytes().to_vec();
db_key.push(0xff);
db_key.push(0xFF);
db_key.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&db_key)? {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
let short = if let Some(shortstatekey) =
self.statekey_shortstatekey.get(&db_key)?
{
utils::u64_from_bytes(&shortstatekey).map_err(|_| {
Error::bad_database("Invalid shortstatekey in db.")
})?
} else {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
@ -106,12 +116,12 @@ impl service::rooms::short::Data for KeyValueDatabase {
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
if let Some(id) = self
.shorteventid_cache
.lock()
.unwrap()
.get_mut(&shorteventid)
fn get_eventid_from_short(
&self,
shorteventid: u64,
) -> Result<Arc<EventId>> {
if let Some(id) =
self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid)
{
return Ok(Arc::clone(id));
}
@ -119,12 +129,20 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
.ok_or_else(|| {
Error::bad_database("Shorteventid does not exist")
})?;
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"EventID in shorteventid_eventid is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database("EventId in shorteventid_eventid is invalid.")
})?;
self.shorteventid_cache
.lock()
@ -134,12 +152,12 @@ impl service::rooms::short::Data for KeyValueDatabase {
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
fn get_statekey_from_short(
&self,
shortstatekey: u64,
) -> Result<(StateEventType, String)> {
if let Some(id) =
self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey)
{
return Ok(id.clone());
}
@ -147,23 +165,32 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
.ok_or_else(|| {
Error::bad_database("Shortstatekey does not exist")
})?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type =
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes =
parts.next().expect("split always returns one entry");
let statekey_bytes = parts.next().ok_or_else(|| {
Error::bad_database("Invalid statekey in shortstatekey_statekey.")
})?;
let event_type = StateEventType::from(
utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database(
"Event type in shortstatekey_statekey is invalid unicode.",
)
})?,
);
let state_key =
utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database(
"Statekey in shortstatekey_statekey is invalid unicode.",
)
})?;
let result = (event_type, state_key);
self.shortstatekey_cache
@ -175,12 +202,18 @@ impl service::rooms::short::Data for KeyValueDatabase {
}
/// Returns `(shortstatehash, already_existed)`
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)> {
Ok(
if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
if let Some(shortstatehash) =
self.statehash_shortstatehash.get(state_hash)?
{
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
utils::u64_from_bytes(&shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in db.")
})?,
true,
)
} else {
@ -196,17 +229,21 @@ impl service::rooms::short::Data for KeyValueDatabase {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})
})
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(
if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
if let Some(short) =
self.roomid_shortroomid.get(room_id.as_bytes())?
{
utils::u64_from_bytes(&short).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid

View file

@ -1,20 +1,22 @@
use ruma::{EventId, OwnedEventId, RoomId};
use std::collections::HashSet;
use std::{collections::HashSet, sync::Arc};
use std::sync::Arc;
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(
Ok(None),
|bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
Error::bad_database(
"Invalid shortstatehash in roomid_shortstatehash",
)
})?))
})
},
)
}
fn set_room_state(
@ -29,23 +31,40 @@ impl service::rooms::state::Data for KeyValueDatabase {
Ok(())
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
fn set_event_state(
&self,
shorteventid: u64,
shortstatehash: u64,
) -> Result<()> {
self.shorteventid_shortstatehash.insert(
&shorteventid.to_be_bytes(),
&shortstatehash.to_be_bytes(),
)?;
Ok(())
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
fn get_forward_extremities(
&self,
room_id: &RoomId,
) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(
|_| {
Error::bad_database(
"EventID in roomid_pduleaves is invalid unicode.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"EventId in roomid_pduleaves is invalid.",
)
})
})
.collect()
}
@ -58,7 +77,7 @@ impl service::rooms::state::Data for KeyValueDatabase {
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;

View file

@ -1,12 +1,19 @@
use std::{collections::HashMap, sync::Arc};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
#[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
async fn state_full_ids(
&self,
shortstatehash: u64,
) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
@ -56,7 +63,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.ok_or_else(|| {
Error::bad_database(
"State event has no state key.",
)
})?
.clone(),
),
pdu,
@ -72,17 +83,16 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
let Some(shortstatekey) =
services().rooms.short.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
@ -106,7 +116,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self,
shortstatehash: u64,
@ -121,20 +132,22 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(
Ok(None),
|shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
"Invalid shortstatehash bytes in \
shorteventid_shortstatehash",
)
})
})
.transpose()
})
},
)
}
/// Returns the full room state.
@ -151,7 +164,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self,
room_id: &RoomId,
@ -167,7 +181,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self,
room_id: &RoomId,

View file

@ -13,20 +13,24 @@ use crate::{
};
impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
fn mark_as_once_joined(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
}
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
@ -46,11 +50,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
@ -72,11 +76,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
// TODO
@ -112,7 +116,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
joinedcount += 1;
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
for _invited in
self.room_members_invited(room_id).filter_map(Result::ok)
{
invitedcount += 1;
}
@ -127,15 +133,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
for old_joined_server in
self.room_servers(room_id).filter_map(Result::ok)
{
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff);
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xff);
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
@ -146,49 +154,45 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff);
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xff);
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
self.appservice_in_room_cache.write().unwrap().remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
fn get_our_real_users(
&self,
room_id: &RoomId,
) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe =
self.our_real_users_cache.read().unwrap().get(room_id).cloned();
if let Some(users) = maybe {
Ok(users)
} else {
self.update_joined_count(room_id)?;
Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
self.our_real_users_cache.read().unwrap().get(room_id).unwrap(),
))
}
}
#[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
fn appservice_in_room(
&self,
room_id: &RoomId,
appservice: &RegistrationInfo,
) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
@ -206,11 +210,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
)
.ok();
let in_room = bridge_user_id
.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self.room_members(room_id).any(|userid| {
userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))
});
let in_room = bridge_user_id.map_or(false, |id| {
self.is_joined(&id, room_id).unwrap_or(false)
}) || self.room_members(room_id).any(|userid| {
userid.map_or(false, |userid| {
appservice.users.is_match(userid.as_str())
})
});
self.appservice_in_room_cache
.write()
@ -227,11 +233,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
@ -247,51 +253,66 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Server name in roomserverids is invalid unicode.")
Error::bad_database(
"Server name in roomserverids is invalid unicode.",
)
})?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
.map_err(|_| {
Error::bad_database("Server name in roomserverids is invalid.")
})
}))
}
#[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
fn server_in_room(
&self,
server: &ServerName,
room_id: &RoomId,
) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we know).
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))]
fn server_rooms<'a>(
&'a self,
server: &ServerName,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
.map_err(|_| {
Error::bad_database(
"RoomId in serverroomids is invalid unicode.",
)
})?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
.map_err(|_| {
Error::bad_database("RoomId in serverroomids is invalid.")
})
}))
}
@ -302,20 +323,24 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
Error::bad_database(
"User ID in roomuserid_joined is invalid unicode.",
)
})?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
.map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid.")
})
}))
}
@ -324,8 +349,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| {
utils::u64_from_bytes(&b)
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
utils::u64_from_bytes(&b).map_err(|_| {
Error::bad_database("Invalid joinedcount in db.")
})
})
.transpose()
}
@ -335,8 +361,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| {
utils::u64_from_bytes(&b)
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
utils::u64_from_bytes(&b).map_err(|_| {
Error::bad_database("Invalid joinedcount in db.")
})
})
.transpose()
}
@ -348,27 +375,30 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database(
"User ID in room_useroncejoined is invalid unicode.",
)
})?,
Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(
|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
.map_err(|_| {
Error::bad_database(
"User ID in room_useroncejoined is invalid \
unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"User ID in room_useroncejoined is invalid.",
)
})
},
))
}
/// Returns an iterator over all invited members of a room.
@ -378,53 +408,64 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
})?,
Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(
|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
.map_err(|_| {
Error::bad_database(
"User ID in roomuserid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"User ID in roomuserid_invited is invalid.",
)
})
},
))
}
#[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
fn get_invite_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid invitecount in db.")
})?))
})
self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid invitecount in db.")
})?))
})
}
#[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
fn get_left_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid leftcount in db."))
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid leftcount in db.")
})
})
.transpose()
}
@ -441,15 +482,22 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
Error::bad_database(
"Room ID in userroomid_joined is invalid \
unicode.",
)
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
.map_err(|_| {
Error::bad_database(
"Room ID in userroomid_joined is invalid.",
)
})
}),
)
}
@ -460,35 +508,43 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn rooms_invited<'a>(
&'a self,
user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a> {
) -> Box<
dyn Iterator<
Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>,
> + 'a,
> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(
|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid.")
})?;
Error::bad_database(
"Room ID in userroomid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"Room ID in userroomid_invited is invalid.",
)
})?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database("Invalid state in userroomid_invitestate.")
})?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database(
"Invalid state in userroomid_invitestate.",
)
})?;
Ok((room_id, state))
}),
)
Ok((room_id, state))
},
))
}
#[tracing::instrument(skip(self))]
@ -498,14 +554,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database(
"Invalid state in userroomid_invitestate.",
)
})?;
Ok(state)
})
@ -519,14 +578,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database(
"Invalid state in userroomid_leftstate.",
)
})?;
Ok(state)
})
@ -539,41 +601,48 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
fn rooms_left<'a>(
&'a self,
user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a> {
) -> Box<
dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>>
+ 'a,
> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(
|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid.")
})?;
Error::bad_database(
"Room ID in userroomid_invited is invalid unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database(
"Room ID in userroomid_invited is invalid.",
)
})?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
let state = serde_json::from_slice(&state).map_err(|_| {
Error::bad_database(
"Invalid state in userroomid_leftstate.",
)
})?;
Ok((room_id, state))
}),
)
Ok((room_id, state))
},
))
}
#[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
@ -582,7 +651,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
@ -591,7 +660,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
@ -600,7 +669,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())

View file

@ -12,8 +12,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent =
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.expect("bytes have right length");
let parent = (parent != 0).then_some(parent);
let mut add_mode = true;
@ -30,7 +30,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
removed
.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
@ -42,7 +43,11 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
})
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
fn save_statediff(
&self,
shortstatehash: u64,
diff: StateDiff,
) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);

View file

@ -1,8 +1,14 @@
use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use ruma::{
api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId,
UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
impl service::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>(
@ -28,14 +34,22 @@ impl service::rooms::threads::Data for KeyValueDatabase {
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let count = utils::u64_from_bytes(
&pduid[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database(
"Invalid pduid in threadid_userids.",
)
})?;
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| {
Error::bad_database("Invalid pduid reference in threadid_userids")
Error::bad_database(
"Invalid pduid reference in threadid_userids",
)
})?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
@ -45,28 +59,43 @@ impl service::rooms::threads::Data for KeyValueDatabase {
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
fn update_participants(
&self,
root_id: &[u8],
participants: &[OwnedUserId],
) -> Result<()> {
let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xff][..]);
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
fn get_participants(
&self,
root_id: &[u8],
) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xff)
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| {
Error::bad_database("Invalid UserId bytes in threadid_userids.")
})?)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
UserId::parse(utils::string_from_bytes(bytes).map_err(
|_| {
Error::bad_database(
"Invalid UserId bytes in threadid_userids.",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"Invalid UserId in threadid_userids.",
)
})
})
.filter_map(Result::ok)
.collect(),

View file

@ -1,16 +1,23 @@
use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId,
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId,
RoomId, UserId,
};
use service::rooms::timeline::PduCount;
use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use service::rooms::timeline::PduCount;
use crate::{
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
fn last_timeline_count(
&self,
sender_user: &UserId,
room_id: &RoomId,
) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
@ -45,14 +52,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}
/// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
fn get_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
serde_json::from_slice(&pdu).map_err(|_| {
Error::bad_database("Invalid PDU in db.")
})
})
.transpose()
},
@ -61,17 +72,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
Error::bad_database("Invalid pduid in eventid_pduid.")
})
})
.transpose()?
.map(|pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
})
.transpose()
}
@ -82,17 +97,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}
/// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
fn get_non_outlier_pdu(
&self,
event_id: &EventId,
) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
Error::bad_database("Invalid pduid in eventid_pduid.")
})
})
.transpose()?
.map(|pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
})
.transpose()
}
@ -112,8 +131,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
serde_json::from_slice(&pdu).map_err(|_| {
Error::bad_database("Invalid PDU in db.")
})
})
.transpose()
},
@ -144,7 +164,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
fn get_pdu_json_from_id(
&self,
pdu_id: &[u8],
) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu)
@ -162,7 +185,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
&serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache
@ -184,7 +208,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
&serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
@ -203,7 +228,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
&serde_json::to_vec(pdu_json)
.expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(
@ -212,22 +238,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
));
}
self.pdu_cache
.lock()
.unwrap()
.remove(&(*pdu.event_id).to_owned());
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that happened before the
/// event with id `until` in reverse-chronological order.
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>(
&'a self,
user_id: &UserId,
room_id: &RoomId,
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
@ -238,7 +263,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
.map_err(|_| {
Error::bad_database("PDU in db is invalid.")
})?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
@ -254,7 +281,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
user_id: &UserId,
room_id: &RoomId,
from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
{
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
@ -265,7 +293,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
.map_err(|_| {
Error::bad_database("PDU in db is invalid.")
})?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
@ -286,13 +316,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let mut highlights_batch = Vec::new();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
@ -307,10 +337,12 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let last_u64 =
utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 = utils::u64_from_bytes(
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()
..pdu_id.len() - size_of::<u64>()],
);
if matches!(second_last_u64, Ok(0)) {
@ -330,7 +362,9 @@ fn count_to_id(
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.ok_or_else(|| {
Error::bad_database("Looked for bad shortroomid in timeline")
})?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();

View file

@ -1,14 +1,20 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
use crate::{
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
fn reset_notification_counts(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
@ -24,35 +30,51 @@ impl service::rooms::user::Data for KeyValueDatabase {
Ok(())
}
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
fn notification_count(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid notification count in db."))
})
self.userroomid_notificationcount.get(&userroom_id)?.map_or(
Ok(0),
|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid notification count in db.")
})
},
)
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
fn highlight_count(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
self.userroomid_highlightcount.get(&userroom_id)?.map_or(
Ok(0),
|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid highlight count in db.")
})
},
)
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
fn last_notification_read(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
@ -60,7 +82,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
Error::bad_database(
"Count in roomuserid_lastprivatereadupdate is invalid.",
)
})
})
.transpose()?
@ -86,7 +110,11 @@ impl service::rooms::user::Data for KeyValueDatabase {
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
@ -100,7 +128,10 @@ impl service::rooms::user::Data for KeyValueDatabase {
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
Error::bad_database(
"Invalid shortstatehash in \
roomsynctoken_shortstatehash",
)
})
})
.transpose()
@ -112,7 +143,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
@ -121,8 +152,12 @@ impl service::rooms::user::Data for KeyValueDatabase {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xff)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| {
Error::bad_database(
"Invalid userroomid_joined in db.",
)
})?
.0
+ 1;
@ -133,15 +168,24 @@ impl service::rooms::user::Data for KeyValueDatabase {
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not reversed)
// We use the default compare function because keys are sorted correctly
// (not reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
RoomId::parse(utils::string_from_bytes(&bytes).map_err(
|_| {
Error::bad_database(
"Invalid RoomId bytes in userroomid_joined",
)
},
)?)
.map_err(|_| {
Error::bad_database(
"Invalid RoomId in userroomid_joined.",
)
})
}),
))
}

View file

@ -12,31 +12,34 @@ use crate::{
impl service::sending::Data for KeyValueDatabase {
fn active_requests<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
) -> Box<
dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>>
+ 'a,
> {
Box::new(self.servercurrentevent_data.iter().map(|(key, v)| {
parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))
}))
}
fn active_requests_for<'a>(
&'a self,
outgoing_kind: &OutgoingKind,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a>
{
let prefix = outgoing_kind.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
Box::new(self.servercurrentevent_data.scan_prefix(prefix).map(
|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e)),
))
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
self.servercurrentevent_data.remove(&key)
}
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
fn delete_all_active_requests_for(
&self,
outgoing_kind: &OutgoingKind,
) -> Result<()> {
let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
@ -45,9 +48,13 @@ impl service::sending::Data for KeyValueDatabase {
Ok(())
}
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
fn delete_all_requests_for(
&self,
outgoing_kind: &OutgoingKind,
) -> Result<()> {
let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone())
{
self.servercurrentevent_data.remove(&key).unwrap();
}
@ -69,7 +76,9 @@ impl service::sending::Data for KeyValueDatabase {
if let SendingEventType::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
key.extend_from_slice(
&services().globals.next_count()?.to_be_bytes(),
);
}
let value = if let SendingEventType::Edu(value) = &event {
&**value
@ -79,24 +88,25 @@ impl service::sending::Data for KeyValueDatabase {
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>(
&'a self,
outgoing_kind: &OutgoingKind,
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a>
{
let prefix = outgoing_kind.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
return Box::new(self.servernameevent_data.scan_prefix(prefix).map(
|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k)),
));
}
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
fn mark_as_active(
&self,
events: &[(SendingEventType, Vec<u8>)],
) -> Result<()> {
for (e, key) in events {
let value = if let SendingEventType::Edu(value) = &e {
&**value
@ -110,18 +120,24 @@ impl service::sending::Data for KeyValueDatabase {
Ok(())
}
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
fn set_latest_educount(
&self,
server_name: &ServerName,
last_count: u64,
) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
self.servername_educount.get(server_name.as_bytes())?.map_or(
Ok(0),
|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid u64 in servername_educount.")
})
},
)
}
}
@ -132,15 +148,17 @@ fn parse_servercurrentevent(
) -> Result<(OutgoingKind, SendingEventType)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let event = parts.next().ok_or_else(|| {
Error::bad_database("Invalid bytes in servercurrentpdus.")
})?;
let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
Error::bad_database(
"Invalid server bytes in server_currenttransaction",
)
})?;
(
@ -152,23 +170,27 @@ fn parse_servercurrentevent(
},
)
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id = UserId::parse(user_string)
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let user_string = utils::string_from_bytes(user).map_err(|_| {
Error::bad_database("Invalid user string in servercurrentevent")
})?;
let user_id = UserId::parse(user_string).map_err(|_| {
Error::bad_database("Invalid user id in servercurrentevent")
})?;
let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let pushkey = parts.next().ok_or_else(|| {
Error::bad_database("Invalid bytes in servercurrentpdus.")
})?;
let pushkey_string =
utils::string_from_bytes(pushkey).map_err(|_| {
Error::bad_database("Invalid pushkey in servercurrentevent")
})?;
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let event = parts.next().ok_or_else(|| {
Error::bad_database("Invalid bytes in servercurrentpdus.")
})?;
(
OutgoingKind::Push(user_id, pushkey_string),
@ -180,20 +202,24 @@ fn parse_servercurrentevent(
},
)
} else {
let mut parts = key.splitn(2, |&b| b == 0xff);
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let event = parts.next().ok_or_else(|| {
Error::bad_database("Invalid bytes in servercurrentpdus.")
})?;
let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
Error::bad_database(
"Invalid server bytes in server_currenttransaction",
)
})?;
(
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction")
Error::bad_database(
"Invalid server string in server_currenttransaction",
)
})?),
if value.is_empty() {
SendingEventType::Pdu(event.to_vec())

View file

@ -11,9 +11,11 @@ impl service::transaction_ids::Data for KeyValueDatabase {
data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(
device_id.map(DeviceId::as_bytes).unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
@ -28,9 +30,11 @@ impl service::transaction_ids::Data for KeyValueDatabase {
txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(
device_id.map(DeviceId::as_bytes).unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction

View file

@ -13,13 +13,10 @@ impl service::uiaa::Data for KeyValueDatabase {
session: &str,
request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
self.userdevicesessionid_uiaarequest.write().unwrap().insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
@ -33,7 +30,11 @@ impl service::uiaa::Data for KeyValueDatabase {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.get(&(
user_id.to_owned(),
device_id.to_owned(),
session.to_owned(),
))
.map(ToOwned::to_owned)
}
@ -45,19 +46,19 @@ impl service::uiaa::Data for KeyValueDatabase {
uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
&serde_json::to_vec(&uiaainfo)
.expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
}
Ok(())
@ -70,9 +71,9 @@ impl service::uiaa::Data for KeyValueDatabase {
session: &str,
) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
@ -84,6 +85,8 @@ impl service::uiaa::Data for KeyValueDatabase {
"UIAA session does not exist.",
))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
.map_err(|_| {
Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")
})
}
}

View file

@ -5,8 +5,8 @@ use ruma::{
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch,
OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use tracing::warn;
@ -40,67 +40,100 @@ impl service::users::Data for KeyValueDatabase {
}
/// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xff);
fn find_from_token(
&self,
token: &str,
) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid.get(token.as_bytes())?.map_or(
Ok(None),
|bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts.next().ok_or_else(|| {
Error::bad_database("User ID in token_userdeviceid is invalid.")
Error::bad_database(
"User ID in token_userdeviceid is invalid.",
)
})?;
let device_bytes = parts.next().ok_or_else(|| {
Error::bad_database("Device ID in token_userdeviceid is invalid.")
Error::bad_database(
"Device ID in token_userdeviceid is invalid.",
)
})?;
Ok(Some((
UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid unicode.")
})?)
UserId::parse(
utils::string_from_bytes(user_bytes).map_err(|_| {
Error::bad_database(
"User ID in token_userdeviceid is invalid \
unicode.",
)
})?,
)
.map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid.")
Error::bad_database(
"User ID in token_userdeviceid is invalid.",
)
})?,
utils::string_from_bytes(device_bytes).map_err(|_| {
Error::bad_database("Device ID in token_userdeviceid is invalid.")
Error::bad_database(
"Device ID in token_userdeviceid is invalid.",
)
})?,
)))
})
},
)
}
/// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
fn iter<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.")
Error::bad_database(
"User ID in userid_password is invalid unicode.",
)
})?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
.map_err(|_| {
Error::bad_database("User ID in userid_password is invalid.")
})
}))
}
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is greater then zero.
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
.iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
.filter_map(|(username, pw)| {
get_username_with_valid_password(&username, &pw)
})
.collect();
Ok(users)
}
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
self.userid_password.get(user_id.as_bytes())?.map_or(
Ok(None),
|bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
Error::bad_database(
"Password hash in db is not valid string.",
)
})?))
})
},
)
}
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
fn set_password(
&self,
user_id: &UserId,
password: Option<&str>,
) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password
@ -120,17 +153,23 @@ impl service::users::Data for KeyValueDatabase {
/// Returns the `displayname` of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
self.userid_displayname.get(user_id.as_bytes())?.map_or(
Ok(None),
|bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Displayname in db is invalid.")
})?))
})
},
)
}
/// Sets a new `displayname` or removes it if `displayname` is `None`. You still need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
/// Sets a new `displayname` or removes it if `displayname` is `None`. You
/// still need to nofify all rooms of this change.
fn set_displayname(
&self,
user_id: &UserId,
displayname: Option<String>,
) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
@ -147,17 +186,25 @@ impl service::users::Data for KeyValueDatabase {
.get(user_id.as_bytes())?
.map(|bytes| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))
.map_err(|_| {
Error::bad_database("Avatar URL in db is invalid.")
})
.map(Into::into)
})
.transpose()
}
/// Sets a new `avatar_url` or removes it if `avatar_url` is `None`.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
fn set_avatar_url(
&self,
user_id: &UserId,
avatar_url: Option<OwnedMxcUri>,
) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
self.userid_avatarurl.insert(
user_id.as_bytes(),
avatar_url.to_string().as_bytes(),
)?;
} else {
self.userid_avatarurl.remove(user_id.as_bytes())?;
}
@ -170,8 +217,9 @@ impl service::users::Data for KeyValueDatabase {
self.userid_blurhash
.get(user_id.as_bytes())?
.map(|bytes| {
let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
let s = utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Avatar URL in db is invalid.")
})?;
Ok(s)
})
@ -179,7 +227,11 @@ impl service::users::Data for KeyValueDatabase {
}
/// Sets a new `avatar_url` or removes it if `avatar_url` is `None`.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
fn set_blurhash(
&self,
user_id: &UserId,
blurhash: Option<String>,
) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
@ -204,11 +256,10 @@ impl service::users::Data for KeyValueDatabase {
);
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userid_devicelistversion.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
@ -228,9 +279,13 @@ impl service::users::Data for KeyValueDatabase {
}
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
fn remove_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
@ -241,7 +296,7 @@ impl service::users::Data for KeyValueDatabase {
// Remove todevice events
let mut prefix = userdeviceid.clone();
prefix.push(0xff);
prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?;
@ -249,8 +304,7 @@ impl service::users::Data for KeyValueDatabase {
// TODO: Remove onetimekeys
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userid_devicelistversion.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?;
@ -263,29 +317,34 @@ impl service::users::Data for KeyValueDatabase {
user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
// All devices have metadata
Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
Error::bad_database("UserDevice ID in db is invalid.")
})?,
Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(
|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes.rsplit(|&b| b == 0xFF).next().ok_or_else(|| {
Error::bad_database("UserDevice ID in db is invalid.")
})?,
)
.map_err(|_| {
Error::bad_database(
"Device ID in userdeviceid_metadata is invalid.",
)
.map_err(|_| {
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
})?
.into())
}),
)
})?
.into())
},
))
}
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
fn set_token(
&self,
user_id: &UserId,
device_id: &DeviceId,
token: &str,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
assert!(
@ -300,10 +359,8 @@ impl service::users::Data for KeyValueDatabase {
}
// Assign token to user device combination
self.userdeviceid_token
.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?;
Ok(())
}
@ -316,17 +373,19 @@ impl service::users::Data for KeyValueDatabase {
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
assert!(
self.userdeviceid_metadata.get(&key)?.is_some(),
"devices should have metadata and this method should only be called with existing devices"
"devices should have metadata and this method should only be \
called with existing devices"
);
key.push(0xff);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update
// everything, because there are no wrapping quotation marks
// anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
@ -335,7 +394,8 @@ impl service::users::Data for KeyValueDatabase {
self.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
&serde_json::to_vec(&one_time_key_value)
.expect("OneTimeKey::to_vec always works"),
)?;
self.userid_lastonetimekeyupdate.insert(
@ -347,13 +407,16 @@ impl service::users::Data for KeyValueDatabase {
}
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
self.userid_lastonetimekeyupdate.get(user_id.as_bytes())?.map_or(
Ok(0),
|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
Error::bad_database(
"Count in roomid_lastroomactiveupdate is invalid.",
)
})
})
},
)
}
fn take_one_time_key(
@ -363,9 +426,9 @@ impl service::users::Data for KeyValueDatabase {
key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
// Annoying quotation mark
prefix.push(b'"');
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
@ -384,13 +447,18 @@ impl service::users::Data for KeyValueDatabase {
Ok((
serde_json::from_slice(
key.rsplit(|&b| b == 0xff)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
key.rsplit(|&b| b == 0xFF).next().ok_or_else(|| {
Error::bad_database(
"OneTimeKeyId in db is invalid.",
)
})?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
.map_err(|_| {
Error::bad_database("OneTimeKeyId in db is invalid.")
})?,
serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("OneTimeKeys in db are invalid.")
})?,
))
})
.transpose()
@ -402,25 +470,31 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new();
for algorithm in
self.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
Error::bad_database("OneTimeKey ID in db is invalid.")
})?,
)
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes.rsplit(|&b| b == 0xFF).next().ok_or_else(
|| {
Error::bad_database(
"OneTimeKey ID in db is invalid.",
)
},
)?,
)
})
.map_err(|_| {
Error::bad_database("DeviceKeyId in db is invalid.")
})?
.algorithm(),
)
})
{
*counts.entry(algorithm?).or_default() += UInt::from(1_u32);
}
@ -435,12 +509,13 @@ impl service::users::Data for KeyValueDatabase {
device_keys: &Raw<DeviceKeys>,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
&serde_json::to_vec(&device_keys)
.expect("DeviceKeys::to_vec always works"),
)?;
self.mark_device_key_update(user_id)?;
@ -458,30 +533,33 @@ impl service::users::Data for KeyValueDatabase {
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key")
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid self signing key",
)
})?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained no key.",
))?;
let self_signing_key_id =
self_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained no key.",
))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
@ -491,7 +569,8 @@ impl service::users::Data for KeyValueDatabase {
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self_signing_key_key
.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key.insert(
&self_signing_key_key,
@ -507,15 +586,19 @@ impl service::users::Data for KeyValueDatabase {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key")
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid user signing key",
)
})?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained no key.",
))?;
let user_signing_key_id =
user_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained no key.",
))?;
if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
@ -525,7 +608,8 @@ impl service::users::Data for KeyValueDatabase {
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
user_signing_key_key
.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key.insert(
&user_signing_key_key,
@ -551,32 +635,44 @@ impl service::users::Data for KeyValueDatabase {
sender_id: &UserId,
) -> Result<()> {
let mut key = target_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value =
serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest(
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
&self.keyid_key.get(&key)?.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to sign nonexistent key.",
))?)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
))?,
)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))?
.ok_or_else(|| {
Error::bad_database("key in keyid_key has no signatures field.")
})?
.as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))?
.ok_or_else(|| {
Error::bad_database(
"key in keyid_key has invalid signatures field.",
)
})?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))?
.ok_or_else(|| {
Error::bad_database(
"signatures in keyid_key for a user is invalid.",
)
})?
.insert(signature.0, signature.1.into());
self.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
&serde_json::to_vec(&cross_signing_key)
.expect("CrossSigningKey::to_vec always works"),
)?;
self.mark_device_key_update(target_id)?;
@ -591,7 +687,7 @@ impl service::users::Data for KeyValueDatabase {
to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let mut start = prefix.clone();
start.extend_from_slice(&(from + 1).to_be_bytes());
@ -603,26 +699,39 @@ impl service::users::Data for KeyValueDatabase {
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) {
&& if let Some(current) =
k.splitn(2, |&b| b == 0xFF).nth(1)
{
if let Ok(c) = utils::u64_from_bytes(current) {
c <= to
} else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
warn!(
"BadDatabase: Could not parse \
keychangeid_userid bytes"
);
false
}
} else {
warn!("BadDatabase: Could not parse keychangeid_userid");
warn!(
"BadDatabase: Could not parse \
keychangeid_userid"
);
false
}
})
.map(|(_, bytes)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"User ID in devicekeychangeid_userid is invalid unicode.",
)
})?)
UserId::parse(utils::string_from_bytes(&bytes).map_err(
|_| {
Error::bad_database(
"User ID in devicekeychangeid_userid is \
invalid unicode.",
)
},
)?)
.map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
Error::bad_database(
"User ID in devicekeychangeid_userid is invalid.",
)
})
}),
)
@ -647,14 +756,14 @@ impl service::users::Data for KeyValueDatabase {
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
@ -667,7 +776,7 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId,
) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
@ -683,11 +792,11 @@ impl service::users::Data for KeyValueDatabase {
master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let master_key = master_key.deserialize().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key")
})?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
@ -712,8 +821,12 @@ impl service::users::Data for KeyValueDatabase {
allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
let mut cross_signing_key = serde_json::from_slice::<
serde_json::Value,
>(&bytes)
.map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.")
})?;
clean_signatures(
&mut cross_signing_key,
sender_user,
@ -754,16 +867,20 @@ impl service::users::Data for KeyValueDatabase {
})
}
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
fn get_user_signing_key(
&self,
user_id: &UserId,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(
Ok(None),
|key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.")
})?))
})
})
},
)
}
fn add_to_device_event(
@ -775,9 +892,9 @@ impl service::users::Data for KeyValueDatabase {
content: serde_json::Value,
) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
@ -785,7 +902,8 @@ impl service::users::Data for KeyValueDatabase {
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
let value =
serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?;
@ -800,15 +918,14 @@ impl service::users::Data for KeyValueDatabase {
let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push(
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
events.push(serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("Event in todeviceid_events is invalid.")
})?);
}
Ok(events)
@ -821,9 +938,9 @@ impl service::users::Data for KeyValueDatabase {
until: u64,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
@ -836,8 +953,14 @@ impl service::users::Data for KeyValueDatabase {
.map(|(key, _)| {
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
utils::u64_from_bytes(
&key[key.len() - size_of::<u64>()..key.len()],
)
.map_err(|_| {
Error::bad_database(
"ToDeviceId has invalid count bytes.",
)
})?,
))
})
.filter_map(Result::ok)
@ -856,7 +979,7 @@ impl service::users::Data for KeyValueDatabase {
device: &Device,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
assert!(
@ -864,12 +987,12 @@ impl service::users::Data for KeyValueDatabase {
"this method should only be called with existing devices"
);
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userid_devicelistversion.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
&serde_json::to_vec(device)
.expect("Device::to_string always works"),
)?;
Ok(())
@ -882,26 +1005,32 @@ impl service::users::Data for KeyValueDatabase {
device_id: &DeviceId,
) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
self.userdeviceid_metadata.get(&userdeviceid)?.map_or(
Ok(None),
|bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
Error::bad_database(
"Metadata in userdeviceid_metadata is invalid.",
)
})?))
})
},
)
}
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(
Ok(None),
|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map_err(|_| {
Error::bad_database("Invalid devicelistversion in db.")
})
.map(Some)
})
},
)
}
fn all_devices_metadata<'a>(
@ -909,25 +1038,29 @@ impl service::users::Data for KeyValueDatabase {
user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes).map_err(|_| {
Error::bad_database("Device in userdeviceid_metadata is invalid.")
})
}),
)
Box::new(self.userdeviceid_metadata.scan_prefix(key).map(
|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes).map_err(|_| {
Error::bad_database(
"Device in userdeviceid_metadata is invalid.",
)
})
},
))
}
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
fn create_filter(
&self,
user_id: &UserId,
filter: &FilterDefinition,
) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter.insert(
@ -938,9 +1071,13 @@ impl service::users::Data for KeyValueDatabase {
Ok(filter_id)
}
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
fn get_filter(
&self,
user_id: &UserId,
filter_id: &str,
) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
@ -956,9 +1093,12 @@ impl service::users::Data for KeyValueDatabase {
/// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed.
/// If [`utils::string_from_bytes`] returns an error that username will be skipped
/// and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
/// If [`utils::string_from_bytes`] returns an error that username will be
/// skipped and the error will be logged.
fn get_username_with_valid_password(
username: &[u8],
password: &[u8],
) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
None
@ -967,7 +1107,8 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<
Ok(u) => Some(u),
Err(e) => {
warn!(
"Failed to parse username while calling get_local_users(): {}",
"Failed to parse username while calling \
get_local_users(): {}",
e.to_string()
);
None

View file

@ -12,7 +12,9 @@ use axum::{
routing::{any, get, on, MethodFilter},
Router,
};
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
};
use figment::{
providers::{Env, Format, Toml},
Figment,
@ -50,16 +52,16 @@ use api::{client_server, server_server};
pub(crate) use config::Config;
pub(crate) use database::KeyValueDatabase;
pub(crate) use service::{pdu::PduEvent, Services};
pub(crate) use utils::error::{Error, Result};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
use tikv_jemallocator::Jemalloc;
pub(crate) use utils::error::{Error, Result};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
pub(crate) static SERVICES: RwLock<Option<&'static Services>> = RwLock::new(None);
pub(crate) static SERVICES: RwLock<Option<&'static Services>> =
RwLock::new(None);
/// Convenient access to the global [`Services`] instance
pub(crate) fn services() -> &'static Services {
@ -71,9 +73,9 @@ pub(crate) fn services() -> &'static Services {
/// Returns the current version of the crate with extra info if supplied
///
/// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string to
/// include it in parenthesis after the SemVer version. A common value are git
/// commit hashes.
/// Set the environment variable `GRAPEVINE_VERSION_EXTRA` to any UTF-8 string
/// to include it in parenthesis after the SemVer version. A common value are
/// git commit hashes.
fn version() -> String {
let cargo_pkg_version = env!("CARGO_PKG_VERSION");
@ -91,7 +93,8 @@ async fn main() {
let raw_config = Figment::new()
.merge(
Toml::file(Env::var("GRAPEVINE_CONFIG").expect(
"The GRAPEVINE_CONFIG env var needs to be set. Example: /etc/grapevine.toml",
"The GRAPEVINE_CONFIG env var needs to be set. Example: \
/etc/grapevine.toml",
))
.nested(),
)
@ -100,7 +103,10 @@ async fn main() {
let config = match raw_config.extract::<Config>() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occurred: {e}");
eprintln!(
"It looks like your config is invalid. The following error \
occurred: {e}"
);
std::process::exit(1);
}
};
@ -108,7 +114,9 @@ async fn main() {
config.warn_deprecated();
if config.allow_jaeger {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
opentelemetry::global::set_text_map_propagator(
opentelemetry_jaeger::Propagator::new(),
);
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_auto_split_batch(true)
.with_service_name("grapevine")
@ -120,7 +128,8 @@ async fn main() {
Ok(s) => s,
Err(e) => {
eprintln!(
"It looks like your log config is invalid. The following error occurred: {e}"
"It looks like your log config is invalid. The following \
error occurred: {e}"
);
EnvFilter::try_new("warn").unwrap()
}
@ -146,7 +155,10 @@ async fn main() {
let filter_layer = match EnvFilter::try_new(&config.log) {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
eprintln!(
"It looks like your config is invalid. The following \
error occured while parsing it: {e}"
);
EnvFilter::try_new("warn").unwrap()
}
};
@ -163,7 +175,8 @@ async fn main() {
// * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6
// * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741
#[cfg(unix)]
maximize_fd_limit().expect("should be able to increase the soft limit to the hard limit");
maximize_fd_limit()
.expect("should be able to increase the soft limit to the hard limit");
info!("Loading database");
if let Err(error) = KeyValueDatabase::load_or_create(config).await {
@ -190,17 +203,19 @@ async fn run_server() -> io::Result<()> {
let middlewares = ServiceBuilder::new()
.sensitive_headers([header::AUTHORIZATION])
.layer(axum::middleware::from_fn(spawn_task))
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
.layer(TraceLayer::new_for_http().make_span_with(
|request: &http::Request<_>| {
let path = if let Some(path) =
request.extensions().get::<MatchedPath>()
{
path.as_str()
} else {
request.uri().path()
};
tracing::info_span!("http_request", %path)
}),
)
},
))
.layer(axum::middleware::from_fn(unrecognized_method))
.layer(
CorsLayer::new()
@ -235,7 +250,8 @@ async fn run_server() -> io::Result<()> {
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let conf =
RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
#[cfg(feature = "systemd")]
@ -411,9 +427,10 @@ fn routes(config: &Config) -> Router {
.ruma_route(c2s::get_relating_events_route)
.ruma_route(c2s::get_hierarchy_route);
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route.
// These two endpoints also allow trailing slashes.
// Ruma doesn't have support for multiple paths for a single endpoint yet,
// and these routes share one Ruma request / response type pair with
// {get,send}_state_event_for_key_route. These two endpoints also allow
// trailing slashes.
let router = router
.route(
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
@ -483,9 +500,7 @@ fn routes(config: &Config) -> Router {
async fn shutdown_signal(handle: ServerHandle) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
@ -554,9 +569,9 @@ impl RouterExt for Router {
}
pub(crate) trait RumaHandler<T> {
// Can't transform to a handler without boxing or relying on the nightly-only
// impl-trait-in-traits feature. Moving a small amount of extra logic into the trait
// allows bypassing both.
// Can't transform to a handler without boxing or relying on the
// nightly-only impl-trait-in-traits feature. Moving a small amount of
// extra logic into the trait allows bypassing both.
fn add_to_router(self, router: Router) -> Router;
}

View file

@ -4,10 +4,9 @@ use std::{
};
use lru_cache::LruCache;
use tokio::sync::{broadcast, Mutex};
use tokio::sync::{broadcast, Mutex, RwLock};
use crate::{Config, Result};
use tokio::sync::RwLock;
pub(crate) mod account_data;
pub(crate) mod admin;
@ -58,10 +57,14 @@ impl Services {
) -> Result<Self> {
Ok(Self {
appservice: appservice::Service::build(db)?,
pusher: pusher::Service { db },
pusher: pusher::Service {
db,
},
rooms: rooms::Service {
alias: db,
auth_chain: rooms::auth_chain::Service { db },
auth_chain: rooms::auth_chain::Service {
db,
},
directory: db,
edus: rooms::edus::Service {
read_receipt: db,
@ -78,10 +81,14 @@ impl Services {
},
metadata: db,
outlier: db,
pdu_metadata: rooms::pdu_metadata::Service { db },
pdu_metadata: rooms::pdu_metadata::Service {
db,
},
search: db,
short: db,
state: rooms::state::Service { db },
state: rooms::state::Service {
db,
},
state_accessor: rooms::state_accessor::Service {
db,
#[allow(
@ -101,7 +108,9 @@ impl Services {
(100.0 * config.cache_capacity_modifier) as usize,
)),
},
state_cache: rooms::state_cache::Service { db },
state_cache: rooms::state_cache::Service {
db,
},
state_compressor: rooms::state_compressor::Service {
db,
#[allow(
@ -117,14 +126,18 @@ impl Services {
db,
lasttimelinecount_cache: Mutex::new(HashMap::new()),
},
threads: rooms::threads::Service { db },
threads: rooms::threads::Service {
db,
},
spaces: rooms::spaces::Service {
roomid_spacechunk_cache: Mutex::new(LruCache::new(200)),
},
user: db,
},
transaction_ids: db,
uiaa: uiaa::Service { db },
uiaa: uiaa::Service {
db,
},
users: users::Service {
db,
connections: StdMutex::new(BTreeMap::new()),
@ -132,14 +145,18 @@ impl Services {
account_data: db,
admin: admin::Service::build(),
key_backups: db,
media: media::Service { db },
media: media::Service {
db,
},
sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config)?,
})
}
async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let lazy_load_waiting =
self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self
.rooms
.state_accessor
@ -154,21 +171,12 @@ impl Services {
.lock()
.unwrap()
.len();
let stateinfo_cache = self
.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.len();
let lasttimelinecount_cache = self
.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
let stateinfo_cache =
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
let lasttimelinecount_cache =
self.rooms.timeline.lasttimelinecount_cache.lock().await.len();
let roomid_spacechunk_cache =
self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
format!(
"\
@ -177,18 +185,13 @@ server_visibility_cache: {server_visibility_cache}
user_visibility_cache: {user_visibility_cache}
stateinfo_cache: {stateinfo_cache}
lasttimelinecount_cache: {lasttimelinecount_cache}
roomid_spacechunk_cache: {roomid_spacechunk_cache}\
"
roomid_spacechunk_cache: {roomid_spacechunk_cache}"
)
}
async fn clear_caches(&self, amount: u32) {
if amount > 0 {
self.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.await
.clear();
self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear();
}
if amount > 1 {
self.rooms
@ -207,28 +210,13 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\
.clear();
}
if amount > 3 {
self.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.clear();
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
}
if amount > 4 {
self.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.clear();
self.rooms.timeline.lasttimelinecount_cache.lock().await.clear();
}
if amount > 5 {
self.rooms
.spaces
.roomid_spacechunk_cache
.lock()
.await
.clear();
self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear();
}
}
}

View file

@ -1,14 +1,16 @@
use std::collections::HashMap;
use crate::Result;
use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use crate::Result;
pub(crate) trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the previous entry.
/// Places one event in the account data of the user and removes the
/// previous entry.
fn update(
&self,
room_id: Option<&RoomId>,

View file

@ -16,7 +16,9 @@ use ruma::{
canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
history_visibility::{
HistoryVisibility, RoomHistoryVisibilityEventContent,
},
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
message::RoomMessageEventContent,
@ -26,19 +28,19 @@ use ruma::{
},
TimelineEventType,
},
EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId,
ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::warn;
use super::pdu::PduBuilder;
use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
services, utils, Error, PduEvent, Result,
};
use super::pdu::PduBuilder;
#[cfg_attr(test, derive(Debug))]
#[derive(Parser)]
#[command(name = "@grapevine:server.name:", version = env!("CARGO_PKG_VERSION"))]
@ -46,11 +48,12 @@ enum AdminCommand {
#[command(verbatim_doc_comment)]
/// Register an appservice using its registration YAML
///
/// This command needs a YAML generated by an appservice (such as a bridge),
/// which must be provided in a Markdown code-block below the command.
/// This command needs a YAML generated by an appservice (such as a
/// bridge), which must be provided in a Markdown code-block below the
/// command.
///
/// Registering a new bridge using the ID of an existing bridge will replace
/// the old one.
/// Registering a new bridge using the ID of an existing bridge will
/// replace the old one.
///
/// [commandbody]()
/// # ```
@ -95,8 +98,9 @@ enum AdminCommand {
///
/// Users will not be removed from joined rooms by default.
/// Can be overridden with --leave-rooms flag.
/// Removing a mass amount of users from a room may cause a significant amount of leave events.
/// The time to leave rooms may depend significantly on joined rooms and servers.
/// Removing a mass amount of users from a room may cause a significant
/// amount of leave events. The time to leave rooms may depend
/// significantly on joined rooms and servers.
///
/// [commandbody]()
/// # ```
@ -138,11 +142,17 @@ enum AdminCommand {
/// Print database memory usage statistics
MemoryUsage,
/// Clears all of Grapevine's database caches with index smaller than the amount
ClearDatabaseCaches { amount: u32 },
/// Clears all of Grapevine's database caches with index smaller than the
/// amount
ClearDatabaseCaches {
amount: u32,
},
/// Clears all of Grapevine's service caches with index smaller than the amount
ClearServiceCaches { amount: u32 },
/// Clears all of Grapevine's service caches with index smaller than the
/// amount
ClearServiceCaches {
amount: u32,
},
/// Show configuration values
ShowConfig,
@ -162,9 +172,13 @@ enum AdminCommand {
},
/// Disables incoming federation handling for a room.
DisableRoom { room_id: Box<RoomId> },
DisableRoom {
room_id: Box<RoomId>,
},
/// Enables incoming federation handling for a room again.
EnableRoom { room_id: Box<RoomId> },
EnableRoom {
room_id: Box<RoomId>,
},
/// Verify json signatures
/// [commandbody]()
@ -267,31 +281,38 @@ impl Service {
}
pub(crate) fn process_message(&self, room_message: String) {
self.sender
.send(AdminRoomEvent::ProcessMessage(room_message))
.unwrap();
self.sender.send(AdminRoomEvent::ProcessMessage(room_message)).unwrap();
}
pub(crate) fn send_message(&self, message_content: RoomMessageEventContent) {
self.sender
.send(AdminRoomEvent::SendMessage(message_content))
.unwrap();
pub(crate) fn send_message(
&self,
message_content: RoomMessageEventContent,
) {
self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap();
}
// Parse and process a message from the admin room
async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent {
async fn process_admin_message(
&self,
room_message: String,
) -> RoomMessageEventContent {
let mut lines = room_message.lines().filter(|l| !l.trim().is_empty());
let command_line = lines.next().expect("each string has at least one line");
let command_line =
lines.next().expect("each string has at least one line");
let body: Vec<_> = lines.collect();
let admin_command = match Self::parse_admin_command(command_line) {
Ok(command) => command,
Err(error) => {
let server_name = services().globals.server_name();
let message = error.replace("server.name", server_name.as_str());
let message =
error.replace("server.name", server_name.as_str());
let html_message = Self::usage_to_html(&message, server_name);
return RoomMessageEventContent::text_html(message, html_message);
return RoomMessageEventContent::text_html(
message,
html_message,
);
}
};
@ -299,22 +320,28 @@ impl Service {
Ok(reply_message) => reply_message,
Err(error) => {
let markdown_message = format!(
"Encountered an error while handling the command:\n\
```\n{error}\n```",
"Encountered an error while handling the \
command:\n```\n{error}\n```",
);
let html_message = format!(
"Encountered an error while handling the command:\n\
<pre>\n{error}\n</pre>",
"Encountered an error while handling the \
command:\n<pre>\n{error}\n</pre>",
);
RoomMessageEventContent::text_html(markdown_message, html_message)
RoomMessageEventContent::text_html(
markdown_message,
html_message,
)
}
}
}
// Parse chat messages from the admin room into an AdminCommand object
fn parse_admin_command(command_line: &str) -> std::result::Result<AdminCommand, String> {
// Note: argv[0] is `@grapevine:servername:`, which is treated as the main command
fn parse_admin_command(
command_line: &str,
) -> std::result::Result<AdminCommand, String> {
// Note: argv[0] is `@grapevine:servername:`, which is treated as the
// main command
let mut argv: Vec<_> = command_line.split_whitespace().collect();
// Replace `help command` with `command --help`
@ -342,18 +369,26 @@ impl Service {
) -> Result<RoomMessageEventContent> {
let reply_message_content = match command {
AdminCommand::RegisterAppservice => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```"
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{
let appservice_config = body[1..body.len() - 1].join("\n");
let parsed_config = serde_yaml::from_str::<Registration>(&appservice_config);
let parsed_config = serde_yaml::from_str::<Registration>(
&appservice_config,
);
match parsed_config {
Ok(yaml) => match services().appservice.register_appservice(yaml).await {
Ok(id) => RoomMessageEventContent::text_plain(format!(
"Appservice registered with ID: {id}."
)),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Failed to register appservice: {e}"
)),
Ok(yaml) => match services()
.appservice
.register_appservice(yaml)
.await
{
Ok(id) => RoomMessageEventContent::text_plain(
format!("Appservice registered with ID: {id}."),
),
Err(e) => RoomMessageEventContent::text_plain(
format!("Failed to register appservice: {e}"),
),
},
Err(e) => RoomMessageEventContent::text_plain(format!(
"Could not parse appservice config: {e}"
@ -361,7 +396,8 @@ impl Service {
}
} else {
RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
"Expected code block in command body. Add --help for \
details.",
)
}
}
@ -372,7 +408,9 @@ impl Service {
.unregister_appservice(&appservice_identifier)
.await
{
Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."),
Ok(()) => RoomMessageEventContent::text_plain(
"Appservice unregistered.",
),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Failed to unregister appservice: {e}"
)),
@ -407,17 +445,25 @@ impl Service {
);
RoomMessageEventContent::text_plain(output)
}
AdminCommand::ListLocalUsers => match services().users.list_local_users() {
AdminCommand::ListLocalUsers => match services()
.users
.list_local_users()
{
Ok(users) => {
let mut msg: String = format!("Found {} local user account(s):\n", users.len());
let mut msg: String = format!(
"Found {} local user account(s):\n",
users.len()
);
msg += &users.join("\n");
RoomMessageEventContent::text_plain(&msg)
}
Err(e) => RoomMessageEventContent::text_plain(e.to_string()),
},
AdminCommand::IncomingFederation => {
let map = services().globals.roomid_federationhandletime.read().await;
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len());
let map =
services().globals.roomid_federationhandletime.read().await;
let mut msg: String =
format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() {
let elapsed = i.elapsed();
@ -431,17 +477,26 @@ impl Service {
}
RoomMessageEventContent::text_plain(&msg)
}
AdminCommand::GetAuthChain { event_id } => {
AdminCommand::GetAuthChain {
event_id,
} => {
let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? {
if let Some(event) =
services().rooms.timeline.get_pdu_json(&event_id)?
{
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
.ok_or_else(|| {
Error::bad_database("Invalid event in database")
})?;
let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| {
Error::bad_database("Invalid room id field in event in database")
})?;
let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| {
Error::bad_database(
"Invalid room id field in event in database",
)
})?;
let start = Instant::now();
let count = services()
.rooms
@ -458,29 +513,47 @@ impl Service {
}
}
AdminCommand::ParsePdu => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```"
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{
let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) {
Ok(value) => {
match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) {
match ruma::signatures::reference_hash(
&value,
&RoomVersionId::V6,
) {
Ok(hash) => {
let event_id = EventId::parse(format!("${hash}"));
let event_id =
EventId::parse(format!("${hash}"));
match serde_json::from_value::<PduEvent>(
serde_json::to_value(value).expect("value is json"),
serde_json::to_value(value)
.expect("value is json"),
) {
Ok(pdu) => RoomMessageEventContent::text_plain(format!(
"EventId: {event_id:?}\n{pdu:#?}"
)),
Err(e) => RoomMessageEventContent::text_plain(format!(
"EventId: {event_id:?}\nCould not parse event: {e}"
)),
Ok(pdu) => {
RoomMessageEventContent::text_plain(
format!(
"EventId: {event_id:?}\\
n{pdu:#?}"
),
)
}
Err(e) => {
RoomMessageEventContent::text_plain(
format!(
"EventId: {event_id:?}\\
nCould not parse event: \
{e}"
),
)
}
}
}
Err(e) => RoomMessageEventContent::text_plain(format!(
"Could not parse PDU JSON: {e:?}"
)),
Err(e) => RoomMessageEventContent::text_plain(
format!("Could not parse PDU JSON: {e:?}"),
),
}
}
Err(e) => RoomMessageEventContent::text_plain(format!(
@ -488,10 +561,14 @@ impl Service {
)),
}
} else {
RoomMessageEventContent::text_plain("Expected code block in command body.")
RoomMessageEventContent::text_plain(
"Expected code block in command body.",
)
}
}
AdminCommand::GetPdu { event_id } => {
AdminCommand::GetPdu {
event_id,
} => {
let mut outlier = false;
let mut pdu_json = services()
.rooms
@ -499,7 +576,8 @@ impl Service {
.get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() {
outlier = true;
pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?;
pdu_json =
services().rooms.timeline.get_pdu_json(&event_id)?;
}
match pdu_json {
Some(json) => {
@ -516,7 +594,8 @@ impl Service {
json_text
),
format!(
"<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n",
"<p>{}</p>\n<pre><code \
class=\"language-json\">{}\n</code></pre>\n",
if outlier {
"PDU is outlier"
} else {
@ -526,7 +605,9 @@ impl Service {
),
)
}
None => RoomMessageEventContent::text_plain("PDU not found."),
None => {
RoomMessageEventContent::text_plain("PDU not found.")
}
}
}
AdminCommand::MemoryUsage => {
@ -537,30 +618,42 @@ impl Service {
"Services:\n{response1}\n\nDatabase:\n{response2}"
))
}
AdminCommand::ClearDatabaseCaches { amount } => {
AdminCommand::ClearDatabaseCaches {
amount,
} => {
services().globals.db.clear_caches(amount);
RoomMessageEventContent::text_plain("Done.")
}
AdminCommand::ClearServiceCaches { amount } => {
AdminCommand::ClearServiceCaches {
amount,
} => {
services().clear_caches(amount).await;
RoomMessageEventContent::text_plain("Done.")
}
AdminCommand::ShowConfig => {
// Construct and send the response
RoomMessageEventContent::text_plain(format!("{}", services().globals.config))
RoomMessageEventContent::text_plain(format!(
"{}",
services().globals.config
))
}
AdminCommand::ResetPassword { username } => {
AdminCommand::ResetPassword {
username,
} => {
let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(),
services().globals.server_name(),
) {
Ok(id) => id,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
return Ok(RoomMessageEventContent::text_plain(
format!(
"The supplied username is not a valid \
username: {e}"
),
))
}
};
@ -589,23 +682,29 @@ impl Service {
));
}
let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
let new_password =
utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
match services()
.users
.set_password(&user_id, Some(new_password.as_str()))
{
Ok(()) => RoomMessageEventContent::text_plain(format!(
"Successfully reset the password for user {user_id}: {new_password}"
"Successfully reset the password for user {user_id}: \
{new_password}"
)),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Couldn't reset the password for user {user_id}: {e}"
)),
}
}
AdminCommand::CreateUser { username, password } => {
let password =
password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH));
AdminCommand::CreateUser {
username,
password,
} => {
let password = password.unwrap_or_else(|| {
utils::random_string(AUTO_GEN_PASSWORD_LENGTH)
});
// Validate user id
let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(),
@ -613,9 +712,12 @@ impl Service {
) {
Ok(id) => id,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
return Ok(RoomMessageEventContent::text_plain(
format!(
"The supplied username is not a valid \
username: {e}"
),
))
}
};
if user_id.is_historical() {
@ -647,24 +749,32 @@ impl Service {
.into(),
&serde_json::to_value(PushRulesEvent {
content: PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id),
global: ruma::push::Ruleset::server_default(
&user_id,
),
},
})
.expect("to json value always works"),
)?;
// we dont add a device since we're not the user, just the creator
// we dont add a device since we're not the user, just the
// creator
// Inhibit login does not work for guests
RoomMessageEventContent::text_plain(format!(
"Created user with user_id: {user_id} and password: {password}"
"Created user with user_id: {user_id} and password: \
{password}"
))
}
AdminCommand::DisableRoom { room_id } => {
AdminCommand::DisableRoom {
room_id,
} => {
services().rooms.metadata.disable_room(&room_id, true)?;
RoomMessageEventContent::text_plain("Room disabled.")
}
AdminCommand::EnableRoom { room_id } => {
AdminCommand::EnableRoom {
room_id,
} => {
services().rooms.metadata.disable_room(&room_id, false)?;
RoomMessageEventContent::text_plain("Room enabled.")
}
@ -677,13 +787,16 @@ impl Service {
RoomMessageEventContent::text_plain(format!(
"User {user_id} doesn't exist on this server"
))
} else if user_id.server_name() != services().globals.server_name() {
} else if user_id.server_name()
!= services().globals.server_name()
{
RoomMessageEventContent::text_plain(format!(
"User {user_id} is not from this server"
))
} else {
RoomMessageEventContent::text_plain(format!(
"Making {user_id} leave all rooms before deactivation..."
"Making {user_id} leave all rooms before \
deactivation..."
));
services().users.deactivate_account(&user_id)?;
@ -697,10 +810,18 @@ impl Service {
))
}
}
AdminCommand::DeactivateAll { leave_rooms, force } => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```"
AdminCommand::DeactivateAll {
leave_rooms,
force,
} => {
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{
let users = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>();
let users = body
.clone()
.drain(1..body.len() - 1)
.collect::<Vec<_>>();
let mut user_ids = Vec::new();
let mut remote_ids = Vec::new();
@ -710,7 +831,9 @@ impl Service {
for &user in &users {
match <&UserId>::try_from(user) {
Ok(user_id) => {
if user_id.server_name() != services().globals.server_name() {
if user_id.server_name()
!= services().globals.server_name()
{
remote_ids.push(user_id);
} else if !services().users.exists(user_id)? {
non_existant_ids.push(user_id);
@ -727,39 +850,59 @@ impl Service {
let mut markdown_message = String::new();
let mut html_message = String::new();
if !invalid_users.is_empty() {
markdown_message.push_str("The following user ids are not valid:\n```\n");
html_message.push_str("The following user ids are not valid:\n<pre>\n");
markdown_message.push_str(
"The following user ids are not valid:\n```\n",
);
html_message.push_str(
"The following user ids are not valid:\n<pre>\n",
);
for invalid_user in invalid_users {
writeln!(markdown_message, "{invalid_user}")
.expect("write to in-memory buffer should succeed");
writeln!(html_message, "{invalid_user}")
.expect("write to in-memory buffer should succeed");
.expect(
"write to in-memory buffer should succeed",
);
writeln!(html_message, "{invalid_user}").expect(
"write to in-memory buffer should succeed",
);
}
markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n");
}
if !remote_ids.is_empty() {
markdown_message
.push_str("The following users are not from this server:\n```\n");
html_message
.push_str("The following users are not from this server:\n<pre>\n");
markdown_message.push_str(
"The following users are not from this \
server:\n```\n",
);
html_message.push_str(
"The following users are not from this \
server:\n<pre>\n",
);
for remote_id in remote_ids {
writeln!(markdown_message, "{remote_id}")
.expect("write to in-memory buffer should succeed");
writeln!(html_message, "{remote_id}")
.expect("write to in-memory buffer should succeed");
writeln!(markdown_message, "{remote_id}").expect(
"write to in-memory buffer should succeed",
);
writeln!(html_message, "{remote_id}").expect(
"write to in-memory buffer should succeed",
);
}
markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n");
}
if !non_existant_ids.is_empty() {
markdown_message.push_str("The following users do not exist:\n```\n");
html_message.push_str("The following users do not exist:\n<pre>\n");
markdown_message.push_str(
"The following users do not exist:\n```\n",
);
html_message.push_str(
"The following users do not exist:\n<pre>\n",
);
for non_existant_id in non_existant_ids {
writeln!(markdown_message, "{non_existant_id}")
.expect("write to in-memory buffer should succeed");
writeln!(html_message, "{non_existant_id}")
.expect("write to in-memory buffer should succeed");
.expect(
"write to in-memory buffer should succeed",
);
writeln!(html_message, "{non_existant_id}").expect(
"write to in-memory buffer should succeed",
);
}
markdown_message.push_str("```\n\n");
html_message.push_str("</pre>\n\n");
@ -775,21 +918,24 @@ impl Service {
let mut admins = Vec::new();
if !force {
user_ids.retain(|&user_id| match services().users.is_admin(user_id) {
Ok(is_admin) => {
if is_admin {
admins.push(user_id.localpart());
false
} else {
true
user_ids.retain(|&user_id| {
match services().users.is_admin(user_id) {
Ok(is_admin) => {
if is_admin {
admins.push(user_id.localpart());
false
} else {
true
}
}
Err(_) => false,
}
Err(_) => false,
});
}
for &user_id in &user_ids {
if services().users.deactivate_account(user_id).is_ok() {
if services().users.deactivate_account(user_id).is_ok()
{
deactivation_count += 1;
}
}
@ -807,16 +953,25 @@ impl Service {
"Deactivated {deactivation_count} accounts."
))
} else {
RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", ")))
RoomMessageEventContent::text_plain(format!(
"Deactivated {} accounts.\nSkipped admin \
accounts: {:?}. Use --force to deactivate admin \
accounts",
deactivation_count,
admins.join(", ")
))
}
} else {
RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
"Expected code block in command body. Add --help for \
details.",
)
}
}
AdminCommand::SignJson => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```"
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{
let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) {
@ -827,20 +982,26 @@ impl Service {
&mut value,
)
.expect("our request json is what ruma expects");
let json_text = serde_json::to_string_pretty(&value)
.expect("canonical json is valid json");
let json_text =
serde_json::to_string_pretty(&value)
.expect("canonical json is valid json");
RoomMessageEventContent::text_plain(json_text)
}
Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Invalid json: {e}"
)),
}
} else {
RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
"Expected code block in command body. Add --help for \
details.",
)
}
}
AdminCommand::VerifyJson => {
if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```"
if body.len() > 2
&& body[0].trim() == "```"
&& body.last().unwrap().trim() == "```"
{
let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) {
@ -850,22 +1011,35 @@ impl Service {
services()
.rooms
.event_handler
.fetch_required_signing_keys(&value, &pub_key_map)
.fetch_required_signing_keys(
&value,
&pub_key_map,
)
.await?;
let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) {
Ok(()) => RoomMessageEventContent::text_plain("Signature correct"),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Signature verification failed: {e}"
)),
match ruma::signatures::verify_json(
&pub_key_map,
&value,
) {
Ok(()) => RoomMessageEventContent::text_plain(
"Signature correct",
),
Err(e) => RoomMessageEventContent::text_plain(
format!(
"Signature verification failed: {e}"
),
),
}
}
Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")),
Err(e) => RoomMessageEventContent::text_plain(format!(
"Invalid json: {e}"
)),
}
} else {
RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
"Expected code block in command body. Add --help for \
details.",
)
}
}
@ -876,7 +1050,8 @@ impl Service {
// Utility to turn clap's `--help` text to HTML.
fn usage_to_html(text: &str, server_name: &ServerName) -> String {
// Replace `@grapevine:servername:-subcmdname` with `@grapevine:servername: subcmdname`
// Replace `@grapevine:servername:-subcmdname` with
// `@grapevine:servername: subcmdname`
let localpart = if services().globals.config.conduit_compat {
"conduit"
} else {
@ -892,11 +1067,13 @@ impl Service {
let text = text.replace("SUBCOMMAND", "COMMAND");
let text = text.replace("subcommand", "command");
// Escape option names (e.g. `<element-id>`) since they look like HTML tags
// Escape option names (e.g. `<element-id>`) since they look like HTML
// tags
let text = text.replace('<', "&lt;").replace('>', "&gt;");
// Italicize the first line (command name and version text)
let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail");
let re =
Regex::new("^(.*?)\n").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<em>$1</em>\n");
// Unmerge wrapped lines
@ -911,8 +1088,8 @@ impl Service {
.expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<code>$1</code>: $4");
// Look for a `[commandbody]()` tag. If it exists, use all lines below it that
// start with a `#` in the USAGE section.
// Look for a `[commandbody]()` tag. If it exists, use all lines below
// it that start with a `#` in the USAGE section.
let mut text_lines: Vec<&str> = text.lines().collect();
let command_body = text_lines
.iter()
@ -936,8 +1113,11 @@ impl Service {
// This makes the usage of e.g. `register-appservice` more accurate
let re = Regex::new("(?m)^USAGE:\n (.*?)\n\n")
.expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>")
.replace("[commandbodyblock]", &command_body)
re.replace_all(
&text,
"USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>",
)
.replace("[commandbodyblock]", &command_body)
};
// Add HTML line-breaks
@ -949,8 +1129,9 @@ impl Service {
/// Create the admin room.
///
/// Users in this room are considered admins by grapevine, and the room can be
/// used to issue admin commands by talking to the server user inside it.
/// Users in this room are considered admins by grapevine, and the room can
/// be used to issue admin commands by talking to the server user inside
/// it.
#[allow(clippy::too_many_lines)]
pub(crate) async fn create_admin_room(&self) -> Result<()> {
let room_id = RoomId::new(services().globals.server_name());
@ -993,7 +1174,9 @@ impl Service {
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => RoomCreateEventContent::new_v1(grapevine_user.clone()),
| RoomVersionId::V10 => {
RoomCreateEventContent::new_v1(grapevine_user.clone())
}
RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => unreachable!("Validity of room version already checked"),
};
@ -1008,7 +1191,8 @@ impl Service {
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"),
content: to_raw_value(&content)
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
@ -1079,8 +1263,10 @@ impl Service {
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite))
.expect("event is valid, we just created it"),
content: to_raw_value(&RoomJoinRulesEventContent::new(
JoinRule::Invite,
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
@ -1098,9 +1284,11 @@ impl Service {
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
HistoryVisibility::Shared,
))
content: to_raw_value(
&RoomHistoryVisibilityEventContent::new(
HistoryVisibility::Shared,
),
)
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
@ -1134,15 +1322,18 @@ impl Service {
.await?;
// 5. Events implied by name and topic
let room_name = format!("{} Admin Room", services().globals.server_name());
let room_name =
format!("{} Admin Room", services().globals.server_name());
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(room_name))
.expect("event is valid, we just created it"),
content: to_raw_value(&RoomNameEventContent::new(
room_name,
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
@ -1160,7 +1351,10 @@ impl Service {
PduBuilder {
event_type: TimelineEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent {
topic: format!("Manage {}", services().globals.server_name()),
topic: format!(
"Manage {}",
services().globals.server_name()
),
})
.expect("event is valid, we just created it"),
unsigned: None,
@ -1174,9 +1368,10 @@ impl Service {
.await?;
// 6. Room alias
let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
let alias: OwnedRoomAliasId =
format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
services()
.rooms
@ -1206,7 +1401,8 @@ impl Service {
/// Gets the room ID of the admin room
///
/// Errors are propagated from the database, and will have None if there is no admin room
/// Errors are propagated from the database, and will have None if there is
/// no admin room
// Allowed because this function uses `services()`
#[allow(clippy::unused_self)]
pub(crate) fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> {
@ -1215,10 +1411,7 @@ impl Service {
.try_into()
.expect("#admins:server_name is a valid alias name");
services()
.rooms
.alias
.resolve_local_alias(&admin_room_alias)
services().rooms.alias.resolve_local_alias(&admin_room_alias)
}
/// Invite the user to the grapevine admin room.
@ -1356,11 +1549,13 @@ mod test {
}
fn get_help_inner(input: &str) {
let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.unwrap_err()
.to_string();
let error =
AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.unwrap_err()
.to_string();
// Search for a handful of keywords that suggest the help printed properly
// Search for a handful of keywords that suggest the help printed
// properly
assert!(error.contains("Usage:"));
assert!(error.contains("Commands:"));
assert!(error.contains("Options:"));

View file

@ -3,7 +3,6 @@ mod data;
use std::collections::BTreeMap;
pub(crate) use data::Data;
use futures_util::Future;
use regex::RegexSet;
use ruma::{
@ -48,6 +47,8 @@ impl NamespaceRegex {
}
impl TryFrom<Vec<Namespace>> for NamespaceRegex {
type Error = regex::Error;
fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> {
let mut exclusive = vec![];
let mut non_exclusive = vec![];
@ -73,8 +74,6 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
},
})
}
type Error = regex::Error;
}
/// Appservice registration combined with its compiled regular expressions.
@ -99,6 +98,8 @@ impl RegistrationInfo {
}
impl TryFrom<Registration> for RegistrationInfo {
type Error = regex::Error;
fn try_from(value: Registration) -> Result<RegistrationInfo, regex::Error> {
Ok(RegistrationInfo {
users: value.namespaces.users.clone().try_into()?,
@ -107,8 +108,6 @@ impl TryFrom<Registration> for RegistrationInfo {
registration: value,
})
}
type Error = regex::Error;
}
pub(crate) struct Service {
@ -135,8 +134,12 @@ impl Service {
registration_info: RwLock::new(registration_info),
})
}
/// Registers an appservice and returns the ID to the caller.
pub(crate) async fn register_appservice(&self, yaml: Registration) -> Result<String> {
pub(crate) async fn register_appservice(
&self,
yaml: Registration,
) -> Result<String> {
//TODO: Check for collisions between exclusive appservice namespaces
self.registration_info
.write()
@ -151,19 +154,27 @@ impl Service {
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
pub(crate) async fn unregister_appservice(&self, service_name: &str) -> Result<()> {
pub(crate) async fn unregister_appservice(
&self,
service_name: &str,
) -> Result<()> {
services()
.appservice
.registration_info
.write()
.await
.remove(service_name)
.ok_or_else(|| crate::Error::AdminCommand("Appservice not found"))?;
.ok_or_else(|| {
crate::Error::AdminCommand("Appservice not found")
})?;
self.db.unregister_appservice(service_name)
}
pub(crate) async fn get_registration(&self, id: &str) -> Option<Registration> {
pub(crate) async fn get_registration(
&self,
id: &str,
) -> Option<Registration> {
self.registration_info
.read()
.await
@ -173,15 +184,13 @@ impl Service {
}
pub(crate) async fn iter_ids(&self) -> Vec<String> {
self.registration_info
.read()
.await
.keys()
.cloned()
.collect()
self.registration_info.read().await.keys().cloned().collect()
}
pub(crate) async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
pub(crate) async fn find_from_token(
&self,
token: &str,
) -> Option<RegistrationInfo> {
self.read()
.await
.values()
@ -207,8 +216,12 @@ impl Service {
pub(crate) fn read(
&self,
) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>>
{
) -> impl Future<
Output = tokio::sync::RwLockReadGuard<
'_,
BTreeMap<String, RegistrationInfo>,
>,
> {
self.registration_info.read()
}
}

View file

@ -15,7 +15,9 @@ pub(crate) trait Data: Send + Sync {
fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
fn iter_ids<'a>(
&'a self,
) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
fn all(&self) -> Result<Vec<(String, Registration)>>;
}

View file

@ -1,26 +1,4 @@
mod data;
pub(crate) use data::Data;
use ruma::{
serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId,
};
use crate::api::server_server::FedDest;
use crate::{services, Config, Error, Result};
use futures_util::FutureExt;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service as HyperService,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
use ruma::{
api::{
client::sync::sync_events,
federation::discovery::{ServerSigningKeys, VerifyKey},
},
DeviceId, RoomVersionId, ServerName, UserId,
};
use std::{
collections::{BTreeMap, HashMap},
error::Error as StdError,
@ -35,11 +13,29 @@ use std::{
},
time::{Duration, Instant},
};
use base64::{engine::general_purpose, Engine as _};
pub(crate) use data::Data;
use futures_util::FutureExt;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service as HyperService,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
use ruma::{
api::{
client::sync::sync_events,
federation::discovery::{ServerSigningKeys, VerifyKey},
},
serde::Base64,
DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId, RoomVersionId, ServerName, UserId,
};
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
use tracing::{error, info};
use trust_dns_resolver::TokioAsyncResolver;
use base64::{engine::general_purpose, Engine as _};
use crate::{api::server_server::FedDest, services, Config, Error, Result};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
@ -66,27 +62,40 @@ pub(crate) struct Service {
default_client: reqwest::Client,
pub(crate) stable_room_versions: Vec<RoomVersionId>,
pub(crate) unstable_room_versions: Vec<RoomVersionId>,
pub(crate) bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub(crate) bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub(crate) bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub(crate) servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub(crate) sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub(crate) roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) bad_event_ratelimiter:
Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub(crate) bad_signature_ratelimiter:
Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub(crate) bad_query_ratelimiter:
Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub(crate) servername_ratelimiter:
Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub(crate) sync_receivers:
RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub(crate) roomid_mutex_insert:
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
// this lock will be held longer
pub(crate) roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub(crate) roomid_mutex_federation:
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_federationhandletime:
RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub(crate) stateres_mutex: Arc<Mutex<()>>,
pub(crate) rotate: RotationHandler,
pub(crate) shutdown: AtomicBool,
}
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
/// Handles "rotation" of long-polling requests. "Rotation" in this context is
/// similar to "rotation" of log files and the like.
///
/// This is utilized to have sync workers return early and release read locks on the database.
pub(crate) struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>);
/// This is utilized to have sync workers return early and release read locks on
/// the database.
pub(crate) struct RotationHandler(
broadcast::Sender<()>,
broadcast::Receiver<()>,
);
impl RotationHandler {
pub(crate) fn new() -> Self {
@ -136,7 +145,10 @@ impl Resolve for Resolver {
.and_then(|(override_name, port)| {
override_name.first().map(|first_name| {
let x: Box<dyn Iterator<Item = SocketAddr> + Send> =
Box::new(iter::once(SocketAddr::new(*first_name, *port)));
Box::new(iter::once(SocketAddr::new(
*first_name,
*port,
)));
let x: Resolving = Box::pin(future::ready(Ok(x)));
x
})
@ -144,9 +156,11 @@ impl Resolve for Resolver {
.unwrap_or_else(|| {
let this = &mut self.inner.clone();
Box::pin(HyperService::<Name>::call(this, name).map(|result| {
result
.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> Box<dyn StdError + Send + Sync> { Box::new(err) })
result.map(|addrs| -> Addrs { Box::new(addrs) }).map_err(
|err| -> Box<dyn StdError + Send + Sync> {
Box::new(err)
},
)
}))
})
}
@ -167,10 +181,9 @@ impl Service {
let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new()));
let jwt_decoding_key = config
.jwt_secret
.as_ref()
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| {
jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())
});
let default_client = reqwest_client_builder(&config)?.build()?;
let federation_client = reqwest_client_builder(&config)?
@ -187,20 +200,28 @@ impl Service {
RoomVersionId::V11,
];
// Experimental, partially supported room versions
let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let unstable_room_versions =
vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let mut s = Self {
db,
config,
keypair: Arc::new(keypair),
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
error!(
"Failed to set up trust dns resolver with system config: {}",
e
);
Error::bad_config("Failed to set up trust dns resolver with system config.")
})?,
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())),
dns_resolver: TokioAsyncResolver::tokio_from_system_conf()
.map_err(|e| {
error!(
"Failed to set up trust dns resolver with system \
config: {}",
e
);
Error::bad_config(
"Failed to set up trust dns resolver with system \
config.",
)
})?,
actual_destination_cache: Arc::new(
RwLock::new(WellKnownMap::new()),
),
tls_name_override,
federation_client,
default_client,
@ -223,12 +244,11 @@ impl Service {
fs::create_dir_all(s.get_media_folder())?;
if !s
.supported_room_versions()
.contains(&s.config.default_room_version)
if !s.supported_room_versions().contains(&s.config.default_room_version)
{
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = crate::config::default_default_room_version();
s.config.default_room_version =
crate::config::default_default_room_version();
};
Ok(s)
@ -261,7 +281,11 @@ impl Service {
self.db.current_count()
}
pub(crate) async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
pub(crate) async fn watch(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<()> {
self.db.watch(user_id, device_id).await
}
@ -313,7 +337,9 @@ impl Service {
&self.dns_resolver
}
pub(crate) fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> {
pub(crate) fn jwt_decoding_key(
&self,
) -> Option<&jsonwebtoken::DecodingKey> {
self.jwt_decoding_key.as_ref()
}
@ -353,7 +379,8 @@ impl Service {
/// TODO: the key valid until timestamp is only honored in room version > 4
/// Remove the outdated keys and insert the new ones.
///
/// This doesn't actually check that the keys provided are newer than the old set.
/// This doesn't actually check that the keys provided are newer than the
/// old set.
pub(crate) fn add_signing_key(
&self,
origin: &ServerName,
@ -362,7 +389,8 @@ impl Service {
self.db.add_signing_key(origin, new_keys)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub(crate) fn signing_keys_for(
&self,
origin: &ServerName,

View file

@ -13,7 +13,8 @@ use crate::Result;
pub(crate) trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>;
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
async fn watch(&self, user_id: &UserId, device_id: &DeviceId)
-> Result<()>;
fn cleanup(&self) -> Result<()>;
fn memory_usage(&self) -> String;
fn clear_caches(&self, amount: u32);
@ -25,7 +26,8 @@ pub(crate) trait Data: Send + Sync {
new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(
&self,
origin: &ServerName,

View file

@ -1,12 +1,13 @@
use std::collections::BTreeMap;
use crate::Result;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn create_backup(
&self,
@ -23,12 +24,21 @@ pub(crate) trait Data: Send + Sync {
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String>;
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
fn get_latest_backup_version(
&self,
user_id: &UserId,
) -> Result<Option<String>>;
fn get_latest_backup(&self, user_id: &UserId)
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
fn get_latest_backup(
&self,
user_id: &UserId,
) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
fn get_backup(
&self,
user_id: &UserId,
version: &str,
) -> Result<Option<Raw<BackupAlgorithm>>>;
fn add_key(
&self,
@ -66,7 +76,12 @@ pub(crate) trait Data: Send + Sync {
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
fn delete_room_keys(
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
) -> Result<()>;
fn delete_room_key(
&self,

View file

@ -2,15 +2,14 @@ mod data;
use std::io::Cursor;
pub(crate) use data::Data;
use crate::{services, Result};
use image::imageops::FilterType;
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
};
use crate::{services, Result};
pub(crate) struct FileMeta {
pub(crate) content_disposition: Option<String>,
pub(crate) content_type: Option<String>,
@ -31,9 +30,13 @@ impl Service {
file: &[u8],
) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail
let key = self
.db
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
let key = self.db.create_file_metadata(
mxc,
0,
0,
content_disposition,
content_type,
)?;
let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?;
@ -52,9 +55,13 @@ impl Service {
height: u32,
file: &[u8],
) -> Result<()> {
let key =
self.db
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
let key = self.db.create_file_metadata(
mxc,
width,
height,
content_disposition,
content_type,
)?;
let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?;
@ -84,9 +91,12 @@ impl Service {
}
}
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
/// the server should send the original file.
fn thumbnail_properties(width: u32, height: u32) -> Option<(u32, u32, bool)> {
/// Returns width, height of the thumbnail and whether it should be cropped.
/// Returns None when the server should send the original file.
fn thumbnail_properties(
width: u32,
height: u32,
) -> Option<(u32, u32, bool)> {
match (width, height) {
(0..=32, 0..=32) => Some((32, 32, true)),
(0..=96, 0..=96) => Some((96, 96, true)),
@ -102,11 +112,15 @@ impl Service {
/// Here's an example on how it works:
///
/// - Client requests an image with width=567, height=567
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
/// - Server rounds that up to (800, 600), so it doesn't have to save too
/// many thumbnails
/// - Server rounds that up again to (958, 600) to fix the aspect ratio
/// (only for width,height>96)
/// - Server creates the thumbnail and sends it to the user
///
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
/// For width,height <= 96 the server uses another thumbnailing algorithm
/// which crops the image afterwards.
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_thumbnail(
&self,
mxc: String,
@ -154,7 +168,8 @@ impl Service {
} else {
let (exact_width, exact_height) = {
// Copied from image::dynimage::resize_dimensions
let use_width = (u64::from(width) * u64::from(original_height))
let use_width = (u64::from(width)
* u64::from(original_height))
<= (u64::from(original_width) * u64::from(height));
let intermediate = if use_width {
u64::from(original_height) * u64::from(width)
@ -165,21 +180,31 @@ impl Service {
};
if use_width {
if intermediate <= u64::from(::std::u32::MAX) {
(width, intermediate.try_into().unwrap_or(u32::MAX))
(
width,
intermediate.try_into().unwrap_or(u32::MAX),
)
} else {
(
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
(u64::from(width)
* u64::from(::std::u32::MAX)
/ intermediate)
.try_into()
.unwrap_or(u32::MAX),
::std::u32::MAX,
)
}
} else if intermediate <= u64::from(::std::u32::MAX) {
(intermediate.try_into().unwrap_or(u32::MAX), height)
(
intermediate.try_into().unwrap_or(u32::MAX),
height,
)
} else {
(
::std::u32::MAX,
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
(u64::from(height)
* u64::from(::std::u32::MAX)
/ intermediate)
.try_into()
.unwrap_or(u32::MAX),
)
@ -195,7 +220,8 @@ impl Service {
image::ImageOutputFormat::Png,
)?;
// Save thumbnail in database so we don't have to generate it again next time
// Save thumbnail in database so we don't have to generate it
// again next time
let thumbnail_key = self.db.create_file_metadata(
mxc,
width,

View file

@ -1,24 +1,31 @@
use crate::Error;
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use ruma::{
canonical_json::redact_content_in_place,
events::{
room::{member::RoomMemberEventContent, redaction::RoomRedactionEventContent},
room::{
member::RoomMemberEventContent,
redaction::RoomRedactionEventContent,
},
space::child::HierarchySpaceChildEvent,
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent,
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType,
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent,
AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
AnyTimelineEvent, StateEvent, TimelineEventType,
},
serde::Raw,
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId,
MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
RoomVersionId, UInt, UserId,
};
use serde::{Deserialize, Serialize};
use serde_json::{
json,
value::{to_raw_value, RawValue as RawJsonValue},
};
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use tracing::warn;
use crate::Error;
/// Content hashes of a PDU.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct EventHash {
@ -61,10 +68,18 @@ impl PduEvent {
) -> crate::Result<()> {
self.unsigned = None;
let mut content = serde_json::from_str(self.content.get())
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
.map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?;
let mut content =
serde_json::from_str(self.content.get()).map_err(|_| {
Error::bad_database("PDU in db has invalid content.")
})?;
redact_content_in_place(
&mut content,
&room_version_id,
self.kind.to_string(),
)
.map_err(|e| {
Error::Redaction(self.sender.server_name().to_owned(), e)
})?;
self.unsigned = Some(to_raw_value(&json!({
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
@ -78,10 +93,12 @@ impl PduEvent {
pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> {
if let Some(unsigned) = &self.unsigned {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
serde_json::from_str(unsigned.get())
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
serde_json::from_str(unsigned.get()).map_err(|_| {
Error::bad_database("Invalid unsigned in pdu event")
})?;
unsigned.remove("transaction_id");
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
self.unsigned =
Some(to_raw_value(&unsigned).expect("unsigned is valid"));
}
Ok(())
@ -91,31 +108,45 @@ impl PduEvent {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned
.as_ref()
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
.map_or_else(
|| Ok(BTreeMap::new()),
|u| serde_json::from_str(u.get()),
)
.map_err(|_| {
Error::bad_database("Invalid unsigned in pdu event")
})?;
unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap());
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
self.unsigned =
Some(to_raw_value(&unsigned).expect("unsigned is valid"));
Ok(())
}
/// Copies the `redacts` property of the event to the `content` dict and vice-versa.
/// Copies the `redacts` property of the event to the `content` dict and
/// vice-versa.
///
/// This follows the specification's
/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property):
///
/// > For backwards-compatibility with older clients, servers should add a redacts
/// > property to the top level of m.room.redaction events in when serving such events
/// > For backwards-compatibility with older clients, servers should add a
/// > redacts
/// > property to the top level of m.room.redaction events in when serving
/// > such events
/// > over the Client-Server API.
/// >
/// > For improved compatibility with newer clients, servers should add a redacts property
/// > to the content of m.room.redaction events in older room versions when serving
/// > For improved compatibility with newer clients, servers should add a
/// > redacts property
/// > to the content of m.room.redaction events in older room versions when
/// > serving
/// > such events over the Client-Server API.
pub(crate) fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) {
pub(crate) fn copy_redacts(
&self,
) -> (Option<Arc<EventId>>, Box<RawJsonValue>) {
if self.kind == TimelineEventType::RoomRedaction {
if let Ok(mut content) =
serde_json::from_str::<RoomRedactionEventContent>(self.content.get())
if let Ok(mut content) = serde_json::from_str::<
RoomRedactionEventContent,
>(self.content.get())
{
if let Some(redacts) = content.redacts {
return (Some(redacts.into()), self.content.clone());
@ -123,7 +154,9 @@ impl PduEvent {
content.redacts = Some(redacts.into());
return (
self.redacts.clone(),
to_raw_value(&content).expect("Must be valid, we only added redacts field"),
to_raw_value(&content).expect(
"Must be valid, we only added redacts field",
),
);
}
}
@ -281,7 +314,9 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
pub(crate) fn to_stripped_spacechild_state_event(
&self,
) -> Raw<HierarchySpaceChildEvent> {
let json = json!({
"content": self.content,
"type": self.kind,
@ -294,7 +329,9 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
pub(crate) fn to_member_event(
&self,
) -> Raw<StateEvent<RoomMemberEventContent>> {
let mut json = json!({
"content": self.content,
"type": self.kind,
@ -318,16 +355,16 @@ impl PduEvent {
pub(crate) fn convert_to_outgoing_federation_event(
mut pdu_json: CanonicalJsonObject,
) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
if let Some(unsigned) =
pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut())
{
unsigned.remove("transaction_id");
}
pdu_json.remove("event_id");
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
to_raw_value(&pdu_json)
.expect("CanonicalJson is valid serde_json::Value")
}
pub(crate) fn from_id_val(
@ -374,11 +411,15 @@ impl state_res::Event for PduEvent {
self.state_key.as_deref()
}
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
fn prev_events(
&self,
) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
Box::new(self.prev_events.iter())
}
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
fn auth_events(
&self,
) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
Box::new(self.auth_events.iter())
}
@ -408,15 +449,17 @@ impl Ord for PduEvent {
/// Generates a correct eventId for the incoming pdu.
///
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
/// CanonicalJsonValue>`.
pub(crate) fn gen_event_id_canonical_json(
pdu: &RawJsonValue,
room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let value: CanonicalJsonObject =
serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let event_id = format!(
"${}",

View file

@ -1,27 +1,34 @@
mod data;
pub(crate) use data::Data;
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
use std::{fmt::Debug, mem};
use crate::{services, Error, PduEvent, Result};
use bytes::BytesMut;
pub(crate) use data::Data;
use ruma::{
api::{
client::push::{set_pusher, Pusher, PusherKind},
push_gateway::send_event_notification::{
self,
v1::{Device, Notification, NotificationCounts, NotificationPriority},
v1::{
Device, Notification, NotificationCounts, NotificationPriority,
},
},
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
},
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
events::{
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent,
StateEventType, TimelineEventType,
},
push::{
Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat,
Ruleset, Tweak,
},
serde::Raw,
uint, RoomId, UInt, UserId,
};
use std::{fmt::Debug, mem};
use tracing::{info, warn};
use crate::{services, Error, PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
}
@ -35,7 +42,11 @@ impl Service {
self.db.set_pusher(sender, pusher)
}
pub(crate) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
pub(crate) fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>> {
self.db.get_pusher(sender, pushkey)
}
@ -43,7 +54,10 @@ impl Service {
self.db.get_pushers(sender)
}
pub(crate) fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
pub(crate) fn get_pushkeys(
&self,
sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>>> {
self.db.get_pushkeys(sender)
}
@ -73,11 +87,8 @@ impl Service {
let reqwest_request = reqwest::Request::try_from(http_request)?;
let url = reqwest_request.url().clone();
let response = services()
.globals
.default_client()
.execute(reqwest_request)
.await;
let response =
services().globals.default_client().execute(reqwest_request).await;
match response {
Ok(mut response) => {
@ -119,11 +130,16 @@ impl Service {
"Push gateway returned invalid response bytes {}\n{}",
destination, url
);
Error::BadServerResponse("Push gateway returned bad response.")
Error::BadServerResponse(
"Push gateway returned bad response.",
)
})
}
Err(e) => {
warn!("Could not send request to pusher {}: {}", destination, e);
warn!(
"Could not send request to pusher {}: {}",
destination, e
);
Err(e.into())
}
}
@ -146,8 +162,9 @@ impl Service {
.state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
serde_json::from_str(ev.content.get())
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
serde_json::from_str(ev.content.get()).map_err(|_| {
Error::bad_database("invalid m.room.power_levels event")
})
})
.transpose()?
.unwrap_or_default();
@ -228,11 +245,16 @@ impl Service {
PusherKind::Http(http) => {
// TODO:
// Two problems with this
// 1. if "event_id_only" is the only format kind it seems we should never add more info
// 1. if "event_id_only" is the only format kind it seems we
// should never add more info
// 2. can pusher/devices have conflicting formats
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
let event_id_only =
http.format == Some(PushFormat::EventIdOnly);
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
let mut device = Device::new(
pusher.ids.app_id.clone(),
pusher.ids.pushkey.clone(),
);
device.data.default_payload = http.default_payload.clone();
device.data.format = http.format.clone();
@ -251,32 +273,43 @@ impl Service {
notifi.counts = NotificationCounts::new(unread, uint!(0));
if event.kind == TimelineEventType::RoomEncrypted
|| tweaks
.iter()
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|| tweaks.iter().any(|t| {
matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))
})
{
notifi.prio = NotificationPriority::High;
}
if event_id_only {
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
self.send_request(
&http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
} else {
notifi.sender = Some(event.sender.clone());
notifi.event_type = Some(event.kind.clone());
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
notifi.content =
serde_json::value::to_raw_value(&event.content).ok();
if event.kind == TimelineEventType::RoomMember {
notifi.user_is_target =
event.state_key.as_deref() == Some(event.sender.as_str());
notifi.user_is_target = event.state_key.as_deref()
== Some(event.sender.as_str());
}
notifi.sender_display_name = services().users.displayname(&event.sender)?;
notifi.sender_display_name =
services().users.displayname(&event.sender)?;
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
notifi.room_name = services()
.rooms
.state_accessor
.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
self.send_request(
&http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
}
Ok(())

View file

@ -1,16 +1,27 @@
use crate::Result;
use ruma::{
api::client::push::{set_pusher, Pusher},
UserId,
};
pub(crate) trait Data: Send + Sync {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
use crate::Result;
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
pub(crate) trait Data: Send + Sync {
fn set_pusher(
&self,
sender: &UserId,
pusher: set_pusher::v3::PusherAction,
) -> Result<()>;
fn get_pusher(
&self,
sender: &UserId,
pushkey: &str,
) -> Result<Option<Pusher>>;
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
fn get_pushkeys<'a>(&'a self, sender: &UserId)
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
fn get_pushkeys<'a>(
&'a self,
sender: &UserId,
) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
}

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
/// Creates or updates the alias to the given room id.
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
@ -9,7 +10,10 @@ pub(crate) trait Data: Send + Sync {
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
/// Looks up the roomid for the given alias.
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>;
fn resolve_local_alias(
&self,
alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>>;
/// Returns all local aliases that point to the given room
fn local_aliases_for_room<'a>(

View file

@ -43,7 +43,8 @@ impl Service {
let mut i = 0;
for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?;
let short =
services().rooms.short.get_or_create_shorteventid(&id)?;
// I'm afraid to change this in case there is accidental reliance on
// the truncation
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
@ -64,7 +65,8 @@ impl Service {
continue;
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
let chunk_key: Vec<u64> =
chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services()
.rooms
.auth_chain
@ -90,11 +92,13 @@ impl Service {
chunk_cache.extend(cached.iter().copied());
} else {
misses2 += 1;
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
services()
.rooms
.auth_chain
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
let auth_chain = Arc::new(
self.get_auth_chain_inner(room_id, &event_id)?,
);
services().rooms.auth_chain.cache_auth_chain(
vec![sevent_id],
Arc::clone(&auth_chain),
)?;
debug!(
event_id = ?event_id,
chain_length = ?auth_chain.len(),
@ -129,13 +133,17 @@ impl Service {
"Auth chain stats",
);
Ok(full_auth_chain
.into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
Ok(full_auth_chain.into_iter().filter_map(move |sid| {
services().rooms.short.get_eventid_from_short(sid).ok()
}))
}
#[tracing::instrument(skip(self, event_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
@ -143,7 +151,10 @@ impl Service {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Evil event in db",
));
}
for auth_event in &pdu.auth_events {
let sauthevent = services()
@ -158,10 +169,17 @@ impl Service {
}
}
Ok(None) => {
warn!(?event_id, "Could not find pdu mentioned in auth events");
warn!(
?event_id,
"Could not find pdu mentioned in auth events"
);
}
Err(error) => {
error!(?event_id, ?error, "Could not load event in auth chain");
error!(
?event_id,
?error,
"Could not load event in auth chain"
);
}
}
}

View file

@ -1,11 +1,15 @@
use crate::Result;
use std::{collections::HashSet, sync::Arc};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn get_cached_eventid_authchain(
&self,
shorteventid: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>)
-> Result<()>;
fn cache_auth_chain(
&self,
shorteventid: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()>;
}

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
/// Adds the room to the public room directory
fn set_public(&self, room_id: &RoomId) -> Result<()>;
@ -12,5 +13,7 @@ pub(crate) trait Data: Send + Sync {
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
/// Returns the unsorted public room directory
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
fn public_rooms<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
}

View file

@ -1,5 +1,8 @@
use ruma::{
events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId,
};
use crate::Result;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
pub(crate) trait Data: Send + Sync {
/// Replaces the previous read receipt.
@ -10,7 +13,8 @@ pub(crate) trait Data: Send + Sync {
event: ReceiptEvent,
) -> Result<()>;
/// Returns an iterator over the most recent read receipts in a room that happened after the event with id `since`.
/// Returns an iterator over the most recent read receipts in a room that
/// happened after the event with id `since`.
#[allow(clippy::type_complexity)]
fn readreceipts_since<'a>(
&'a self,
@ -27,11 +31,24 @@ pub(crate) trait Data: Send + Sync {
>;
/// Sets a private read marker at `count`.
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
fn private_read_set(
&self,
room_id: &RoomId,
user_id: &UserId,
count: u64,
) -> Result<()>;
/// Returns the private read marker.
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>;
fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<u64>>;
/// Returns the count of the last typing update in this room.
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<u64>;
}

View file

@ -1,8 +1,9 @@
use std::collections::BTreeMap;
use ruma::{
events::{typing::TypingEventContent, SyncEphemeralRoomEvent},
OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use std::collections::BTreeMap;
use tokio::sync::{broadcast, RwLock};
use tracing::trace;
@ -10,15 +11,16 @@ use crate::{services, utils, Result};
pub(crate) struct Service {
// u64 is unix timestamp of timeout
pub(crate) typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>,
pub(crate) typing:
RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>,
// timestamp of the last change to typing users
pub(crate) last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>,
pub(crate) typing_update_sender: broadcast::Sender<OwnedRoomId>,
}
impl Service {
/// Sets a user as typing until the timeout timestamp is reached or `roomtyping_remove` is
/// called.
/// Sets a user as typing until the timeout timestamp is reached or
/// `roomtyping_remove` is called.
pub(crate) async fn typing_add(
&self,
user_id: &UserId,
@ -36,13 +38,20 @@ impl Service {
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
trace!(
"receiver found what it was looking for and is no longer \
interested"
);
}
Ok(())
}
/// Removes a user from typing before the timeout is reached.
pub(crate) async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(crate) async fn typing_remove(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
self.typing
.write()
.await
@ -54,7 +63,10 @@ impl Service {
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
trace!(
"receiver found what it was looking for and is no longer \
interested"
);
}
Ok(())
}
@ -97,14 +109,20 @@ impl Service {
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
trace!(
"receiver found what it was looking for and is no longer \
interested"
);
}
}
Ok(())
}
/// Returns the count of the last typing update in this room.
pub(crate) async fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
pub(crate) async fn last_typing_update(
&self,
room_id: &RoomId,
) -> Result<u64> {
self.typings_maintain(room_id).await?;
Ok(self
.last_typing_update

File diff suppressed because it is too large Load diff

View file

@ -5,16 +5,19 @@ pub(crate) use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use crate::Result;
use super::timeline::PduCount;
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
#[allow(clippy::type_complexity)]
pub(crate) lazy_load_waiting:
Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
pub(crate) lazy_load_waiting: Mutex<
HashMap<
(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount),
HashSet<OwnedUserId>,
>,
>,
}
impl Service {
@ -26,8 +29,7 @@ impl Service {
room_id: &RoomId,
ll_user: &UserId,
) -> Result<bool> {
self.db
.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
}
#[tracing::instrument(skip(self))]

View file

@ -1,6 +1,7 @@
use crate::Result;
use ruma::{DeviceId, RoomId, UserId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn lazy_load_was_sent_before(
&self,

View file

@ -1,10 +1,13 @@
use crate::Result;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
/// Checks if a room exists.
fn exists(&self, room_id: &RoomId) -> Result<bool>;
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
fn iter_ids<'a>(
&'a self,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
fn is_disabled(&self, room_id: &RoomId) -> Result<bool>;
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>;
}

View file

@ -4,8 +4,15 @@ use crate::{PduEvent, Result};
pub(crate) trait Data: Send + Sync {
/// Returns the pdu from the outlier tree.
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
/// Append the PDU as an outlier.
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>;
fn add_pdu_outlier(
&self,
event_id: &EventId,
pdu: &CanonicalJsonObject,
) -> Result<()>;
}

View file

@ -9,9 +9,8 @@ use ruma::{
};
use serde::Deserialize;
use crate::{services, PduEvent, Result};
use super::timeline::PduCount;
use crate::{services, PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
@ -29,9 +28,15 @@ struct ExtractRelatesToEventId {
impl Service {
#[tracing::instrument(skip(self, from, to))]
pub(crate) fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
pub(crate) fn add_relation(
&self,
from: PduCount,
to: PduCount,
) -> Result<()> {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
(PduCount::Normal(f), PduCount::Normal(t)) => {
self.db.add_relation(f, t)
}
_ => {
// TODO: Relations with backfilled pdus
@ -42,6 +47,7 @@ impl Service {
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
// Allowed because this function uses `services()`
clippy::unused_self,
)]
@ -68,15 +74,17 @@ impl Service {
.relations_until(sender_user, room_id, target, from)?
.filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t)
&& if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(
pdu.content.get(),
)
{
filter_rel_type
.as_ref()
.map_or(true, |r| &&content.relates_to.rel_type == r)
filter_event_type
.as_ref()
.map_or(true, |t| &&pdu.kind == t)
&& if let Ok(content) = serde_json::from_str::<
ExtractRelatesToEventId,
>(
pdu.content.get()
) {
filter_rel_type.as_ref().map_or(true, |r| {
&&content.relates_to.rel_type == r
})
} else {
false
}
@ -88,13 +96,18 @@ impl Service {
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.user_can_see_event(
sender_user,
room_id,
&pdu.event_id,
)
.unwrap_or(false)
})
.take_while(|&(k, _)| Some(k) != to)
.collect();
next_token = events_after.last().map(|(count, _)| count).copied();
next_token =
events_after.last().map(|(count, _)| count).copied();
// Reversed because relations are always most recent first
let events_after: Vec<_> = events_after
@ -116,15 +129,17 @@ impl Service {
.relations_until(sender_user, room_id, target, from)?
.filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &&pdu.kind == t)
&& if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(
pdu.content.get(),
)
{
filter_rel_type
.as_ref()
.map_or(true, |r| &&content.relates_to.rel_type == r)
filter_event_type
.as_ref()
.map_or(true, |t| &&pdu.kind == t)
&& if let Ok(content) = serde_json::from_str::<
ExtractRelatesToEventId,
>(
pdu.content.get()
) {
filter_rel_type.as_ref().map_or(true, |r| {
&&content.relates_to.rel_type == r
})
} else {
false
}
@ -136,13 +151,18 @@ impl Service {
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.user_can_see_event(
sender_user,
room_id,
&pdu.event_id,
)
.unwrap_or(false)
})
.take_while(|&(k, _)| Some(k) != to)
.collect();
next_token = events_before.last().map(|(count, _)| count).copied();
next_token =
events_before.last().map(|(count, _)| count).copied();
let events_before: Vec<_> = events_before
.into_iter()
@ -165,7 +185,8 @@ impl Service {
target: &'a EventId,
until: PduCount,
) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
let room_id =
services().rooms.short.get_or_create_shortroomid(room_id)?;
let target = match services().rooms.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
@ -185,17 +206,27 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub(crate) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
pub(crate) fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
self.db.is_event_referenced(room_id, event_id)
}
#[tracing::instrument(skip(self))]
pub(crate) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
pub(crate) fn mark_event_soft_failed(
&self,
event_id: &EventId,
) -> Result<()> {
self.db.mark_event_soft_failed(event_id)
}
#[tracing::instrument(skip(self))]
pub(crate) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
pub(crate) fn is_event_soft_failed(
&self,
event_id: &EventId,
) -> Result<bool> {
self.db.is_event_soft_failed(event_id)
}
}

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
use ruma::{EventId, RoomId, UserId};
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
pub(crate) trait Data: Send + Sync {
fn add_relation(&self, from: u64, to: u64) -> Result<()>;
#[allow(clippy::type_complexity)]
@ -13,8 +14,16 @@ pub(crate) trait Data: Send + Sync {
target: u64,
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>;
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>;
fn mark_as_referenced(
&self,
room_id: &RoomId,
event_ids: &[Arc<EventId>],
) -> Result<()>;
fn is_event_referenced(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool>;
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>;
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>;
}

View file

@ -1,8 +1,14 @@
use crate::Result;
use ruma::RoomId;
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
fn index_pdu(
&self,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()>;
#[allow(clippy::type_complexity)]
fn search_pdus<'a>(

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use crate::Result;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>;
@ -18,12 +19,19 @@ pub(crate) trait Data: Send + Sync {
state_key: &str,
) -> Result<u64>;
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>;
fn get_eventid_from_short(&self, shorteventid: u64)
-> Result<Arc<EventId>>;
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>;
fn get_statekey_from_short(
&self,
shortstatekey: u64,
) -> Result<(StateEventType, String)>;
/// Returns `(shortstatehash, already_existed)`
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>;
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;

Some files were not shown because too many files have changed in this diff Show more