diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 993aedd1..bfd338a6 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1372,7 +1372,7 @@ pub(crate) async fn invite_helper( ) })?; - let pdu_id: Vec = services() + let pdu_id = services() .rooms .event_handler .handle_incoming_pdu( diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index 458d4ae2..cd36a13b 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -84,7 +84,9 @@ pub(crate) async fn search_events_route( if let Some(s) = searches .iter_mut() .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) + .max_by_key(|(peek, _)| { + peek.as_ref().map(|id| id.as_bytes().to_vec()) + }) .and_then(|(_, i)| i.next()) { results.push(s); diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 52b58f90..74f7a1b0 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1622,7 +1622,7 @@ async fn create_join_event( .roomid_mutex_federation .lock_key(room_id.to_owned()) .await; - let pdu_id: Vec = services() + let pdu_id = services() .rooms .event_handler .handle_incoming_pdu( diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index b890985f..32689bd4 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -4,7 +4,10 @@ use ruma::{EventId, RoomId, UserId}; use crate::{ database::KeyValueDatabase, - service::{self, rooms::timeline::PduCount}, + service::{ + self, + rooms::timeline::{PduCount, PduId}, + }, services, utils, Error, PduEvent, Result, }; @@ -50,6 +53,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { let mut pduid = shortroomid.to_be_bytes().to_vec(); pduid.extend_from_slice(&from.to_be_bytes()); + let pduid = PduId::new(pduid); let mut pdu = services() .rooms diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 5fb90462..fa48965d 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,6 +1,10 @@ use ruma::RoomId; -use crate::{database::KeyValueDatabase, service, services, utils, Result}; +use crate::{ + database::KeyValueDatabase, + service::{self, rooms::timeline::PduId}, + services, utils, Result, +}; /// Splits a string into tokens used as keys in the search inverted index /// @@ -18,7 +22,7 @@ impl service::rooms::search::Data for KeyValueDatabase { fn index_pdu( &self, shortroomid: u64, - pdu_id: &[u8], + pdu_id: &PduId, message_body: &str, ) -> Result<()> { let mut batch = tokenize(message_body).map(|word| { @@ -26,7 +30,7 @@ impl service::rooms::search::Data for KeyValueDatabase { key.extend_from_slice(word.as_bytes()); key.push(0xFF); // TODO: currently we save the room id a second time here - key.extend_from_slice(pdu_id); + key.extend_from_slice(pdu_id.as_bytes()); (key, Vec::new()) }); @@ -37,7 +41,7 @@ impl service::rooms::search::Data for KeyValueDatabase { fn deindex_pdu( &self, shortroomid: u64, - pdu_id: &[u8], + pdu_id: &PduId, message_body: &str, ) -> Result<()> { let batch = tokenize(message_body).map(|word| { @@ -45,7 +49,7 @@ impl service::rooms::search::Data for KeyValueDatabase { key.extend_from_slice(word.as_bytes()); key.push(0xFF); // TODO: currently we save the room id a second time here - key.extend_from_slice(pdu_id); + key.extend_from_slice(pdu_id.as_bytes()); key }); @@ -62,7 +66,7 @@ impl service::rooms::search::Data for KeyValueDatabase { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a>, Vec)>> + ) -> Result + 'a>, Vec)>> { let prefix = services() .rooms @@ -87,12 +91,14 @@ impl service::rooms::search::Data for KeyValueDatabase { // Newest pdus first .iter_from(&last_possible_id, true) .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) + .map(move |(key, _)| PduId::new(key[prefix3.len()..].to_vec())) }); // We compare b with a because we reversed the iterator earlier let Some(common_elements) = - utils::common_elements(iterators, |a, b| b.cmp(a)) + utils::common_elements(iterators, |a, b| { + b.as_bytes().cmp(a.as_bytes()) + }) else { return Ok(None); }; diff --git a/src/database/key_value/rooms/threads.rs b/src/database/key_value/rooms/threads.rs index 9915198e..bb9a1712 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/database/key_value/rooms/threads.rs @@ -6,8 +6,9 @@ use ruma::{ }; use crate::{ - database::KeyValueDatabase, service, services, utils, Error, PduEvent, - Result, + database::KeyValueDatabase, + service::{self, rooms::timeline::PduId}, + services, utils, Error, PduEvent, Result, }; impl service::rooms::threads::Data for KeyValueDatabase { @@ -42,6 +43,9 @@ impl service::rooms::threads::Data for KeyValueDatabase { "Invalid pduid in threadid_userids.", ) })?; + + let pduid = PduId::new(pduid); + let mut pdu = services() .rooms .timeline @@ -61,7 +65,7 @@ impl service::rooms::threads::Data for KeyValueDatabase { fn update_participants( &self, - root_id: &[u8], + root_id: &PduId, participants: &[OwnedUserId], ) -> Result<()> { let users = participants @@ -70,16 +74,16 @@ impl service::rooms::threads::Data for KeyValueDatabase { .collect::>() .join(&[0xFF][..]); - self.threadid_userids.insert(root_id, &users)?; + self.threadid_userids.insert(root_id.as_bytes(), &users)?; Ok(()) } fn get_participants( &self, - root_id: &[u8], + root_id: &PduId, ) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { + if let Some(users) = self.threadid_userids.get(root_id.as_bytes())? { Ok(Some( users .split(|b| *b == 0xFF) diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 0633e1d1..951975e5 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -10,7 +10,8 @@ use tracing::error; use crate::{ database::KeyValueDatabase, observability::{FoundIn, Lookup, METRICS}, - service, services, utils, Error, PduEvent, Result, + service::{self, rooms::timeline::PduId}, + services, utils, Error, PduEvent, Result, }; impl service::rooms::timeline::Data for KeyValueDatabase { @@ -102,8 +103,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.eventid_pduid.get(event_id.as_bytes()).map(|x| x.map(PduId::new)) } /// Returns the pdu. @@ -170,8 +171,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + fn get_pdu_from_id(&self, pdu_id: &PduId) -> Result> { + self.pduid_pdu.get(pdu_id.as_bytes())?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) .map_err(|_| Error::bad_database("Invalid PDU in db."))?, @@ -182,9 +183,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { /// Returns the pdu as a `BTreeMap`. fn get_pdu_json_from_id( &self, - pdu_id: &[u8], + pdu_id: &PduId, ) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + self.pduid_pdu.get(pdu_id.as_bytes())?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) .map_err(|_| Error::bad_database("Invalid PDU in db."))?, @@ -194,13 +195,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase { fn append_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, ) -> Result<()> { self.pduid_pdu.insert( - pdu_id, + pdu_id.as_bytes(), &serde_json::to_vec(json) .expect("CanonicalJsonObject is always a valid"), )?; @@ -210,7 +211,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .unwrap() .insert(pdu.room_id.clone(), PduCount::Normal(count)); - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; + self.eventid_pduid + .insert(pdu.event_id.as_bytes(), pdu_id.as_bytes())?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; Ok(()) @@ -218,17 +220,17 @@ impl service::rooms::timeline::Data for KeyValueDatabase { fn prepend_backfill_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, event_id: &EventId, json: &CanonicalJsonObject, ) -> Result<()> { self.pduid_pdu.insert( - pdu_id, + pdu_id.as_bytes(), &serde_json::to_vec(json) .expect("CanonicalJsonObject is always a valid"), )?; - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id.as_bytes())?; self.eventid_outlierpdu.remove(event_id.as_bytes())?; Ok(()) @@ -237,13 +239,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase { /// Removes a pdu and creates a new one with the same id. fn replace_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent, ) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { + if self.pduid_pdu.get(pdu_id.as_bytes())?.is_some() { self.pduid_pdu.insert( - pdu_id, + pdu_id.as_bytes(), &serde_json::to_vec(pdu_json) .expect("CanonicalJsonObject is always a valid"), )?; diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 89c8c7ca..c620b40f 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -4,6 +4,7 @@ use crate::{ database::KeyValueDatabase, service::{ self, + rooms::timeline::PduId, sending::{Destination, RequestKey, SendingEventType}, }, services, utils, Error, Result, @@ -61,7 +62,7 @@ impl service::sending::Data for KeyValueDatabase { for (destination, event) in requests { let mut key = destination.get_prefix(); if let SendingEventType::Pdu(value) = &event { - key.extend_from_slice(value); + key.extend_from_slice(value.as_bytes()); } else { key.extend_from_slice( &services().globals.next_count()?.to_be_bytes(), @@ -202,7 +203,7 @@ fn parse_servercurrentevent( Ok(( destination, if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) + SendingEventType::Pdu(PduId::new(event.to_vec())) } else { SendingEventType::Edu(value) }, diff --git a/src/service/rooms/event_handler.rs b/src/service/rooms/event_handler.rs index 62c6470b..6bf29e8e 100644 --- a/src/service/rooms/event_handler.rs +++ b/src/service/rooms/event_handler.rs @@ -37,7 +37,7 @@ use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore}; use tracing::{debug, error, info, trace, warn}; -use super::state_compressor::CompressedStateEvent; +use super::{state_compressor::CompressedStateEvent, timeline::PduId}; use crate::{ service::{globals::SigningKeys, pdu}, services, @@ -89,7 +89,7 @@ impl Service { value: CanonicalJsonObject, is_timeline_event: bool, pub_key_map: &'a RwLock>, - ) -> Result>> { + ) -> Result> { // 0. Check the server is in the room if !services().rooms.metadata.exists(room_id)? { return Err(Error::BadRequest( @@ -565,7 +565,7 @@ impl Service { origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>, - ) -> Result>> { + ) -> Result> { // Skip the PDU if we already have it as a timeline event if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index efc503fe..515b7d04 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,19 +1,19 @@ use ruma::RoomId; -use crate::Result; +use crate::{service::rooms::timeline::PduId, Result}; pub(crate) trait Data: Send + Sync { fn index_pdu( &self, shortroomid: u64, - pdu_id: &[u8], + pdu_id: &PduId, message_body: &str, ) -> Result<()>; fn deindex_pdu( &self, shortroomid: u64, - pdu_id: &[u8], + pdu_id: &PduId, message_body: &str, ) -> Result<()>; @@ -22,5 +22,5 @@ pub(crate) trait Data: Send + Sync { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a>, Vec)>>; + ) -> Result + 'a>, Vec)>>; } diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 8a1607db..384c23c8 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -3,7 +3,7 @@ use ruma::{ UserId, }; -use crate::{PduEvent, Result}; +use crate::{service::rooms::timeline::PduId, PduEvent, Result}; pub(crate) trait Data: Send + Sync { #[allow(clippy::type_complexity)] @@ -17,11 +17,11 @@ pub(crate) trait Data: Send + Sync { fn update_participants( &self, - root_id: &[u8], + root_id: &PduId, participants: &[OwnedUserId], ) -> Result<()>; fn get_participants( &self, - root_id: &[u8], + root_id: &PduId, ) -> Result>>; } diff --git a/src/service/rooms/timeline.rs b/src/service/rooms/timeline.rs index bb268255..8df9c410 100644 --- a/src/service/rooms/timeline.rs +++ b/src/service/rooms/timeline.rs @@ -44,6 +44,23 @@ use crate::{ Error, PduEvent, Result, }; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct PduId { + inner: Vec, +} + +impl PduId { + pub(crate) fn new(inner: Vec) -> Self { + Self { + inner, + } + } + + pub(crate) fn as_bytes(&self) -> &[u8] { + &self.inner + } +} + #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] pub(crate) enum PduCount { Backfilled(u64), @@ -146,7 +163,7 @@ impl Service { pub(crate) fn get_pdu_id( &self, event_id: &EventId, - ) -> Result>> { + ) -> Result> { self.db.get_pdu_id(event_id) } @@ -165,7 +182,7 @@ impl Service { /// This does __NOT__ check the outliers `Tree`. pub(crate) fn get_pdu_from_id( &self, - pdu_id: &[u8], + pdu_id: &PduId, ) -> Result> { self.db.get_pdu_from_id(pdu_id) } @@ -173,7 +190,7 @@ impl Service { /// Returns the pdu as a `BTreeMap`. pub(crate) fn get_pdu_json_from_id( &self, - pdu_id: &[u8], + pdu_id: &PduId, ) -> Result> { self.db.get_pdu_json_from_id(pdu_id) } @@ -182,7 +199,7 @@ impl Service { #[tracing::instrument(skip(self))] pub(crate) fn replace_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent, ) -> Result<()> { @@ -202,7 +219,7 @@ impl Service { mut pdu_json: CanonicalJsonObject, leaves: Vec, room_id: &KeyToken, - ) -> Result> { + ) -> Result { assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed"); let shortroomid = services() @@ -282,6 +299,7 @@ impl Service { let count2 = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); + let pdu_id = PduId::new(pdu_id); // Insert pdu self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; @@ -1106,7 +1124,7 @@ impl Service { state_ids_compressed: Arc>, soft_fail: bool, room_id: &KeyToken, - ) -> Result>> { + ) -> Result> { assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed"); // We append to state before appending the pdu, so we don't have a @@ -1344,6 +1362,7 @@ impl Service { let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); + let pdu_id = PduId::new(pdu_id); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 5ea5ae03..acd0b33c 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; use super::PduCount; -use crate::{PduEvent, Result}; +use crate::{service::rooms::timeline::PduId, PduEvent, Result}; pub(crate) trait Data: Send + Sync { fn last_timeline_count( @@ -28,7 +28,7 @@ pub(crate) trait Data: Send + Sync { ) -> Result>; /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>>; + fn get_pdu_id(&self, event_id: &EventId) -> Result>; /// Returns the pdu. /// @@ -46,18 +46,18 @@ pub(crate) trait Data: Send + Sync { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; + fn get_pdu_from_id(&self, pdu_id: &PduId) -> Result>; /// Returns the pdu as a `BTreeMap`. fn get_pdu_json_from_id( &self, - pdu_id: &[u8], + pdu_id: &PduId, ) -> Result>; /// Adds a new pdu to the timeline fn append_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, @@ -66,7 +66,7 @@ pub(crate) trait Data: Send + Sync { // Adds a new pdu to the backfilled timeline fn prepend_backfill_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, event_id: &EventId, json: &CanonicalJsonObject, ) -> Result<()>; @@ -74,7 +74,7 @@ pub(crate) trait Data: Send + Sync { /// Removes a pdu and creates a new one with the same id. fn replace_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent, ) -> Result<()>; diff --git a/src/service/sending.rs b/src/service/sending.rs index 832380c9..9aa1d684 100644 --- a/src/service/sending.rs +++ b/src/service/sending.rs @@ -37,6 +37,7 @@ use tokio::{ }; use tracing::{debug, error, warn, Span}; +use super::rooms::timeline::PduId; use crate::{ api::{appservice_server, server_server}, services, @@ -83,7 +84,7 @@ impl Destination { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) enum SendingEventType { // pduid - Pdu(Vec), + Pdu(PduId), // pdu json Edu(Vec), } @@ -565,7 +566,7 @@ impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey))] pub(crate) fn send_push_pdu( &self, - pdu_id: &[u8], + pdu_id: &PduId, user: &UserId, pushkey: String, ) -> Result<()> { @@ -589,7 +590,7 @@ impl Service { pub(crate) fn send_pdu>( &self, servers: I, - pdu_id: &[u8], + pdu_id: &PduId, ) -> Result<()> { let requests = servers .into_iter() @@ -644,7 +645,7 @@ impl Service { pub(crate) fn send_pdu_appservice( &self, appservice_id: String, - pdu_id: Vec, + pdu_id: PduId, ) -> Result<()> { let destination = Destination::Appservice(appservice_id); let event_type = SendingEventType::Pdu(pdu_id); @@ -758,8 +759,8 @@ async fn handle_appservice_event( &events .iter() .map(|e| match e { - SendingEventType::Edu(b) - | SendingEventType::Pdu(b) => &**b, + SendingEventType::Edu(b) => &**b, + SendingEventType::Pdu(b) => b.as_bytes(), }) .collect::>(), )) @@ -905,8 +906,8 @@ async fn handle_federation_event( &events .iter() .map(|e| match e { - SendingEventType::Edu(b) - | SendingEventType::Pdu(b) => &**b, + SendingEventType::Edu(b) => &**b, + SendingEventType::Pdu(b) => b.as_bytes(), }) .collect::>(), )) diff --git a/src/utils.rs b/src/utils.rs index c1e2d6a6..42c15f88 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -113,14 +113,14 @@ pub(crate) fn calculate_hash(keys: &[&[u8]]) -> Vec { hash.as_ref().to_owned() } -pub(crate) fn common_elements( +pub(crate) fn common_elements( mut iterators: I, check_order: F, -) -> Option>> +) -> Option> where I: Iterator, - I::Item: Iterator>, - F: Fn(&[u8], &[u8]) -> Ordering, + I::Item: Iterator, + F: Fn(&T, &T) -> Ordering, { let first_iterator = iterators.next()?; let mut other_iterators =