From 66210bc32d100812d958e5981ff978a53046e290 Mon Sep 17 00:00:00 2001 From: Olivia Lee Date: Sun, 23 Mar 2025 02:08:38 -0700 Subject: [PATCH] separate account_data service methods for room vs global events Previously we were mashing everything together as RoomAccountDataEvent, even the global events. This technically worked, because of the hidden custom fields on the ruma event types, but it's confusing and easy to mess up. Separate methods with appropriate types are preferable. --- src/api/client_server/account.rs | 17 ++-- src/api/client_server/config.rs | 78 +++++++--------- src/api/client_server/push.rs | 123 ++++++++++--------------- src/api/client_server/read_marker.rs | 25 ++--- src/api/client_server/sync/msc3575.rs | 13 +-- src/api/client_server/sync/v3.rs | 26 +----- src/api/client_server/tag.rs | 47 +++++----- src/database.rs | 56 +++++------ src/database/key_value/account_data.rs | 70 +++++++------- src/service/account_data.rs | 105 +++++++++++++++++---- src/service/account_data/data.rs | 35 ++++--- src/service/admin.rs | 15 ++- src/service/globals.rs | 16 ++-- src/service/rooms/state_cache.rs | 58 +++++------- src/service/rooms/timeline.rs | 16 +--- src/service/sending.rs | 10 +- 16 files changed, 349 insertions(+), 361 deletions(-) diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index d660c6ba..035b1788 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -12,9 +12,12 @@ use ruma::{ uiaa::{AuthFlow, AuthType, UiaaInfo}, }, events::{ - room::message::RoomMessageEventContent, GlobalAccountDataEventType, + room::message::RoomMessageEventContent, AnyGlobalAccountDataEvent, + GlobalAccountDataEventType, }, - push, UserId, + push, + serde::Raw, + UserId, }; use tracing::{info, warn}; @@ -234,16 +237,16 @@ pub(crate) async fn register_route( services().users.set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data - services().account_data.update( - None, + services().account_data.update_global( &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + &GlobalAccountDataEventType::PushRules, + &Raw::new(&ruma::events::push_rules::PushRulesEvent { content: ruma::events::push_rules::PushRulesEventContent { global: push::Ruleset::server_default(&user_id), }, }) - .expect("to json always works"), + .expect("constructed event should be valid") + .cast::(), )?; // Inhibit login does not work for guests diff --git a/src/api/client_server/config.rs b/src/api/client_server/config.rs index 8d9a0a84..76bf3395 100644 --- a/src/api/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -7,12 +7,13 @@ use ruma::{ error::ErrorKind, }, events::{ - AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent, + AnyGlobalAccountDataEvent, AnyGlobalAccountDataEventContent, + AnyRoomAccountDataEvent, AnyRoomAccountDataEventContent, }, serde::Raw, }; use serde::Deserialize; -use serde_json::{json, value::RawValue as RawJsonValue}; +use serde_json::json; use crate::{services, Ar, Error, Ra, Result}; @@ -24,21 +25,17 @@ pub(crate) async fn set_global_account_data_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Data is invalid.") - })?; + let event = Raw::new(&json!({ + "type": &body.event_type, + "content": &body.data, + })) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))? + .cast::(); - let event_type = body.event_type.to_string(); - - services().account_data.update( - None, + services().account_data.update_global( sender_user, - event_type.clone().into(), - &json!({ - "type": event_type, - "content": data, - }), + &body.event_type, + &event, )?; Ok(Ra(set_global_account_data::v3::Response {})) @@ -52,21 +49,18 @@ pub(crate) async fn set_room_account_data_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Data is invalid.") - })?; + let event = Raw::new(&json!({ + "type": &body.event_type, + "content": &body.data, + })) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))? + .cast::(); - let event_type = body.event_type.to_string(); - - services().account_data.update( - Some(&body.room_id), + services().account_data.update_room( + &body.room_id, sender_user, - event_type.clone().into(), - &json!({ - "type": event_type, - "content": data, - }), + &body.event_type, + &event, )?; Ok(Ra(set_room_account_data::v3::Response {})) @@ -80,17 +74,15 @@ pub(crate) async fn get_global_account_data_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services() + let event = services() .account_data - .get(None, sender_user, body.event_type.to_string().into())? + .get_global(sender_user, &body.event_type)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = - serde_json::from_str::(event.get()) - .map_err(|_| { - Error::bad_database("Invalid account data event in db.") - })? - .content; + let account_data = event + .deserialize_as::() + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; Ok(Ra(get_global_account_data::v3::Response { account_data, @@ -105,17 +97,15 @@ pub(crate) async fn get_room_account_data_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services() + let event = services() .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? + .get_room(&body.room_id, sender_user, &body.event_type)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = - serde_json::from_str::(event.get()) - .map_err(|_| { - Error::bad_database("Invalid account data event in db.") - })? - .content; + let account_data = event + .deserialize_as::() + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; Ok(Ra(get_room_account_data::v3::Response { account_data, diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index 9d2b5956..b0e12cec 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -7,8 +7,12 @@ use ruma::{ set_pushrule_actions, set_pushrule_enabled, }, }, - events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, + events::{ + push_rules::PushRulesEvent, AnyGlobalAccountDataEvent, + GlobalAccountDataEventType, + }, push::{AnyPushRuleRef, InsertPushRuleError, RemovePushRuleError}, + serde::Raw, }; use crate::{services, Ar, Error, Ra, Result}; @@ -23,17 +27,14 @@ pub(crate) async fn get_pushrules_all_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let account_data = serde_json::from_str::(event.get()) + let account_data = event + .deserialize_as::() .map_err(|_| Error::bad_database("Invalid account data event in db."))? .content; @@ -52,17 +53,14 @@ pub(crate) async fn get_pushrule_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let account_data = serde_json::from_str::(event.get()) + let account_data = event + .deserialize_as::() .map_err(|_| Error::bad_database("Invalid account data event in db."))? .content; @@ -91,18 +89,14 @@ pub(crate) async fn set_pushrule_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| { + let mut account_data = + event.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") })?; @@ -142,12 +136,12 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - services().account_data.update( - None, + services().account_data.update_global( sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json event serialization should always succeed") + .cast::(), )?; Ok(Ra(set_pushrule::v3::Response {})) @@ -163,17 +157,14 @@ pub(crate) async fn get_pushrule_actions_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let account_data = serde_json::from_str::(event.get()) + let account_data = event + .deserialize_as::() .map_err(|_| Error::bad_database("Invalid account data event in db."))? .content; @@ -201,18 +192,14 @@ pub(crate) async fn set_pushrule_actions_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| { + let mut account_data = + event.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") })?; @@ -228,12 +215,12 @@ pub(crate) async fn set_pushrule_actions_route( )); } - services().account_data.update( - None, + services().account_data.update_global( sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json event serialization should always suceed") + .cast::(), )?; Ok(Ra(set_pushrule_actions::v3::Response {})) @@ -249,18 +236,14 @@ pub(crate) async fn get_pushrule_enabled_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| { + let account_data = + event.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") })?; @@ -288,18 +271,14 @@ pub(crate) async fn set_pushrule_enabled_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| { + let mut account_data = + event.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") })?; @@ -315,12 +294,12 @@ pub(crate) async fn set_pushrule_enabled_route( )); } - services().account_data.update( - None, + services().account_data.update_global( sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json event serialization should always succeed") + .cast::(), )?; Ok(Ra(set_pushrule_enabled::v3::Response {})) @@ -336,18 +315,14 @@ pub(crate) async fn delete_pushrule_route( let event = services() .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(sender_user, &GlobalAccountDataEventType::PushRules)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "PushRules event not found.", ))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| { + let mut account_data = + event.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") })?; @@ -368,12 +343,12 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services().account_data.update( - None, + services().account_data.update_global( sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json event serialization should always suceed") + .cast::(), )?; Ok(Ra(delete_pushrule::v3::Response {})) diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index 2daabf73..0f6655a4 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -6,8 +6,9 @@ use ruma::{ }, events::{ receipt::{ReceiptThread, ReceiptType}, - RoomAccountDataEventType, + AnyRoomAccountDataEvent, RoomAccountDataEventType, }, + serde::Raw, MilliSecondsSinceUnixEpoch, }; @@ -33,12 +34,13 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services().account_data.update( - Some(&body.room_id), + services().account_data.update_room( + &body.room_id, sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event) - .expect("to json value always works"), + &RoomAccountDataEventType::FullyRead, + &Raw::new(&fully_read_event) + .expect("json event serialization should always suceed") + .cast::(), )?; } @@ -129,12 +131,13 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services().account_data.update( - Some(&body.room_id), + services().account_data.update_room( + &body.room_id, sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event) - .expect("to json value always works"), + &RoomAccountDataEventType::FullyRead, + &Raw::new(&fully_read_event) + .expect("json event serialization should always succeed") + .cast::(), )?; } create_receipt::v3::ReceiptType::Read => { diff --git a/src/api/client_server/sync/msc3575.rs b/src/api/client_server/sync/msc3575.rs index 1de1c7d1..20ee6379 100644 --- a/src/api/client_server/sync/msc3575.rs +++ b/src/api/client_server/sync/msc3575.rs @@ -644,17 +644,8 @@ pub(crate) async fn sync_events_v4_route( { services() .account_data - .changes_since(None, &sender_user, globalsince)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| { - Error::bad_database( - "Invalid account event in database.", - ) - }) - .ok() - }) + .global_changes_since(&sender_user, globalsince)? + .into_values() .collect() } else { Vec::new() diff --git a/src/api/client_server/sync/v3.rs b/src/api/client_server/sync/v3.rs index 65830053..21bb0004 100644 --- a/src/api/client_server/sync/v3.rs +++ b/src/api/client_server/sync/v3.rs @@ -234,17 +234,8 @@ pub(crate) async fn sync_events_route( account_data: GlobalAccountData { events: services() .account_data - .changes_since(None, ctx.sender_user, ctx.since)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| { - Error::bad_database( - "Invalid account event in database.", - ) - }) - .ok() - }) + .global_changes_since(ctx.sender_user, ctx.since)? + .into_values() .collect(), }, device_lists: DeviceLists { @@ -880,17 +871,8 @@ async fn load_joined_room( account_data: RoomAccountData { events: services() .account_data - .changes_since(Some(room_id), ctx.sender_user, ctx.since)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| { - Error::bad_database( - "Invalid account event in database.", - ) - }) - .ok() - }) + .room_changes_since(ctx.sender_user, room_id, ctx.since)? + .into_values() .collect(), }, summary: RoomSummary { diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 17367300..12a3e37d 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -4,8 +4,9 @@ use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ tag::{TagEvent, TagEventContent}, - RoomAccountDataEventType, + AnyRoomAccountDataEvent, RoomAccountDataEventType, }, + serde::Raw, }; use crate::{services, Ar, Error, Ra, Result}; @@ -20,10 +21,10 @@ pub(crate) async fn update_tag_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), + let event = services().account_data.get_room( + &body.room_id, sender_user, - RoomAccountDataEventType::Tag, + &RoomAccountDataEventType::Tag, )?; let mut tags_event = event.map_or_else( @@ -35,7 +36,7 @@ pub(crate) async fn update_tag_route( }) }, |e| { - serde_json::from_str(e.get()).map_err(|_| { + e.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") }) }, @@ -46,11 +47,13 @@ pub(crate) async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services().account_data.update( - Some(&body.room_id), + services().account_data.update_room( + &body.room_id, sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), + &RoomAccountDataEventType::Tag, + &Raw::new(&tags_event) + .expect("json event serialization should always suceed") + .cast::(), )?; Ok(Ra(create_tag::v3::Response {})) @@ -66,10 +69,10 @@ pub(crate) async fn delete_tag_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), + let event = services().account_data.get_room( + &body.room_id, sender_user, - RoomAccountDataEventType::Tag, + &RoomAccountDataEventType::Tag, )?; let mut tags_event = event.map_or_else( @@ -81,7 +84,7 @@ pub(crate) async fn delete_tag_route( }) }, |e| { - serde_json::from_str(e.get()).map_err(|_| { + e.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") }) }, @@ -89,11 +92,13 @@ pub(crate) async fn delete_tag_route( tags_event.content.tags.remove(&body.tag.clone().into()); - services().account_data.update( - Some(&body.room_id), + services().account_data.update_room( + &body.room_id, sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), + &RoomAccountDataEventType::Tag, + &Raw::new(&tags_event) + .expect("json value serialization should always succeed") + .cast::(), )?; Ok(Ra(delete_tag::v3::Response {})) @@ -109,10 +114,10 @@ pub(crate) async fn get_tags_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), + let event = services().account_data.get_room( + &body.room_id, sender_user, - RoomAccountDataEventType::Tag, + &RoomAccountDataEventType::Tag, )?; let tags_event = event.map_or_else( @@ -124,7 +129,7 @@ pub(crate) async fn get_tags_route( }) }, |e| { - serde_json::from_str(e.get()).map_err(|_| { + e.deserialize_as::().map_err(|_| { Error::bad_database("Invalid account data event in db.") }) }, diff --git a/src/database.rs b/src/database.rs index f7cb1352..b374f361 100644 --- a/src/database.rs +++ b/src/database.rs @@ -7,8 +7,12 @@ use std::{ }; use ruma::{ - events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, + events::{ + push_rules::PushRulesEvent, AnyGlobalAccountDataEvent, + GlobalAccountDataEventType, + }, push::Ruleset, + serde::Raw, EventId, OwnedRoomId, RoomId, UserId, }; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -859,20 +863,15 @@ impl KeyValueDatabase { let raw_rules_list = services() .account_data - .get( - None, + .get_global( &user, - GlobalAccountDataEventType::PushRules - .to_string() - .into(), + &GlobalAccountDataEventType::PushRules, ) .unwrap() .expect("Username is invalid"); - let mut account_data = - serde_json::from_str::( - raw_rules_list.get(), - ) + let mut account_data = raw_rules_list + .deserialize_as::() .unwrap(); let rules_list = &mut account_data.content.global; @@ -927,14 +926,12 @@ impl KeyValueDatabase { } } - services().account_data.update( - None, + services().account_data.update_global( &user, - GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json serialization should always succeed") + .cast::(), )?; } Ok(()) @@ -961,20 +958,15 @@ impl KeyValueDatabase { let raw_rules_list = services() .account_data - .get( - None, + .get_global( &user, - GlobalAccountDataEventType::PushRules - .to_string() - .into(), + &GlobalAccountDataEventType::PushRules, ) .unwrap() .expect("Username is invalid"); - let mut account_data = - serde_json::from_str::( - raw_rules_list.get(), - ) + let mut account_data = raw_rules_list + .deserialize_as::() .unwrap(); let user_default_rules = Ruleset::server_default(&user); @@ -983,14 +975,12 @@ impl KeyValueDatabase { .global .update_with_server_default(user_default_rules); - services().account_data.update( - None, + services().account_data.update_global( &user, - GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(account_data) - .expect("to json value always works"), + &GlobalAccountDataEventType::PushRules, + &Raw::new(&account_data) + .expect("json serialization should always succeed") + .cast::(), )?; } Ok(()) diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index dfabb322..49231a2b 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,11 +1,8 @@ use std::collections::HashMap; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; +use ruma::{api::client::error::ErrorKind, RoomId, UserId}; +use serde::Deserialize; +use serde_json::value::RawValue; use crate::{ database::KeyValueDatabase, service, services, utils, Error, Result, @@ -16,9 +13,19 @@ impl service::account_data::Data for KeyValueDatabase { &self, room_id: Option<&RoomId>, user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, + event_type: &str, + data: &RawValue, ) -> Result<()> { + // Allowed because we just use this type to validate the schema, and + // don't read the fields. + #[allow(dead_code)] + #[derive(Deserialize)] + struct ExtractEventFields<'a> { + #[serde(rename = "type")] + event_type: &'a str, + content: &'a RawValue, + } + let mut prefix = room_id .map(ToString::to_string) .unwrap_or_default() @@ -32,23 +39,20 @@ impl service::account_data::Data for KeyValueDatabase { roomuserdataid .extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); + roomuserdataid.extend_from_slice(event_type.as_bytes()); let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); + key.extend_from_slice(event_type.as_bytes()); - if data.get("type").is_none() || data.get("content").is_none() { + if serde_json::from_str::>(data.get()).is_err() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Account data doesn't have all required fields.", )); } - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data) - .expect("to_vec always works on json values"), - )?; + self.roomuserdataid_accountdata + .insert(&roomuserdataid, data.get().as_bytes())?; let prev = self.roomusertype_roomuserdataid.get(&key)?; @@ -66,8 +70,8 @@ impl service::account_data::Data for KeyValueDatabase { &self, room_id: Option<&RoomId>, user_id: &UserId, - kind: RoomAccountDataEventType, - ) -> Result>> { + kind: &str, + ) -> Result>> { let mut key = room_id .map(ToString::to_string) .unwrap_or_default() @@ -76,7 +80,7 @@ impl service::account_data::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); + key.extend_from_slice(kind.as_bytes()); self.roomusertype_roomuserdataid .get(&key)? @@ -96,8 +100,7 @@ impl service::account_data::Data for KeyValueDatabase { room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>> - { + ) -> Result>> { let mut userdata = HashMap::new(); let mut prefix = room_id @@ -119,28 +122,23 @@ impl service::account_data::Data for KeyValueDatabase { .take_while(move |(k, _)| k.starts_with(&prefix)) .map(|(k, v)| { Ok::<_, Error>(( - RoomAccountDataEventType::from( - utils::string_from_bytes( - k.rsplit(|&b| b == 0xFF).next().ok_or_else( - || { - Error::bad_database( - "RoomUserData ID in db is invalid.", - ) - }, - )?, - ) - .map_err(|_| { + utils::string_from_bytes( + k.rsplit(|&b| b == 0xFF).next().ok_or_else(|| { Error::bad_database( "RoomUserData ID in db is invalid.", ) })?, - ), - serde_json::from_slice::>(&v) - .map_err(|_| { + ) + .map_err(|_| { + Error::bad_database("RoomUserData ID in db is invalid.") + })?, + serde_json::from_slice::>(&v).map_err( + |_| { Error::bad_database( "Database contains invalid account data.", ) - })?, + }, + )?, )) }) { diff --git a/src/service/account_data.rs b/src/service/account_data.rs index 7b51b421..014afcfd 100644 --- a/src/service/account_data.rs +++ b/src/service/account_data.rs @@ -1,7 +1,10 @@ use std::collections::HashMap; use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + events::{ + AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, + GlobalAccountDataEventType, RoomAccountDataEventType, + }, serde::Raw, RoomId, UserId, }; @@ -23,42 +26,104 @@ impl Service { } } - /// Places one event in the account data of the user and removes the + /// Places one event in the global account data of the user and removes the /// previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub(crate) fn update( + #[tracing::instrument(skip(self, user_id, event))] + pub(crate) fn update_global( &self, - room_id: Option<&RoomId>, user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, + event_type: &GlobalAccountDataEventType, + event: &Raw, ) -> Result<()> { - self.db.update(room_id, user_id, event_type, data) + self.db.update(None, user_id, &event_type.to_string(), event.json()) } - /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, event_type))] - pub(crate) fn get( + /// Places one event in the room account data of the user and removes the + /// previous entry for that room. + #[tracing::instrument(skip(self, room_id, user_id, event))] + pub(crate) fn update_room( &self, - room_id: Option<&RoomId>, + room_id: &RoomId, user_id: &UserId, - event_type: RoomAccountDataEventType, - ) -> Result>> { - self.db.get(room_id, user_id, event_type) + event_type: &RoomAccountDataEventType, + event: &Raw, + ) -> Result<()> { + self.db.update( + Some(room_id), + user_id, + &event_type.to_string(), + event.json(), + ) } - /// Returns all changes to the account data that happened after `since`. + /// Searches the global account data for a specific kind. + #[tracing::instrument(skip(self, user_id, event_type))] + pub(crate) fn get_global( + &self, + user_id: &UserId, + event_type: &GlobalAccountDataEventType, + ) -> Result>> { + Ok(self + .db + .get(None, user_id, &event_type.to_string())? + .map(Raw::from_json)) + } + + /// Searches the room account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, event_type))] + pub(crate) fn get_room( + &self, + room_id: &RoomId, + user_id: &UserId, + event_type: &RoomAccountDataEventType, + ) -> Result>> { + Ok(self + .db + .get(Some(room_id), user_id, &event_type.to_string())? + .map(Raw::from_json)) + } + + /// Returns all changes to global account data that happened after `since`. + /// + /// When there have been multiple changes to the same event type, returned + /// map contains the most recent value. + #[tracing::instrument(skip(self, user_id, since))] + pub(crate) fn global_changes_since( + &self, + user_id: &UserId, + since: u64, + ) -> Result< + HashMap>, + > { + Ok(self + .db + .changes_since(None, user_id, since)? + .into_iter() + .map(|(event_type, event)| { + (event_type.into(), Raw::from_json(event)) + }) + .collect()) + } + + /// Returns all changes to room account data that happened after `since`. /// /// When there have been multiple changes to the same event type, returned /// map contains the most recent value. #[tracing::instrument(skip(self, room_id, user_id, since))] - pub(crate) fn changes_since( + pub(crate) fn room_changes_since( &self, - room_id: Option<&RoomId>, user_id: &UserId, + room_id: &RoomId, since: u64, - ) -> Result>> + ) -> Result>> { - self.db.changes_since(room_id, user_id, since) + Ok(self + .db + .changes_since(Some(room_id), user_id, since)? + .into_iter() + .map(|(event_type, event)| { + (event_type.into(), Raw::from_json(event)) + }) + .collect()) } } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 868881c8..9db322ad 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,40 +1,51 @@ use std::collections::HashMap; -use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; +use ruma::{RoomId, UserId}; +use serde_json::value::RawValue; use crate::Result; +/// Unlike the service-level API, the database API for account data does not +/// distinguish between global and room events. Because there are no ruma types +/// that cover both, we use strings for the event types and raw json values for +/// the contents. pub(crate) trait Data: Send + Sync { /// Places one event in the account data of the user and removes the /// previous entry. + /// + /// If `room_id` is `None`, set a global event, otherwise set a room event + /// in the specified room. fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, + event_type: &str, + data: &RawValue, ) -> Result<()>; /// Searches the account data for a specific kind. + /// + /// If `room_id` is `None`, search global events, otherwise search room + /// events in the specified room. fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, - kind: RoomAccountDataEventType, - ) -> Result>>; + kind: &str, + ) -> Result>>; /// Returns all changes to the account data that happened after `since`. /// - /// When there have been multiple changes to the same event type, returned - /// map contains the most recent value. + /// If `room_id` is `None`, read global events, otherwise read room events + /// in the specified room. + /// + /// Returned as a map from event type to event objects (containing both a + /// `type` and a `content` key). When there have been multiple changes to + /// the same event type, returned map contains the most recent value. fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>>; + ) -> Result>>; } diff --git a/src/service/admin.rs b/src/service/admin.rs index 5e3ee770..f5fb39aa 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -20,8 +20,9 @@ use ruma::{ power_levels::RoomPowerLevelsEventContent, topic::RoomTopicEventContent, }, - TimelineEventType, + AnyGlobalAccountDataEvent, TimelineEventType, }, + serde::Raw, signatures::verify_json, EventId, MilliSecondsSinceUnixEpoch, OwnedMxcUri, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, @@ -770,20 +771,18 @@ impl Service { .set_displayname(&user_id, Some(displayname))?; // Initial account data - services().account_data.update( - None, + services().account_data.update_global( &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(PushRulesEvent { + &ruma::events::GlobalAccountDataEventType::PushRules, + &Raw::new(&PushRulesEvent { content: PushRulesEventContent { global: ruma::push::Ruleset::server_default( &user_id, ), }, }) - .expect("to json value always works"), + .expect("json serialization should always succeed") + .cast::(), )?; // we dont add a device since we're not the user, just the diff --git a/src/service/globals.rs b/src/service/globals.rs index 5f8e4c27..c1922fef 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -18,11 +18,11 @@ use ruma::{ api::federation::discovery::ServerSigningKeys, events::{ push_rules::PushRulesEventContent, - room::message::RoomMessageEventContent, GlobalAccountDataEvent, - GlobalAccountDataEventType, + room::message::RoomMessageEventContent, AnyGlobalAccountDataEvent, + GlobalAccountDataEvent, GlobalAccountDataEventType, }, push::Ruleset, - serde::Base64, + serde::{Base64, Raw}, DeviceId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UInt, UserId, @@ -492,16 +492,16 @@ impl Service { None => (Ruleset::new(), Ok(false)), }; - services().account_data.update( - None, + services().account_data.update_global( admin_bot, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { + &GlobalAccountDataEventType::PushRules, + &Raw::new(&GlobalAccountDataEvent { content: PushRulesEventContent { global: ruleset, }, }) - .expect("to json value always works"), + .expect("json serialization should always succeed") + .cast::(), )?; res diff --git a/src/service/rooms/state_cache.rs b/src/service/rooms/state_cache.rs index ea6721d9..6fe100cb 100644 --- a/src/service/rooms/state_cache.rs +++ b/src/service/rooms/state_cache.rs @@ -6,8 +6,8 @@ use std::{ use ruma::{ events::{ ignored_user_list::IgnoredUserListEvent, room::member::MembershipState, - AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, - RoomAccountDataEventType, + AnyGlobalAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, + GlobalAccountDataEventType, RoomAccountDataEventType, }, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -92,29 +92,21 @@ impl Service { self.db.mark_as_joined(user_id, room_id)?; } MembershipState::Invite => { - let event_kind = RoomAccountDataEventType::from( - GlobalAccountDataEventType::IgnoredUserList.to_string(), - ); - // We want to know if the sender is ignored by the receiver let is_ignored = services() .account_data - .get( - // Ignored users are in global account data - None, + // Ignored users are in global account data + .get_global( // Receiver user_id, - event_kind.clone(), + &GlobalAccountDataEventType::IgnoredUserList, )? .map(|event| { - serde_json::from_str::( - event.get(), - ) + event.deserialize_as::() .map_err(|error| { warn!( %error, - %event_kind, - "Invalid account data event", + "Invalid m.ignored_user_list account data event", ); Error::BadDatabase("Invalid account data event.") }) @@ -200,20 +192,18 @@ impl Service { from_room_id: &RoomId, to_room_id: &RoomId, ) -> Result<()> { - let Some(event) = services().account_data.get( - Some(from_room_id), + let Some(event) = services().account_data.get_room( + from_room_id, user_id, - RoomAccountDataEventType::Tag, + &RoomAccountDataEventType::Tag, )? else { return Ok(()); }; - let event = serde_json::from_str::(event.get()) - .expect("RawValue -> Value should always succeed"); - if let Err(error) = services().account_data.update( - Some(to_room_id), + if let Err(error) = services().account_data.update_room( + to_room_id, user_id, - RoomAccountDataEventType::Tag, + &RoomAccountDataEventType::Tag, &event, ) { warn!(%error, "error writing m.tag account data to upgraded room"); @@ -231,16 +221,15 @@ impl Service { from_room_id: &RoomId, to_room_id: &RoomId, ) -> Result<()> { - let event_kind = RoomAccountDataEventType::from( - GlobalAccountDataEventType::Direct.to_string(), - ); - let Some(event) = - services().account_data.get(None, user_id, event_kind.clone())? + let Some(event) = services() + .account_data + .get_global(user_id, &GlobalAccountDataEventType::Direct)? else { return Ok(()); }; - let mut event = serde_json::from_str::(event.get()) + let mut event = event + .deserialize_as::() .expect("RawValue -> Value should always succeed"); // As a server, we should try not to assume anything about the schema @@ -285,13 +274,14 @@ impl Service { } if event_updated { - if let Err(error) = services().account_data.update( - None, + if let Err(error) = services().account_data.update_global( user_id, - event_kind.clone(), - &event, + &GlobalAccountDataEventType::Direct, + &Raw::new(&event) + .expect("json serialization should always succeed") + .cast::(), ) { - warn!(%event_kind, %error, "error writing account data event after upgrading room"); + warn!(%error, "error writing m.direct account data event after upgrading room"); } } Ok(()) diff --git a/src/service/rooms/timeline.rs b/src/service/rooms/timeline.rs index d8cfa368..3a422830 100644 --- a/src/service/rooms/timeline.rs +++ b/src/service/rooms/timeline.rs @@ -417,19 +417,11 @@ impl Service { let rules_for_user = services() .account_data - .get( - None, - user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? + .get_global(user, &GlobalAccountDataEventType::PushRules)? .map(|event| { - serde_json::from_str::(event.get()).map_err( - |_| { - Error::bad_database( - "Invalid push rules event in db.", - ) - }, - ) + event.deserialize_as::().map_err(|_| { + Error::bad_database("Invalid push rules event in db.") + }) }) .transpose()? .map_or_else( diff --git a/src/service/sending.rs b/src/service/sending.rs index 1b7e8d01..0a5ff541 100644 --- a/src/service/sending.rs +++ b/src/service/sending.rs @@ -857,15 +857,9 @@ async fn handle_push_event( let rules_for_user = services() .account_data - .get( - None, - userid, - GlobalAccountDataEventType::PushRules.to_string().into(), - ) + .get_global(userid, &GlobalAccountDataEventType::PushRules) .unwrap_or_default() - .and_then(|event| { - serde_json::from_str::(event.get()).ok() - }) + .and_then(|event| event.deserialize_as::().ok()) .map_or_else( || push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global,