generalize get_room_version

There are other fields of `m.room.create` events that are useful to
individually extract without caring about the values of other fields.
This commit is contained in:
Charles Hall 2024-11-08 18:11:11 -08:00
parent c9c30fba30
commit a4e1522875
No known key found for this signature in database
GPG key ID: 7B8E0645816E07CF
4 changed files with 69 additions and 24 deletions

View file

@ -7,14 +7,14 @@ use std::{
use ruma::{
api::client::error::ErrorKind,
events::{
room::{create::RoomCreateEventContent, member::MembershipState},
AnyStrippedStateEvent, StateEventType, TimelineEventType,
room::member::MembershipState, AnyStrippedStateEvent, StateEventType,
TimelineEventType,
},
serde::Raw,
state_res::{self, StateMap},
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
};
use serde::Deserialize;
use serde::{de::DeserializeOwned, Deserialize};
use tracing::warn;
use super::{short::ShortStateHash, state_compressor::CompressedStateEvent};
@ -31,6 +31,26 @@ mod data;
pub(crate) use data::Data;
pub(crate) trait ExtractCreateContent: DeserializeOwned {
type Extract;
fn extract(self) -> Self::Extract;
}
/// Extract the `room_version` from an `m.room.create` event
#[derive(Deserialize)]
pub(crate) struct ExtractVersion {
room_version: RoomVersionId,
}
impl ExtractCreateContent for ExtractVersion {
type Extract = RoomVersionId;
fn extract(self) -> Self::Extract {
self.room_version
}
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
}
@ -315,22 +335,22 @@ impl Service {
self.db.set_room_state(room_id, shortstatehash)
}
/// Returns the room's version.
/// Returns the value of a field of an `m.room.create` event's `content`.
#[tracing::instrument(skip(self))]
pub(crate) fn get_room_version(
pub(crate) fn get_create_content<T: ExtractCreateContent>(
&self,
room_id: &RoomId,
) -> Result<RoomVersionId> {
) -> Result<T::Extract> {
let create_event = services().rooms.state_accessor.room_state_get(
room_id,
&StateEventType::RoomCreate,
"",
)?;
let create_event_content: RoomCreateEventContent = create_event
let content_field = create_event
.as_ref()
.map(|create_event| {
serde_json::from_str(create_event.content.get()).map_err(
serde_json::from_str::<T>(create_event.content.get()).map_err(
|error| {
warn!(%error, "Invalid create event");
Error::BadDatabase("Invalid create event in db.")
@ -345,7 +365,7 @@ impl Service {
)
})?;
Ok(create_event_content.room_version)
Ok(content_field.extract())
}
#[tracing::instrument(skip(self))]