Properly type stored EDUs

This commit is contained in:
Lambda 2024-08-26 17:10:43 +00:00
parent 26322d5a95
commit cce83beedb
3 changed files with 29 additions and 20 deletions

View file

@ -5,6 +5,7 @@ use ruma::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
federation::{self, transactions::edu::DirectDeviceContent}, federation::{self, transactions::edu::DirectDeviceContent},
}, },
serde::Raw,
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
}; };
@ -40,7 +41,7 @@ pub(crate) async fn send_event_to_device_route(
services().sending.send_reliable_edu( services().sending.send_reliable_edu(
target_user_id.server_name(), target_user_id.server_name(),
serde_json::to_vec( Raw::new(
&federation::transactions::edu::Edu::DirectToDevice( &federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent { DirectDeviceContent {
sender: sender_user.clone(), sender: sender_user.clone(),

View file

@ -1,4 +1,4 @@
use ruma::{ServerName, UserId}; use ruma::{serde::Raw, ServerName, UserId};
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
@ -69,7 +69,7 @@ impl service::sending::Data for KeyValueDatabase {
); );
} }
let value = if let SendingEventType::Edu(value) = &event { let value = if let SendingEventType::Edu(value) = &event {
&**value value.json().get().as_bytes()
} else { } else {
&[] &[]
}; };
@ -100,7 +100,7 @@ impl service::sending::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
for (e, key) in events { for (e, key) in events {
let value = if let SendingEventType::Edu(value) = &e { let value = if let SendingEventType::Edu(value) = &e {
&**value value.json().get().as_bytes()
} else { } else {
&[] &[]
}; };
@ -205,7 +205,13 @@ fn parse_servercurrentevent(
if value.is_empty() { if value.is_empty() {
SendingEventType::Pdu(PduId::new(event.to_vec())) SendingEventType::Pdu(PduId::new(event.to_vec()))
} else { } else {
SendingEventType::Edu(value) SendingEventType::Edu(
Raw::from_json_string(
String::from_utf8(value)
.expect("EDU content in database should be a string"),
)
.expect("EDU content in database should be valid JSON"),
)
}, },
)) ))
} }

View file

@ -28,8 +28,10 @@ use ruma::{
push_rules::PushRulesEvent, receipt::ReceiptType, push_rules::PushRulesEvent, receipt::ReceiptType,
AnySyncEphemeralRoomEvent, GlobalAccountDataEventType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType,
}, },
push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, push,
ServerName, UInt, UserId, serde::Raw,
uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName,
UInt, UserId,
}; };
use tokio::{ use tokio::{
select, select,
@ -81,12 +83,12 @@ impl Destination {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug)]
pub(crate) enum SendingEventType { pub(crate) enum SendingEventType {
// pduid // pduid
Pdu(PduId), Pdu(PduId),
// pdu json // pdu json
Edu(Vec<u8>), Edu(Raw<Edu>),
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
@ -444,7 +446,7 @@ impl Service {
pub(crate) fn select_edus( pub(crate) fn select_edus(
&self, &self,
server_name: &ServerName, server_name: &ServerName,
) -> Result<(Vec<Vec<u8>>, u64)> { ) -> Result<(Vec<Raw<Edu>>, u64)> {
// u64: count of last edu // u64: count of last edu
let since = self.db.get_latest_educount(server_name)?; let since = self.db.get_latest_educount(server_name)?;
let mut events = Vec::new(); let mut events = Vec::new();
@ -532,7 +534,7 @@ impl Service {
}; };
events.push( events.push(
serde_json::to_vec(&federation_event) Raw::new(&federation_event)
.expect("json can be serialized"), .expect("json can be serialized"),
); );
@ -555,9 +557,7 @@ impl Service {
keys: None, keys: None,
}); });
events.push( events.push(Raw::new(&edu).expect("json can be serialized"));
serde_json::to_vec(&edu).expect("json can be serialized"),
);
} }
Ok((events, max_edu_count)) Ok((events, max_edu_count))
@ -622,7 +622,7 @@ impl Service {
pub(crate) fn send_reliable_edu( pub(crate) fn send_reliable_edu(
&self, &self,
server: &ServerName, server: &ServerName,
serialized: Vec<u8>, serialized: Raw<Edu>,
id: u64, id: u64,
) -> Result<()> { ) -> Result<()> {
let destination = Destination::Normal(server.to_owned()); let destination = Destination::Normal(server.to_owned());
@ -759,7 +759,9 @@ async fn handle_appservice_event(
&events &events
.iter() .iter()
.map(|e| match e { .map(|e| match e {
SendingEventType::Edu(b) => &**b, SendingEventType::Edu(b) => {
b.json().get().as_bytes()
}
SendingEventType::Pdu(b) => b.as_bytes(), SendingEventType::Pdu(b) => b.as_bytes(),
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@ -885,9 +887,7 @@ async fn handle_federation_event(
)); ));
} }
SendingEventType::Edu(edu) => { SendingEventType::Edu(edu) => {
if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(edu.clone());
edu_jsons.push(raw);
}
} }
} }
} }
@ -906,7 +906,9 @@ async fn handle_federation_event(
&events &events
.iter() .iter()
.map(|e| match e { .map(|e| match e {
SendingEventType::Edu(b) => &**b, SendingEventType::Edu(b) => {
b.json().get().as_bytes()
}
SendingEventType::Pdu(b) => b.as_bytes(), SendingEventType::Pdu(b) => b.as_bytes(),
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>(),