add type-safe accessors to account_data service

This commit is contained in:
Olivia Lee 2025-03-23 15:17:19 -07:00
parent b82458a460
commit 88ad596e8d
No known key found for this signature in database
GPG key ID: 54D568A15B9CD1F9
12 changed files with 202 additions and 179 deletions

View file

@ -4,7 +4,9 @@ use ruma::{
events::{
AnyGlobalAccountDataEvent, AnyGlobalAccountDataEventContent,
AnyRoomAccountDataEvent, AnyRoomAccountDataEventContent,
GlobalAccountDataEventType, RoomAccountDataEventType,
GlobalAccountDataEventContent, GlobalAccountDataEventType,
RoomAccountDataEventContent, RoomAccountDataEventType,
StaticEventContent,
},
serde::Raw,
RoomId, UserId,
@ -101,9 +103,29 @@ impl Service {
}
/// Places one event in the global account data of the user and removes the
/// previous entry.
/// previous entry, with a static event type.
#[tracing::instrument(skip(self, user_id, content))]
pub(crate) fn update_global(
pub(crate) fn update_global<T>(
&self,
user_id: &UserId,
content: &Raw<T>,
) -> Result<()>
where
T: GlobalAccountDataEventContent + StaticEventContent,
{
let event_type = T::TYPE.into();
let content = content.cast_ref::<AnyGlobalAccountDataEventContent>();
let event = raw_global_event_from_parts(&event_type, content);
self.db.update(None, user_id, &event_type.to_string(), event.json())
}
/// Places one event in the global account data of the user and removes the
/// previous entry, with a dynamic event type.
///
/// If the event type is known statically, [`Service::update_global`] should
/// be perferred for better type-safety.
#[tracing::instrument(skip(self, user_id, content))]
pub(crate) fn update_global_any(
&self,
user_id: &UserId,
event_type: &GlobalAccountDataEventType,
@ -114,9 +136,35 @@ impl Service {
}
/// Places one event in the room account data of the user and removes the
/// previous entry for that room.
/// previous entry for that room, with a static event type.
#[tracing::instrument(skip(self, room_id, user_id, content))]
pub(crate) fn update_room(
pub(crate) fn update_room<T>(
&self,
room_id: &RoomId,
user_id: &UserId,
content: &Raw<T>,
) -> Result<()>
where
T: RoomAccountDataEventContent + StaticEventContent,
{
let event_type = T::TYPE.into();
let content = content.cast_ref::<AnyRoomAccountDataEventContent>();
let event = raw_room_event_from_parts(&event_type, content);
self.db.update(
Some(room_id),
user_id,
&event_type.to_string(),
event.json(),
)
}
/// Places one event in the room account data of the user and removes the
/// previous entry for that room, with a dynamic event type.
///
/// If the event type is known statically, [`Service::update_room`] should
/// be perferred for better type-safety.
#[tracing::instrument(skip(self, room_id, user_id, content))]
pub(crate) fn update_room_any(
&self,
room_id: &RoomId,
user_id: &UserId,
@ -132,9 +180,31 @@ impl Service {
)
}
/// Searches the global account data for a specific kind.
/// Searches the global account data for a specific static event type.
#[tracing::instrument(skip(self, user_id))]
pub(crate) fn get_global<T>(
&self,
user_id: &UserId,
) -> Result<Option<Raw<T>>>
where
T: GlobalAccountDataEventContent + StaticEventContent,
{
let Some(event) = self.db.get(None, user_id, T::TYPE)? else {
return Ok(None);
};
let event = Raw::<AnyGlobalAccountDataEvent>::from_json(event);
let (_, content) = raw_global_event_to_parts(&event).map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
Ok(Some(content.cast::<T>()))
}
/// Searches the global account data for a specific dynamic event type.
///
/// If the event type is known statically, [`Service::get_global`] should
/// be perferred for better type-safety.
#[tracing::instrument(skip(self, user_id, event_type))]
pub(crate) fn get_global(
pub(crate) fn get_global_any(
&self,
user_id: &UserId,
event_type: &GlobalAccountDataEventType,
@ -151,9 +221,32 @@ impl Service {
Ok(Some(content))
}
/// Searches the room account data for a specific kind.
/// Searches the room account data for a specific static event type.
#[tracing::instrument(skip(self, room_id, user_id))]
pub(crate) fn get_room<T>(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<Raw<T>>>
where
T: RoomAccountDataEventContent + StaticEventContent,
{
let Some(event) = self.db.get(Some(room_id), user_id, T::TYPE)? else {
return Ok(None);
};
let event = Raw::<AnyRoomAccountDataEvent>::from_json(event);
let (_, content) = raw_room_event_to_parts(&event).map_err(|_| {
Error::bad_database("Invalid account data event in db.")
})?;
Ok(Some(content.cast::<T>()))
}
/// Searches the room account data for a specific dynamic event type.
///
/// If the event type is known statically, [`Service::get_room`] should
/// be perferred for better type-safety.
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
pub(crate) fn get_room(
pub(crate) fn get_room_any(
&self,
room_id: &RoomId,
user_id: &UserId,

View file

@ -773,15 +773,9 @@ impl Service {
// Initial account data
services().account_data.update_global(
&user_id,
&ruma::events::GlobalAccountDataEventType::PushRules,
&Raw::new(
&PushRulesEventContent {
global: ruma::push::Ruleset::server_default(
&user_id,
),
}
.into(),
)
&Raw::new(&PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id),
})
.expect("json serialization should always succeed"),
)?;

View file

@ -18,7 +18,7 @@ use ruma::{
api::federation::discovery::ServerSigningKeys,
events::{
push_rules::PushRulesEventContent,
room::message::RoomMessageEventContent, GlobalAccountDataEventType,
room::message::RoomMessageEventContent,
},
push::Ruleset,
serde::{Base64, Raw},
@ -493,13 +493,9 @@ impl Service {
services().account_data.update_global(
admin_bot,
&GlobalAccountDataEventType::PushRules,
&Raw::new(
&PushRulesEventContent {
global: ruleset,
}
.into(),
)
&Raw::new(&PushRulesEventContent {
global: ruleset,
})
.expect("json serialization should always succeed"),
)?;

View file

@ -5,10 +5,10 @@ use std::{
use ruma::{
events::{
direct::DirectEventContent,
ignored_user_list::IgnoredUserListEventContent,
room::member::MembershipState, AnyGlobalAccountDataEventContent,
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType,
room::member::MembershipState, tag::TagEventContent,
AnyStrippedStateEvent, AnySyncStateEvent,
},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
@ -96,14 +96,9 @@ impl Service {
// We want to know if the sender is ignored by the receiver
let is_ignored = services()
.account_data
// Ignored users are in global account data
.get_global(
// Receiver
user_id,
&GlobalAccountDataEventType::IgnoredUserList,
)?
.get_global::<IgnoredUserListEventContent>(user_id)?
.map(|event| {
event.deserialize_as::<IgnoredUserListEventContent>()
event.deserialize()
.map_err(|error| {
warn!(
%error,
@ -192,20 +187,15 @@ impl Service {
from_room_id: &RoomId,
to_room_id: &RoomId,
) -> Result<()> {
let Some(event) = services().account_data.get_room(
from_room_id,
user_id,
&RoomAccountDataEventType::Tag,
)?
let Some(event) = services()
.account_data
.get_room::<TagEventContent>(from_room_id, user_id)?
else {
return Ok(());
};
if let Err(error) = services().account_data.update_room(
to_room_id,
user_id,
&RoomAccountDataEventType::Tag,
&event,
) {
if let Err(error) =
services().account_data.update_room(to_room_id, user_id, &event)
{
warn!(%error, "error writing m.tag account data to upgraded room");
}
@ -223,7 +213,7 @@ impl Service {
) -> Result<()> {
let Some(event_content) = services()
.account_data
.get_global(user_id, &GlobalAccountDataEventType::Direct)?
.get_global::<DirectEventContent>(user_id)?
else {
return Ok(());
};
@ -270,10 +260,9 @@ impl Service {
if event_updated {
if let Err(error) = services().account_data.update_global(
user_id,
&GlobalAccountDataEventType::Direct,
&Raw::new(&event_content)
.expect("json serialization should always succeed")
.cast::<AnyGlobalAccountDataEventContent>(),
.cast::<DirectEventContent>(),
) {
warn!(%error, "error writing m.direct account data event after upgrading room");
}

View file

@ -16,7 +16,7 @@ use ruma::{
power_levels::RoomPowerLevelsEventContent,
redaction::RoomRedactionEventContent,
},
GlobalAccountDataEventType, StateEventType, TimelineEventType,
StateEventType, TimelineEventType,
},
push::{Action, Ruleset, Tweak},
state_res::{self, Event},
@ -417,15 +417,11 @@ impl Service {
let rules_for_user = services()
.account_data
.get_global(user, &GlobalAccountDataEventType::PushRules)?
.get_global::<PushRulesEventContent>(user)?
.map(|event| {
event.deserialize_as::<PushRulesEventContent>().map_err(
|_| {
Error::bad_database(
"Invalid push rules event in db.",
)
},
)
event.deserialize().map_err(|_| {
Error::bad_database("Invalid push rules event in db.")
})
})
.transpose()?
.map_or_else(|| Ruleset::server_default(user), |ev| ev.global);

View file

@ -23,7 +23,7 @@ use ruma::{
device_id,
events::{
push_rules::PushRulesEventContent, receipt::ReceiptType,
AnySyncEphemeralRoomEvent, GlobalAccountDataEventType,
AnySyncEphemeralRoomEvent,
},
push,
serde::Raw,
@ -857,11 +857,9 @@ async fn handle_push_event(
let rules_for_user = services()
.account_data
.get_global(userid, &GlobalAccountDataEventType::PushRules)
.get_global::<PushRulesEventContent>(userid)
.unwrap_or_default()
.and_then(|event| {
event.deserialize_as::<PushRulesEventContent>().ok()
})
.and_then(|event| event.deserialize().ok())
.map_or_else(
|| push::Ruleset::server_default(userid),
|ev| ev.global,