From 2c19abc5359b628d3157b7c86864ecdebc22e1a4 Mon Sep 17 00:00:00 2001 From: Lambda Date: Mon, 27 May 2024 18:12:02 +0000 Subject: [PATCH] Add OnDemandHashMap and TokenSet --- src/observability.rs | 26 ++- src/utils.rs | 1 + src/utils/on_demand_hashmap.rs | 304 +++++++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 src/utils/on_demand_hashmap.rs diff --git a/src/observability.rs b/src/observability.rs index 7add9ce4..177001aa 100644 --- a/src/observability.rs +++ b/src/observability.rs @@ -1,7 +1,7 @@ //! Facilities for observing runtime behavior #![warn(missing_docs, clippy::missing_docs_in_private_items)] -use std::{collections::HashSet, fs::File, io::BufWriter}; +use std::{collections::HashSet, fs::File, io::BufWriter, sync::Arc}; use axum::{ extract::{MatchedPath, Request}, @@ -269,6 +269,10 @@ pub(crate) struct Metrics { /// Counts where data is found from lookup: opentelemetry::metrics::Counter, + + /// Number of entries in an + /// [`OnDemandHashMap`](crate::utils::on_demand_hashmap::OnDemandHashMap) + on_demand_hashmap_size: opentelemetry::metrics::Gauge, } impl Metrics { @@ -319,10 +323,16 @@ impl Metrics { .with_description("Counts where data is found from") .init(); + let on_demand_hashmap_size = meter + .u64_gauge("on_demand_hashmap_size") + .with_description("Number of entries in OnDemandHashMap") + .init(); + Metrics { otel_state: (registry, provider), http_requests_histogram, lookup, + on_demand_hashmap_size, } } @@ -343,6 +353,20 @@ impl Metrics { ], ); } + + /// Record size of [`OnDemandHashMap`] + /// + /// [`OnDemandHashMap`]: crate::utils::on_demand_hashmap::OnDemandHashMap + pub(crate) fn record_on_demand_hashmap_size( + &self, + name: Arc, + size: usize, + ) { + self.on_demand_hashmap_size.record( + size.try_into().unwrap_or(u64::MAX), + &[KeyValue::new("name", name)], + ); + } } /// Track HTTP metrics by converting this into an [`axum`] layer diff --git a/src/utils.rs b/src/utils.rs index 3bd0319b..52956bf6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,5 @@ pub(crate) mod error; +pub(crate) mod on_demand_hashmap; use std::{ borrow::Cow, diff --git a/src/utils/on_demand_hashmap.rs b/src/utils/on_demand_hashmap.rs new file mode 100644 index 00000000..69e8bac9 --- /dev/null +++ b/src/utils/on_demand_hashmap.rs @@ -0,0 +1,304 @@ +use std::{ + collections::HashMap, + fmt, + hash::Hash, + marker::PhantomData, + ops::Deref, + sync::{Arc, Weak}, +}; + +use tokio::sync::{mpsc, Mutex, OwnedMutexGuard, RwLock}; +use tracing::{trace, warn, Level}; + +use crate::observability::METRICS; + +/// Data shared between [`OnDemandHashMap`] and the cleanup task +/// +/// Importantly it does not contain the `cleanup_sender`, since it getting +/// dropped signals the cleanup task to exit. If the cleanup task had an owned +/// reference to it, the only way for it to exit would be for every [`Entry`] to +/// be dropped, we don't want to rely on that. +struct SharedData { + name: Arc, + /// Values are owned by their [entries][Entry] + entries: RwLock>>, +} + +impl SharedData +where + K: Hash + Eq + Clone + fmt::Debug, +{ + #[tracing::instrument( + level = Level::TRACE, + skip(self), + fields(name = self.name.as_ref()), + )] + async fn try_cleanup_entry(&self, key: K) { + let mut map = self.entries.write().await; + + let Some(weak) = map.get(&key) else { + trace!("Entry has already been cleaned up"); + return; + }; + + if weak.strong_count() != 0 { + trace!("Entry is in use"); + return; + } + + trace!("Cleaning up unused entry"); + map.remove(&key); + METRICS.record_on_demand_hashmap_size(self.name.clone(), map.len()); + } + + #[tracing::instrument(level = Level::TRACE, skip(map))] + fn try_get_live_value( + pass: usize, + map: &HashMap>, + key: &K, + ) -> Option> { + if let Some(value) = map.get(key) { + if let Some(value) = value.upgrade() { + trace!(pass, "Using existing value"); + return Some(value); + } + + trace!( + pass, + "Existing value is stale and needs cleanup, creating new" + ); + } else { + trace!(pass, "No existing value, creating new"); + } + + None + } + + /// Either returns an existing live value, or creates a new one and inserts + /// it into the map. + #[tracing::instrument(level = Level::TRACE, skip(self, create))] + async fn get_or_insert_with(&self, key: &K, create: F) -> Arc + where + F: FnOnce() -> V, + { + { + // first, take a read lock and try to get an existing value + + // TODO check if this fast path actually makes it faster, possibly + // make it configurable per OnDemandHashMap depending on contention + // and how expensive create() is + let map = self.entries.read().await; + if let Some(v) = Self::try_get_live_value(1, &map, key) { + return v; + } + } + + // no entry or it has died, create a new one + let value = Arc::new(create()); + let weak = Arc::downgrade(&value); + + // take a write lock, try again, otherwise insert our new value + let mut map = self.entries.write().await; + if let Some(v) = Self::try_get_live_value(2, &map, key) { + // another entry showed up while we had let go of the lock, + // use that + drop(value); + drop(weak); + return v; + } + + map.insert(key.clone(), weak); + METRICS.record_on_demand_hashmap_size(self.name.clone(), map.len()); + + value + } +} + +/// A [`HashMap`] whose entries are automatically removed once they are no +/// longer referenced. +pub(crate) struct OnDemandHashMap { + /// The data shared between the [`OnDemandHashMap`] and the cleanup task. + shared: Arc>, + /// This is the only non-[weak][mpsc::WeakUnboundedSender] `Sender`, which + /// means that dropping the `OnDemandHashMap` causes the cleanup + /// process to exit. + cleanup_sender: mpsc::UnboundedSender, +} + +impl OnDemandHashMap +where + K: Hash + Eq + Clone + fmt::Debug + Send + Sync + 'static, + V: Send + Sync + 'static, +{ + /// Creates a new `OnDemandHashMap`. The `name` is used for metrics and + /// should be unique to this instance. + pub(crate) fn new(name: String) -> Self { + let (cleanup_sender, mut receiver) = mpsc::unbounded_channel(); + + let shared = Arc::new(SharedData { + name: name.into(), + entries: RwLock::new(HashMap::new()), + }); + + { + let shared = Arc::clone(&shared); + tokio::task::spawn(async move { + loop { + let Some(key) = receiver.recv().await else { + trace!( + name = shared.name.as_ref(), + "Channel has died, exiting cleanup task" + ); + return; + }; + + shared.try_cleanup_entry(key).await; + } + }); + } + + Self { + shared, + cleanup_sender, + } + } + + #[tracing::instrument(level = Level::TRACE, skip(self, create))] + pub(crate) async fn get_or_insert_with( + &self, + key: K, + create: F, + ) -> Entry + where + F: FnOnce() -> V, + { + let value = self.shared.get_or_insert_with(&key, create).await; + + Entry { + drop_guard: EntryDropGuard { + cleanup_sender: self.cleanup_sender.downgrade(), + key: Some(key), + }, + value, + } + } +} + +struct EntryDropGuard { + cleanup_sender: mpsc::WeakUnboundedSender, + /// Only `None` during `drop()` + key: Option, +} + +impl Drop for EntryDropGuard { + fn drop(&mut self) { + let Some(cleanup_sender) = self.cleanup_sender.upgrade() else { + trace!("Backing map has already been dropped"); + return; + }; + + if let Err(error) = cleanup_sender + .send(self.key.take().expect("drop should only be called once")) + { + warn!(%error, "Failed to send cleanup message"); + }; + } +} + +/// A wrapper around a key/value pair inside an [`OnDemandHashMap`] +/// +/// If every `Entry` for a specific key is dropped, the value is removed from +/// the map. +pub(crate) struct Entry { + drop_guard: EntryDropGuard, + value: Arc, +} + +impl Deref for Entry { + type Target = V; + + fn deref(&self) -> &Self::Target { + self.value.as_ref() + } +} + +/// Internal zero-sized type used to swallow the [`TokenSet`]'s marker type +struct TokenMarker(PhantomData T>); + +/// A collection of dynamically-created locks, one for each value of `K`. +/// +/// A given key can be locked using [`TokenSet::lock_key()`], which will either +/// return an ownership token immediately if the key is not currently locked, or +/// wait until the previous lock has been released. +/// +/// The marker type `M` can be used to disambiguate different `TokenSet` +/// instances to avoid misuse of tokens. +pub(crate) struct TokenSet { + inner: OnDemandHashMap>>, +} + +impl TokenSet +where + K: Hash + Eq + Clone + fmt::Debug + Send + Sync + 'static, + M: 'static, +{ + /// Creates a new `TokenSet`. The `name` is used for metrics and should be + /// unique to this instance. + pub(crate) fn new(name: String) -> Self { + Self { + inner: OnDemandHashMap::new(name), + } + } + + /// Locks this key in the `TokenSet`, returning a token proving + /// unique access. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) async fn lock_key(&self, key: K) -> KeyToken { + let Entry { + drop_guard, + value, + } = self + .inner + .get_or_insert_with(key, || Mutex::new(TokenMarker(PhantomData))) + .await; + + KeyToken { + drop_guard, + _mutex_guard: value.lock_owned().await, + } + } +} + +/// Unique token for a given key in a [`TokenSet`]. +/// +/// Ownership of this token proves that no other [`KeyToken`] for this key in +/// this [`TokenSet`] currently exists. +/// +/// Access to the underlying key is provided by a [`Deref`] impl. +pub(crate) struct KeyToken { + drop_guard: EntryDropGuard, + _mutex_guard: OwnedMutexGuard>, +} + +impl Deref for KeyToken { + type Target = K; + + fn deref(&self) -> &Self::Target { + self.drop_guard + .key + .as_ref() + .expect("key should only be None during Drop") + } +} + +impl fmt::Debug for KeyToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", &**self) + } +} + +impl fmt::Display for KeyToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", &**self) + } +}