SSS: implement state filtering

This commit is contained in:
Lambda 2025-03-30 13:15:33 +00:00
parent ede95dcee5
commit 891eb410cc

View file

@ -45,9 +45,111 @@ use crate::{
services, Ar, Error, Ra, Result,
};
#[derive(Debug)]
enum RequiredStateKeys {
All,
Selected(BTreeSet<String>),
}
impl RequiredStateKeys {
fn merge(&mut self, key: String) {
match self {
RequiredStateKeys::All => {
// nothing to do, we're already getting all keys
}
RequiredStateKeys::Selected(keys) => {
if key == "*" {
*self = RequiredStateKeys::All;
} else {
keys.insert(key);
}
}
}
}
}
#[derive(Debug)]
struct RequiredState {
/// Indicates that a `("*", "*")` tuple was present in `required_state`.
/// When `true`, all state events are sent by default, except for state
/// event types that are present in `filters`, for which only the
/// request state keys are sent.
all_events: bool,
filters: BTreeMap<StateEventType, RequiredStateKeys>,
}
impl RequiredState {
fn update(
&mut self,
required_state: Vec<(StateEventType, String)>,
sender_user: &UserId,
) {
let contains_wildcard = required_state
.iter()
.any(|(typ, key)| typ.to_string() == "*" && key == "*");
let mut old_filters = None;
if contains_wildcard {
if self.all_events {
// filters already contains existing negative filters, remember
// them and only apply new filters that were
// already there previously
old_filters = Some(std::mem::take(&mut self.filters));
} else {
// clear existing positive filters
self.filters = BTreeMap::new();
}
self.all_events = true;
} else if self.all_events {
// all events were requested previously, don't add any additional
// positive filters
return;
}
for (typ, mut key) in required_state {
if typ.to_string() == "*" {
continue;
}
if key == "$ME" {
key = sender_user.to_string();
}
if let Some(old_filters) = old_filters.as_mut() {
// re-insert the old negative filter if it matches the new
// negative filter exactly
if let Some(old_filter) = old_filters.remove(&typ) {
if let RequiredStateKeys::Selected(state_keys) = &old_filter
{
if state_keys.len() == 1 && state_keys.contains(&key) {
self.filters.insert(typ, old_filter);
}
}
}
} else {
// add the key to the filter for this event type
self.filters
.entry(typ)
.or_insert_with(|| {
RequiredStateKeys::Selected(BTreeSet::new())
})
.merge(key);
}
}
}
fn matches(&self, typ: &StateEventType, key: &str) -> bool {
match self.filters.get(typ) {
Some(keys) => match keys {
RequiredStateKeys::All => true,
RequiredStateKeys::Selected(keys) => keys.contains(key),
},
None => self.all_events,
}
}
}
#[derive(Debug)]
struct TodoRoom {
required_state_request: BTreeSet<(StateEventType, String)>,
required_state: RequiredState,
timeline_limit: u64,
roomsince: u64,
}
@ -58,8 +160,10 @@ impl TodoRoom {
timeline_limit: UInt,
known_rooms: &BTreeMap<OwnedRoomId, u64>,
room_id: &RoomId,
sender_user: &UserId,
) {
self.required_state_request.extend(required_state);
self.required_state.update(required_state, sender_user);
self.timeline_limit =
self.timeline_limit.max(u64::from(timeline_limit).min(100));
// 0 means unknown because it got out of date
@ -70,7 +174,10 @@ impl TodoRoom {
impl Default for TodoRoom {
fn default() -> Self {
Self {
required_state_request: BTreeSet::new(),
required_state: RequiredState {
all_events: false,
filters: BTreeMap::new(),
},
timeline_limit: 0,
roomsince: u64::MAX,
}
@ -307,6 +414,7 @@ pub(crate) async fn sync_events_v5_route(
&all_room_ids,
&known_rooms,
&mut todo_rooms,
&sender_user,
);
(list_id, rooms)
})
@ -322,6 +430,7 @@ pub(crate) async fn sync_events_v5_route(
room.timeline_limit,
&known_rooms,
room_id,
&sender_user,
);
}
@ -634,6 +743,7 @@ fn rooms_in_list(
all_room_ids: &[&RoomId],
known_rooms: &BTreeMap<OwnedRoomId, u64>,
todo_rooms: &mut BTreeMap<OwnedRoomId, TodoRoom>,
sender_user: &UserId,
) -> sync_events::v5::response::List {
trace!(list_id, ?list, "Collecting rooms in list");
@ -673,6 +783,7 @@ fn rooms_in_list(
list.room_details.timeline_limit,
known_rooms,
room_id,
sender_user,
);
}
}
@ -721,19 +832,104 @@ fn process_room(
let room_events: Vec<_> =
timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect();
let required_state = todo_room
.required_state_request
.iter()
.filter_map(|state| {
services()
.rooms
.state_accessor
.room_state_get(room_id, &state.0, &state.1)
.ok()
.flatten()
.map(|state| state.to_sync_state_event())
})
.collect();
let Some(current_shortstatehash) =
services().rooms.state.get_room_shortstatehash(room_id)?
else {
error!(%room_id, "Room has no state");
return Ok(None);
};
let need_scan = todo_room.required_state.all_events
|| todo_room
.required_state
.filters
.iter()
.any(|(_, keys)| matches!(keys, RequiredStateKeys::All));
let required_state = if need_scan {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(current_shortstatehash)?
.pop()
.expect("there is always one layer")
.full_state;
full_state
.iter()
.filter_map(|compressed| {
let Ok((typ, key)) = services()
.rooms
.short
.get_statekey_from_short(compressed.state)
else {
warn!(
?compressed,
"Failed to get info for shortstatekey, skipping"
);
return None;
};
if !todo_room.required_state.matches(&typ, &key) {
return None;
}
let shorteventid = compressed.event;
let pdu = match services()
.rooms
.short
.get_eventid_from_short(shorteventid)
{
Ok(event_id) => {
services().rooms.timeline.get_pdu(&event_id)
}
Err(error) => {
warn!(
%error,
%typ,
key,
?shorteventid,
"Failed to get event ID from short event ID"
);
return None;
}
};
match pdu {
Ok(Some(pdu)) => Some(pdu.to_sync_state_event()),
Ok(None) => None,
Err(error) => {
warn!(%error, %typ, key, "Failed to get state PDU");
None
}
}
})
.collect()
} else {
todo_room
.required_state
.filters
.iter()
.flat_map(|(typ, keys)| {
let RequiredStateKeys::Selected(keys) = keys else {
panic!(
"wildcard key should have triggered a full state scan"
);
};
keys.iter().filter_map(move |key| {
match services().rooms.state_accessor.state_get(
current_shortstatehash,
typ,
key,
) {
Ok(Some(pdu)) => Some(pdu.to_sync_state_event()),
Ok(None) => None,
Err(error) => {
warn!(%error, %typ, key, "Failed to get state PDU");
None
}
}
})
})
.collect()
};
// Heroes
let heroes = services()