diff --git a/src/api/client_server/sync/msc4186.rs b/src/api/client_server/sync/msc4186.rs index f4e579bf..a50dda46 100644 --- a/src/api/client_server/sync/msc4186.rs +++ b/src/api/client_server/sync/msc4186.rs @@ -10,20 +10,38 @@ use std::{ use ruma::{ api::client::{ - sync::sync_events::{self, DeviceLists, UnreadNotificationsCount}, + sync::sync_events::{ + self, v5::request::ListFilters, DeviceLists, + UnreadNotificationsCount, + }, uiaa::UiaaResponse, }, events::{ - room::member::{MembershipState, RoomMemberEventContent}, - StateEventType, TimelineEventType, + direct::DirectEventContent, + room::{ + create::RoomCreateEventContent, + encryption::PossiblyRedactedRoomEncryptionEventContent, + member::{MembershipState, RoomMemberEventContent}, + }, + AnyStrippedStateEvent, PossiblyRedactedStateEventContent, + StateEventType, StrippedStateEvent, TimelineEventType, }, + room::RoomType, + serde::Raw, uint, JsOption, OwnedRoomId, RoomId, UInt, UserId, }; +use serde::de::DeserializeOwned; use tracing::{debug, error, field, trace, warn}; use super::{load_timeline, share_encrypted_room}; use crate::{ - service::{account_data, rooms::timeline::PduCount, users::ConnectionKey}, + service::{ + account_data, + rooms::{ + short::ShortStateHash, state::ExtractType, timeline::PduCount, + }, + users::ConnectionKey, + }, services, Ar, Error, Ra, Result, }; @@ -59,6 +77,163 @@ impl Default for TodoRoom { } } +fn is_dm_room(user: &UserId, room: &RoomId) -> Result { + let Some(event) = + services().account_data.get_global::(user)? + else { + return Ok(false); + }; + + let event = event + .deserialize() + .map_err(|_| Error::bad_database("Invalid m.direct event"))?; + + Ok(event.values().flatten().any(|r| r == room)) +} + +fn is_encrypted_room(current_shortstatehash: ShortStateHash) -> Result { + Ok(services() + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? + .is_some()) +} + +fn get_invite_state( + invite_state: &[Raw], +) -> Option> +where + T: PossiblyRedactedStateEventContent + DeserializeOwned, +{ + invite_state + .iter() + .find_map(|ev| ev.deserialize_as::>().ok()) +} + +#[derive(Debug)] +struct RoomData { + id: OwnedRoomId, + current_shortstatehash: ShortStateHash, + is_dm: bool, + is_encrypted: bool, + is_invite: bool, + room_type: Option, +} + +impl RoomData { + #[tracing::instrument] + fn new( + id: OwnedRoomId, + user: &UserId, + invite_state: Option<&[Raw]>, + ) -> Result { + let current_shortstatehash = services() + .rooms + .state + .get_room_shortstatehash(&id)? + .ok_or_else(|| Error::bad_database("Room has no state"))?; + + let room_type = if let Some(invite_state) = &invite_state { + get_invite_state::(invite_state) + .and_then(|e| e.content.room_type) + } else { + services().rooms.state.get_create_content::(&id)? + }; + + let is_dm = match is_dm_room(user, &id) { + Ok(x) => x, + Err(error) => { + error!(%error, %user, "Invalid m.direct account data event"); + false + } + }; + let is_encrypted = if let Some(invite_state) = &invite_state { + get_invite_state::( + invite_state, + ) + .is_some() + } else { + is_encrypted_room(current_shortstatehash)? + }; + let is_invite = invite_state.is_some(); + + Ok(Self { + id, + current_shortstatehash, + is_dm, + is_encrypted, + is_invite, + room_type, + }) + } + + #[tracing::instrument(skip(self), fields(room_id = self.id.as_str()))] + fn matches_filter(&self, filter_data: &ListFilters) -> Result { + if let Some(is_dm) = filter_data.is_dm { + if self.is_dm != is_dm { + return Ok(false); + } + } + if let Some(is_encrypted) = filter_data.is_encrypted { + if self.is_encrypted != is_encrypted { + return Ok(false); + } + } + if let Some(is_invite) = filter_data.is_invite { + if self.is_invite != is_invite { + return Ok(false); + } + } + + let room_type = self.room_type.clone().into(); + if filter_data.not_room_types.contains(&room_type) { + return Ok(false); + } + if !filter_data.room_types.is_empty() + && !filter_data.room_types.contains(&room_type) + { + return Ok(false); + } + + Ok(true) + } +} + +#[tracing::instrument(skip_all)] +fn joined_rooms_data(sender_user: &UserId) -> Vec { + services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(Result::ok) + .filter_map(move |id| { + RoomData::new(id.clone(), sender_user, None) + .inspect_err(|error| { + error!(%error, room_id = %id, "Failed to get data for room, skipping"); + }) + .ok() + }).collect() +} + +#[tracing::instrument(skip_all)] +fn invited_rooms_data(sender_user: &UserId) -> Vec { + services() + .rooms + .state_cache + .rooms_invited(sender_user) + .filter_map(Result::ok) + .filter_map(move |(id, invite_state)| { + RoomData::new(id.clone(), sender_user, Some(&invite_state)) + .inspect_err(|error| { + error!( + %error, room_id = %id, "Failed to get data for room, skipping" + ); + }) + .ok() + }) + .collect() +} + #[allow(clippy::too_many_lines)] #[tracing::instrument(skip_all, fields( pos, @@ -96,12 +271,7 @@ pub(crate) async fn sync_events_v5_route( let known_rooms = services().users.get_rooms_in_connection(connection_key.clone()); - let all_joined_rooms = services() - .rooms - .state_cache - .rooms_joined(&sender_user) - .filter_map(Result::ok) - .collect::>(); + let all_joined_rooms = joined_rooms_data(&sender_user); if body.extensions.to_device.enabled.unwrap_or(false) { services().users.remove_to_device_events( @@ -118,21 +288,27 @@ pub(crate) async fn sync_events_v5_route( None }; - // and required state + let mut all_rooms = all_joined_rooms; + all_rooms.extend(invited_rooms_data(&sender_user)); + + let all_room_ids: Vec<_> = all_rooms.iter().map(|r| r.id.clone()).collect(); + let all_room_ids: Vec<_> = all_room_ids.iter().map(|id| &**id).collect(); + let mut todo_rooms: BTreeMap = BTreeMap::new(); let lists = body .lists .into_iter() - .filter_map(|(list_id, list)| { + .map(|(list_id, list)| { let rooms = rooms_in_list( &list_id, list, - &all_joined_rooms, + &all_rooms, + &all_room_ids, &known_rooms, &mut todo_rooms, - )?; - Some((list_id, rooms)) + ); + (list_id, rooms) }) .collect(); @@ -241,7 +417,7 @@ pub(crate) async fn sync_events_v5_route( async fn get_e2ee_data( sender_user: &UserId, globalsince: u64, - all_joined_rooms: &[OwnedRoomId], + all_joined_rooms: &[RoomData], ) -> Result { // Users that have left any encrypted rooms the sender was in let mut left_encrypted_users = HashSet::new(); @@ -253,14 +429,13 @@ async fn get_e2ee_data( .filter_map(Result::ok) .collect(); - for room_id in all_joined_rooms { - let Some(current_shortstatehash) = - services().rooms.state.get_room_shortstatehash(room_id)? - else { - error!(%room_id, "Room has no state"); - continue; - }; - + for RoomData { + id: room_id, + current_shortstatehash, + is_encrypted, + .. + } in all_joined_rooms + { let since_shortstatehash = services() .rooms .user @@ -288,19 +463,9 @@ async fn get_e2ee_data( .ok() }); - let encrypted_room = services() - .rooms - .state_accessor - .state_get( - current_shortstatehash, - &StateEventType::RoomEncryption, - "", - )? - .is_some(); - if let Some(since_shortstatehash) = since_shortstatehash { // Skip if there are only timeline changes - if since_shortstatehash == current_shortstatehash { + if since_shortstatehash == *current_shortstatehash { continue; } @@ -316,12 +481,12 @@ async fn get_e2ee_data( }); let new_encrypted_room = - encrypted_room && since_encryption.is_none(); - if encrypted_room { + *is_encrypted && since_encryption.is_none(); + if *is_encrypted { let current_state_ids = services() .rooms .state_accessor - .state_full_ids(current_shortstatehash) + .state_full_ids(*current_shortstatehash) .await?; let since_state_ids = services() .rooms @@ -465,45 +630,59 @@ async fn get_e2ee_data( fn rooms_in_list( list_id: &str, list: sync_events::v5::request::List, - all_joined_rooms: &[OwnedRoomId], + all_rooms: &[RoomData], + all_room_ids: &[&RoomId], known_rooms: &BTreeMap, todo_rooms: &mut BTreeMap, -) -> Option { +) -> sync_events::v5::response::List { trace!(list_id, ?list, "Collecting rooms in list"); - if list.filters.and_then(|f| f.is_invite).unwrap_or(false) { - return None; + let matching_room_ids_buf: Vec<&RoomId>; + let matching_room_ids = if let Some(filters) = list.filters.as_ref() { + matching_room_ids_buf = all_rooms + .iter() + .filter_map(|r| { + match r.matches_filter(filters) { + Ok(pass) => pass.then_some(&*r.id), + Err(error) => { + warn!(%error, ?filters, room_id=r.id.as_str(), "Failed to evaluate list filter, skipping room"); + None + } + } + }) + .collect(); + matching_room_ids_buf.as_slice() + } else { + all_room_ids + }; + + if !matching_room_ids.is_empty() { + let mut list_room_ids: BTreeSet<&RoomId> = BTreeSet::new(); + for (from, to) in list.ranges { + let from = usize::try_from(from) + .unwrap_or(usize::MAX) + .clamp(0, matching_room_ids.len() - 1); + let to = usize::try_from(to) + .unwrap_or(usize::MAX) + .clamp(from, matching_room_ids.len() - 1); + list_room_ids.extend(&matching_room_ids[from..=to]); + } + for room_id in list_room_ids { + todo_rooms.entry(room_id.to_owned()).or_default().update( + list.room_details.required_state.clone(), + list.room_details.timeline_limit, + known_rooms, + room_id, + ); + } } - let mut list_room_ids: BTreeSet = BTreeSet::new(); - for (mut from, mut to) in list.ranges { - from = from.clamp( - uint!(0), - UInt::try_from(all_joined_rooms.len() - 1).unwrap_or(UInt::MAX), - ); - to = to.clamp( - from, - UInt::try_from(all_joined_rooms.len() - 1).unwrap_or(UInt::MAX), - ); - let room_ids = all_joined_rooms[from.try_into().unwrap_or(usize::MAX) - ..=to.try_into().unwrap_or(usize::MAX)] - .to_vec(); - list_room_ids.extend(room_ids); - } - for room_id in &list_room_ids { - todo_rooms.entry(room_id.clone()).or_default().update( - list.room_details.required_state.clone(), - list.room_details.timeline_limit, - known_rooms, - room_id, - ); - } - let num_rooms = list_room_ids.len(); + let num_rooms = matching_room_ids.len(); trace!(list_id, num_rooms, "Done collecting rooms"); - Some(sync_events::v5::response::List { + sync_events::v5::response::List { count: UInt::try_from(num_rooms).unwrap_or(UInt::MAX), - }) + } } #[allow(clippy::too_many_lines)]