From 891eb410cc15291d17505d41039daf39e204873b Mon Sep 17 00:00:00 2001 From: Lambda Date: Sun, 30 Mar 2025 13:15:33 +0000 Subject: [PATCH] SSS: implement state filtering --- src/api/client_server/sync/msc4186.rs | 228 ++++++++++++++++++++++++-- 1 file changed, 212 insertions(+), 16 deletions(-) diff --git a/src/api/client_server/sync/msc4186.rs b/src/api/client_server/sync/msc4186.rs index a50dda46..a3726faa 100644 --- a/src/api/client_server/sync/msc4186.rs +++ b/src/api/client_server/sync/msc4186.rs @@ -45,9 +45,111 @@ use crate::{ services, Ar, Error, Ra, Result, }; +#[derive(Debug)] +enum RequiredStateKeys { + All, + Selected(BTreeSet), +} +impl RequiredStateKeys { + fn merge(&mut self, key: String) { + match self { + RequiredStateKeys::All => { + // nothing to do, we're already getting all keys + } + RequiredStateKeys::Selected(keys) => { + if key == "*" { + *self = RequiredStateKeys::All; + } else { + keys.insert(key); + } + } + } + } +} + +#[derive(Debug)] +struct RequiredState { + /// Indicates that a `("*", "*")` tuple was present in `required_state`. + /// When `true`, all state events are sent by default, except for state + /// event types that are present in `filters`, for which only the + /// request state keys are sent. + all_events: bool, + filters: BTreeMap, +} +impl RequiredState { + fn update( + &mut self, + required_state: Vec<(StateEventType, String)>, + sender_user: &UserId, + ) { + let contains_wildcard = required_state + .iter() + .any(|(typ, key)| typ.to_string() == "*" && key == "*"); + + let mut old_filters = None; + if contains_wildcard { + if self.all_events { + // filters already contains existing negative filters, remember + // them and only apply new filters that were + // already there previously + old_filters = Some(std::mem::take(&mut self.filters)); + } else { + // clear existing positive filters + self.filters = BTreeMap::new(); + } + + self.all_events = true; + } else if self.all_events { + // all events were requested previously, don't add any additional + // positive filters + return; + } + + for (typ, mut key) in required_state { + if typ.to_string() == "*" { + continue; + } + if key == "$ME" { + key = sender_user.to_string(); + } + + if let Some(old_filters) = old_filters.as_mut() { + // re-insert the old negative filter if it matches the new + // negative filter exactly + if let Some(old_filter) = old_filters.remove(&typ) { + if let RequiredStateKeys::Selected(state_keys) = &old_filter + { + if state_keys.len() == 1 && state_keys.contains(&key) { + self.filters.insert(typ, old_filter); + } + } + } + } else { + // add the key to the filter for this event type + self.filters + .entry(typ) + .or_insert_with(|| { + RequiredStateKeys::Selected(BTreeSet::new()) + }) + .merge(key); + } + } + } + + fn matches(&self, typ: &StateEventType, key: &str) -> bool { + match self.filters.get(typ) { + Some(keys) => match keys { + RequiredStateKeys::All => true, + RequiredStateKeys::Selected(keys) => keys.contains(key), + }, + None => self.all_events, + } + } +} + #[derive(Debug)] struct TodoRoom { - required_state_request: BTreeSet<(StateEventType, String)>, + required_state: RequiredState, timeline_limit: u64, roomsince: u64, } @@ -58,8 +160,10 @@ impl TodoRoom { timeline_limit: UInt, known_rooms: &BTreeMap, room_id: &RoomId, + sender_user: &UserId, ) { - self.required_state_request.extend(required_state); + self.required_state.update(required_state, sender_user); + self.timeline_limit = self.timeline_limit.max(u64::from(timeline_limit).min(100)); // 0 means unknown because it got out of date @@ -70,7 +174,10 @@ impl TodoRoom { impl Default for TodoRoom { fn default() -> Self { Self { - required_state_request: BTreeSet::new(), + required_state: RequiredState { + all_events: false, + filters: BTreeMap::new(), + }, timeline_limit: 0, roomsince: u64::MAX, } @@ -307,6 +414,7 @@ pub(crate) async fn sync_events_v5_route( &all_room_ids, &known_rooms, &mut todo_rooms, + &sender_user, ); (list_id, rooms) }) @@ -322,6 +430,7 @@ pub(crate) async fn sync_events_v5_route( room.timeline_limit, &known_rooms, room_id, + &sender_user, ); } @@ -634,6 +743,7 @@ fn rooms_in_list( all_room_ids: &[&RoomId], known_rooms: &BTreeMap, todo_rooms: &mut BTreeMap, + sender_user: &UserId, ) -> sync_events::v5::response::List { trace!(list_id, ?list, "Collecting rooms in list"); @@ -673,6 +783,7 @@ fn rooms_in_list( list.room_details.timeline_limit, known_rooms, room_id, + sender_user, ); } } @@ -721,19 +832,104 @@ fn process_room( let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); - let required_state = todo_room - .required_state_request - .iter() - .filter_map(|state| { - services() - .rooms - .state_accessor - .room_state_get(room_id, &state.0, &state.1) - .ok() - .flatten() - .map(|state| state.to_sync_state_event()) - }) - .collect(); + let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + else { + error!(%room_id, "Room has no state"); + return Ok(None); + }; + + let need_scan = todo_room.required_state.all_events + || todo_room + .required_state + .filters + .iter() + .any(|(_, keys)| matches!(keys, RequiredStateKeys::All)); + let required_state = if need_scan { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(current_shortstatehash)? + .pop() + .expect("there is always one layer") + .full_state; + full_state + .iter() + .filter_map(|compressed| { + let Ok((typ, key)) = services() + .rooms + .short + .get_statekey_from_short(compressed.state) + else { + warn!( + ?compressed, + "Failed to get info for shortstatekey, skipping" + ); + return None; + }; + + if !todo_room.required_state.matches(&typ, &key) { + return None; + } + + let shorteventid = compressed.event; + let pdu = match services() + .rooms + .short + .get_eventid_from_short(shorteventid) + { + Ok(event_id) => { + services().rooms.timeline.get_pdu(&event_id) + } + Err(error) => { + warn!( + %error, + %typ, + key, + ?shorteventid, + "Failed to get event ID from short event ID" + ); + return None; + } + }; + match pdu { + Ok(Some(pdu)) => Some(pdu.to_sync_state_event()), + Ok(None) => None, + Err(error) => { + warn!(%error, %typ, key, "Failed to get state PDU"); + None + } + } + }) + .collect() + } else { + todo_room + .required_state + .filters + .iter() + .flat_map(|(typ, keys)| { + let RequiredStateKeys::Selected(keys) = keys else { + panic!( + "wildcard key should have triggered a full state scan" + ); + }; + keys.iter().filter_map(move |key| { + match services().rooms.state_accessor.state_get( + current_shortstatehash, + typ, + key, + ) { + Ok(Some(pdu)) => Some(pdu.to_sync_state_event()), + Ok(None) => None, + Err(error) => { + warn!(%error, %typ, key, "Failed to get state PDU"); + None + } + } + }) + }) + .collect() + }; // Heroes let heroes = services()