use std::{ collections::{BTreeSet, HashSet}, sync::{Arc, Mutex}, }; use lru_cache::LruCache; use ruma::{EventId, RoomId, api::client::error::ErrorKind}; use tracing::{debug, error, warn}; use super::short::ShortEventId; use crate::{ Error, Result, observability::{FoundIn, Lookup, METRICS}, services, utils::debug_slice_truncated, }; mod data; pub(crate) use data::Data; pub(crate) struct Service { db: &'static dyn Data, #[allow(clippy::type_complexity)] auth_chain_cache: Option, Arc>>>>, } impl Service { pub(crate) fn new( db: &'static dyn Data, auth_chain_cache_size: usize, ) -> Self { Self { db, auth_chain_cache: (auth_chain_cache_size > 0) .then(|| Mutex::new(LruCache::new(auth_chain_cache_size))), } } pub(crate) fn get_cached_eventid_authchain( &self, key: &[ShortEventId], ) -> Result>>> { let lookup = Lookup::AuthChain; if let Some(cache) = &self.auth_chain_cache { if let Some(result) = cache.lock().unwrap().get_mut(key) { METRICS.record_lookup(lookup, FoundIn::Cache); return Ok(Some(Arc::clone(result))); } } let Some(chain) = self.db.get_cached_eventid_authchain(key)? else { METRICS.record_lookup(lookup, FoundIn::Nothing); return Ok(None); }; METRICS.record_lookup(lookup, FoundIn::Database); let chain = Arc::new(chain); if let Some(cache) = &self.auth_chain_cache { cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain)); } Ok(Some(chain)) } #[tracing::instrument(skip(self))] pub(crate) fn cache_auth_chain( &self, key: Vec, auth_chain: Arc>, ) -> Result<()> { self.db.cache_auth_chain(&key, &auth_chain)?; if let Some(cache) = &self.auth_chain_cache { cache.lock().unwrap().insert(key, auth_chain); } Ok(()) } #[tracing::instrument( skip(self, starting_events), fields(starting_events = debug_slice_truncated(&starting_events, 5)), )] pub(crate) async fn get_auth_chain<'a>( &self, room_id: &RoomId, starting_events: Vec>, ) -> Result> + 'a> { const NUM_BUCKETS: usize = 50; let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; let mut i = 0; for id in starting_events { let short = services().rooms.short.get_or_create_shorteventid(&id)?; // I'm afraid to change this in case there is accidental reliance on // the truncation #[allow(clippy::as_conversions, clippy::cast_possible_truncation)] let bucket_id = (short.get() % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; if i % 100 == 0 { tokio::task::yield_now().await; } } let mut full_auth_chain = HashSet::new(); let mut hits = 0; let mut misses = 0; for chunk in buckets { if chunk.is_empty() { continue; } let chunk_key: Vec<_> = chunk.iter().map(|(short, _)| short).copied().collect(); if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; } misses += 1; let mut chunk_cache = HashSet::new(); let mut hits2 = 0; let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; let auth_chain = Arc::new( self.get_auth_chain_inner(room_id, &event_id)?, ); self.cache_auth_chain( vec![sevent_id], Arc::clone(&auth_chain), )?; debug!( event_id = ?event_id, chain_length = ?auth_chain.len(), "Cache missed event" ); chunk_cache.extend(auth_chain.iter()); i += 1; if i % 100 == 0 { tokio::task::yield_now().await; } }; } debug!( chunk_cache_length = ?chunk_cache.len(), hits = ?hits2, misses = ?misses2, "Chunk missed", ); let chunk_cache = Arc::new(chunk_cache); self.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; full_auth_chain.extend(chunk_cache.iter()); } debug!( chain_length = ?full_auth_chain.len(), hits = ?hits, misses = ?misses, "Auth chain stats", ); Ok(full_auth_chain.into_iter().filter_map(move |sid| { services().rooms.short.get_eventid_from_short(sid).ok() })) } #[tracing::instrument(skip(self))] fn get_auth_chain_inner( &self, room_id: &RoomId, event_id: &EventId, ) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { match services().rooms.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { warn!(bad_room_id = %pdu.room_id, "Event referenced in auth chain has incorrect room id"); return Err(Error::BadRequest( ErrorKind::forbidden(), "Event has incorrect room id", )); } for auth_event in &pdu.auth_events { let sauthevent = services() .rooms .short .get_or_create_shorteventid(auth_event)?; if !found.contains(&sauthevent) { found.insert(sauthevent); todo.push(auth_event.clone()); } } } Ok(None) => { warn!( ?event_id, "Could not find pdu mentioned in auth events" ); } Err(error) => { error!( ?event_id, ?error, "Could not load event in auth chain" ); } } } Ok(found) } }