mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 15:21:24 +01:00
234 lines
7.3 KiB
Rust
234 lines
7.3 KiB
Rust
use std::{
|
|
collections::{BTreeSet, HashSet},
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use lru_cache::LruCache;
|
|
use ruma::{api::client::error::ErrorKind, EventId, RoomId};
|
|
use tracing::{debug, error, warn};
|
|
|
|
use super::short::ShortEventId;
|
|
use crate::{
|
|
observability::{FoundIn, Lookup, METRICS},
|
|
services,
|
|
utils::debug_slice_truncated,
|
|
Error, Result,
|
|
};
|
|
|
|
mod data;
|
|
|
|
pub(crate) use data::Data;
|
|
|
|
pub(crate) struct Service {
|
|
db: &'static dyn Data,
|
|
#[allow(clippy::type_complexity)]
|
|
auth_chain_cache:
|
|
Option<Mutex<LruCache<Vec<ShortEventId>, Arc<HashSet<ShortEventId>>>>>,
|
|
}
|
|
|
|
impl Service {
|
|
pub(crate) fn new(
|
|
db: &'static dyn Data,
|
|
auth_chain_cache_size: usize,
|
|
) -> Self {
|
|
Self {
|
|
db,
|
|
auth_chain_cache: (auth_chain_cache_size > 0)
|
|
.then(|| Mutex::new(LruCache::new(auth_chain_cache_size))),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn get_cached_eventid_authchain(
|
|
&self,
|
|
key: &[ShortEventId],
|
|
) -> Result<Option<Arc<HashSet<ShortEventId>>>> {
|
|
let lookup = Lookup::AuthChain;
|
|
|
|
if let Some(cache) = &self.auth_chain_cache {
|
|
if let Some(result) = cache.lock().unwrap().get_mut(key) {
|
|
METRICS.record_lookup(lookup, FoundIn::Cache);
|
|
return Ok(Some(Arc::clone(result)));
|
|
}
|
|
}
|
|
|
|
let Some(chain) = self.db.get_cached_eventid_authchain(key)? else {
|
|
METRICS.record_lookup(lookup, FoundIn::Nothing);
|
|
return Ok(None);
|
|
};
|
|
|
|
METRICS.record_lookup(lookup, FoundIn::Database);
|
|
let chain = Arc::new(chain);
|
|
|
|
if let Some(cache) = &self.auth_chain_cache {
|
|
cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
|
|
}
|
|
|
|
Ok(Some(chain))
|
|
}
|
|
|
|
#[tracing::instrument(skip(self))]
|
|
pub(crate) fn cache_auth_chain(
|
|
&self,
|
|
key: Vec<ShortEventId>,
|
|
auth_chain: Arc<HashSet<ShortEventId>>,
|
|
) -> Result<()> {
|
|
self.db.cache_auth_chain(&key, &auth_chain)?;
|
|
if let Some(cache) = &self.auth_chain_cache {
|
|
cache.lock().unwrap().insert(key, auth_chain);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(
|
|
skip(self, starting_events),
|
|
fields(starting_events = debug_slice_truncated(&starting_events, 5)),
|
|
)]
|
|
pub(crate) async fn get_auth_chain<'a>(
|
|
&self,
|
|
room_id: &RoomId,
|
|
starting_events: Vec<Arc<EventId>>,
|
|
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
|
const NUM_BUCKETS: usize = 50;
|
|
|
|
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
|
|
|
|
let mut i = 0;
|
|
for id in starting_events {
|
|
let short =
|
|
services().rooms.short.get_or_create_shorteventid(&id)?;
|
|
// I'm afraid to change this in case there is accidental reliance on
|
|
// the truncation
|
|
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
|
|
let bucket_id = (short.get() % NUM_BUCKETS as u64) as usize;
|
|
buckets[bucket_id].insert((short, id.clone()));
|
|
i += 1;
|
|
if i % 100 == 0 {
|
|
tokio::task::yield_now().await;
|
|
}
|
|
}
|
|
|
|
let mut full_auth_chain = HashSet::new();
|
|
|
|
let mut hits = 0;
|
|
let mut misses = 0;
|
|
for chunk in buckets {
|
|
if chunk.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let chunk_key: Vec<_> =
|
|
chunk.iter().map(|(short, _)| short).copied().collect();
|
|
if let Some(cached) =
|
|
self.get_cached_eventid_authchain(&chunk_key)?
|
|
{
|
|
hits += 1;
|
|
full_auth_chain.extend(cached.iter().copied());
|
|
continue;
|
|
}
|
|
misses += 1;
|
|
|
|
let mut chunk_cache = HashSet::new();
|
|
let mut hits2 = 0;
|
|
let mut misses2 = 0;
|
|
let mut i = 0;
|
|
for (sevent_id, event_id) in chunk {
|
|
if let Some(cached) =
|
|
self.get_cached_eventid_authchain(&[sevent_id])?
|
|
{
|
|
hits2 += 1;
|
|
chunk_cache.extend(cached.iter().copied());
|
|
} else {
|
|
misses2 += 1;
|
|
let auth_chain = Arc::new(
|
|
self.get_auth_chain_inner(room_id, &event_id)?,
|
|
);
|
|
self.cache_auth_chain(
|
|
vec![sevent_id],
|
|
Arc::clone(&auth_chain),
|
|
)?;
|
|
debug!(
|
|
event_id = ?event_id,
|
|
chain_length = ?auth_chain.len(),
|
|
"Cache missed event"
|
|
);
|
|
chunk_cache.extend(auth_chain.iter());
|
|
|
|
i += 1;
|
|
if i % 100 == 0 {
|
|
tokio::task::yield_now().await;
|
|
}
|
|
}
|
|
}
|
|
debug!(
|
|
chunk_cache_length = ?chunk_cache.len(),
|
|
hits = ?hits2,
|
|
misses = ?misses2,
|
|
"Chunk missed",
|
|
);
|
|
let chunk_cache = Arc::new(chunk_cache);
|
|
self.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
|
full_auth_chain.extend(chunk_cache.iter());
|
|
}
|
|
|
|
debug!(
|
|
chain_length = ?full_auth_chain.len(),
|
|
hits = ?hits,
|
|
misses = ?misses,
|
|
"Auth chain stats",
|
|
);
|
|
|
|
Ok(full_auth_chain.into_iter().filter_map(move |sid| {
|
|
services().rooms.short.get_eventid_from_short(sid).ok()
|
|
}))
|
|
}
|
|
|
|
#[tracing::instrument(skip(self))]
|
|
fn get_auth_chain_inner(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_id: &EventId,
|
|
) -> Result<HashSet<ShortEventId>> {
|
|
let mut todo = vec![Arc::from(event_id)];
|
|
let mut found = HashSet::new();
|
|
|
|
while let Some(event_id) = todo.pop() {
|
|
match services().rooms.timeline.get_pdu(&event_id) {
|
|
Ok(Some(pdu)) => {
|
|
if pdu.room_id != room_id {
|
|
warn!(bad_room_id = %pdu.room_id, "Event referenced in auth chain has incorrect room id");
|
|
return Err(Error::BadRequest(
|
|
ErrorKind::forbidden(),
|
|
"Event has incorrect room id",
|
|
));
|
|
}
|
|
for auth_event in &pdu.auth_events {
|
|
let sauthevent = services()
|
|
.rooms
|
|
.short
|
|
.get_or_create_shorteventid(auth_event)?;
|
|
|
|
if !found.contains(&sauthevent) {
|
|
found.insert(sauthevent);
|
|
todo.push(auth_event.clone());
|
|
}
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
warn!(
|
|
?event_id,
|
|
"Could not find pdu mentioned in auth events"
|
|
);
|
|
}
|
|
Err(error) => {
|
|
error!(
|
|
?event_id,
|
|
?error,
|
|
"Could not load event in auth chain"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(found)
|
|
}
|
|
}
|