From 029e32971e964260999a9140be52d64790f7d610 Mon Sep 17 00:00:00 2001 From: Charles Hall Date: Fri, 14 Jun 2024 22:30:24 -0700 Subject: [PATCH] add basic typed key value store abstraction Should eliminate a few classes of footgun when it's done. --- Cargo.lock | 53 ++++++++ Cargo.toml | 1 + src/database.rs | 1 + src/database/map.rs | 178 ++++++++++++++++++++++++++ src/database/map/tests.rs | 117 +++++++++++++++++ src/database/map/tests/conversions.rs | 105 +++++++++++++++ 6 files changed, 455 insertions(+) create mode 100644 src/database/map.rs create mode 100644 src/database/map/tests.rs create mode 100644 src/database/map/tests/conversions.rs diff --git a/Cargo.lock b/Cargo.lock index 29beb49a..3fa73d53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,6 +713,58 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "frunk" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11a351b59e12f97b4176ee78497dff72e4276fb1ceb13e19056aca7fa0206287" +dependencies = [ + "frunk_core", + "frunk_derives", + "frunk_proc_macros", +] + +[[package]] +name = "frunk_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af2469fab0bd07e64ccf0ad57a1438f63160c69b2e57f04a439653d68eb558d6" + +[[package]] +name = "frunk_derives" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fa992f1656e1707946bbba340ad244f0814009ef8c0118eb7b658395f19a2e" +dependencies = [ + "frunk_proc_macro_helpers", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macro_helpers" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35b54add839292b743aeda6ebedbd8b11e93404f902c56223e51b9ec18a13d2c" +dependencies = [ + "frunk_core", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macros" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71b85a1d4a9a6b300b41c05e8e13ef2feca03e0334127f29eca9506a7fe13a93" +dependencies = [ + "frunk_core", + "frunk_proc_macro_helpers", + "quote", + "syn", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -840,6 +892,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "frunk", "futures-util", "hmac", "html-escape", diff --git a/Cargo.toml b/Cargo.toml index d7682651..ea96389c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,6 +95,7 @@ axum-server = { version = "0.6.0", features = ["tls-rustls"] } base64 = "0.22.1" bytes = "1.6.0" clap = { version = "4.5.4", default-features = false, features = ["std", "derive", "help", "usage", "error-context", "string", "wrap_help"] } +frunk = "0.4.2" futures-util = { version = "0.3.30", default-features = false } hmac = "0.12.1" html-escape = "0.2.13" diff --git a/src/database.rs b/src/database.rs index 223d807b..d178eee1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -28,6 +28,7 @@ use crate::{ pub(crate) mod abstraction; pub(crate) mod key_value; +mod map; use abstraction::{KeyValueDatabaseEngine, KvTree}; diff --git a/src/database/map.rs b/src/database/map.rs new file mode 100644 index 00000000..a7544cf8 --- /dev/null +++ b/src/database/map.rs @@ -0,0 +1,178 @@ +//! A high-level strongly-typed abstraction over key-value stores + +#![warn(missing_docs, clippy::missing_docs_in_private_items)] + +use std::{ + any::TypeId, + borrow::{Borrow, Cow}, + error::Error, +}; + +use frunk::{HCons, HNil}; + +#[cfg(test)] +mod tests; + +/// Errors that can occur during key-value store operations +// Missing docs are allowed here since that kind of information should be +// encoded in the error messages themselves anyway. +#[allow(clippy::missing_docs_in_private_items, dead_code)] +#[derive(thiserror::Error, Debug)] +pub(crate) enum MapError { + #[cfg(any(feature = "sqlite", feature = "rocksdb"))] + #[error(transparent)] + Database(Box), + + #[error("failed to convert stored value into structured data")] + FromBytes(#[source] Box), +} + +/// A high-level representation of a key-value relation in a key-value store +#[allow(dead_code)] +pub(crate) trait Map { + /// The key type of this relation + type Key: ToBytes + FromBytes; + + /// The value type of this relation + type Value: ToBytes + FromBytes; + + /// Load a value based on its corresponding key + async fn get(&self, key: &K) -> Result, MapError> + where + Self::Key: Borrow, + K: ToBytes + ?Sized; + + /// Insert or update a key-value pair + async fn set(&self, key: &K, value: &V) -> Result<(), MapError> + where + Self::Key: Borrow, + Self::Value: Borrow, + K: ToBytes + ?Sized, + V: ToBytes + ?Sized; + + /// Remove a key-value pair by its key + /// + /// It is not an error to remove a key-value pair that is not present in the + /// store. + async fn del(&self, key: &K) -> Result<(), MapError> + where + Self::Key: Borrow, + K: ToBytes + ?Sized; +} + +/// Convert `Self` into bytes for storage in a key-value store +/// +/// Implementations on types other than `HList`s must not contain `0xFF` bytes +/// in their serialized form. +/// +/// [`FromBytes`] must be the exact inverse of this operation. +#[allow(dead_code)] +pub(crate) trait ToBytes { + /// Perform the conversion + fn to_bytes(&self) -> Cow<'_, [u8]>; +} + +impl ToBytes for () { + fn to_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&[]) + } +} + +impl ToBytes for HNil { + fn to_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&[]) + } +} + +impl ToBytes for HCons +where + H: ToBytes, + T: ToBytes + 'static, +{ + fn to_bytes(&self) -> Cow<'_, [u8]> { + let buf = self.head.to_bytes(); + + if TypeId::of::() == TypeId::of::() { + buf + } else { + let mut buf = buf.into_owned(); + buf.push(0xFF); + buf.extend_from_slice(self.tail.to_bytes().as_ref()); + Cow::Owned(buf) + } + } +} + +impl ToBytes for String { + fn to_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(self.as_bytes()) + } +} + +/// Convert from bytes stored in a key-value store into structured data +/// +/// This should generally only be implemented by owned types. +/// +/// [`ToBytes`] must be the exact inverse of this operation. +#[allow(dead_code)] +pub(crate) trait FromBytes +where + Self: Sized, +{ + /// Perform the conversion + fn from_bytes(bytes: Vec) -> Result>; +} + +impl FromBytes for () { + fn from_bytes(bytes: Vec) -> Result> { + bytes + .is_empty() + .then_some(()) + .ok_or_else(|| "got bytes when none were expected".into()) + } +} + +impl FromBytes for HNil { + fn from_bytes(bytes: Vec) -> Result> { + bytes + .is_empty() + .then_some(HNil) + .ok_or_else(|| "got bytes when none were expected".into()) + } +} + +impl FromBytes for HCons +where + H: FromBytes, + T: FromBytes + 'static, +{ + fn from_bytes(bytes: Vec) -> Result> { + let (head, tail) = if TypeId::of::() == TypeId::of::() { + // There is no spoon. I mean, tail. + (bytes, Vec::new()) + } else { + let boundary = bytes + .iter() + .copied() + .position(|x| x == 0xFF) + .ok_or("map entry is missing a boundary")?; + + // Don't include the boundary in the head or tail + let head = &bytes[..boundary]; + let tail = &bytes[boundary + 1..]; + + (head.to_owned(), tail.to_owned()) + }; + + Ok(HCons { + head: H::from_bytes(head)?, + tail: T::from_bytes(tail)?, + }) + } +} + +impl FromBytes for String { + fn from_bytes(bytes: Vec) -> Result> { + String::from_utf8(bytes).map_err(Into::into) + } +} diff --git a/src/database/map/tests.rs b/src/database/map/tests.rs new file mode 100644 index 00000000..680b5ee3 --- /dev/null +++ b/src/database/map/tests.rs @@ -0,0 +1,117 @@ +use std::{ + borrow::Borrow, collections::BTreeMap, marker::PhantomData, sync::RwLock, +}; + +use frunk::{hlist, HList}; + +use super::{FromBytes, Map, MapError, ToBytes}; + +mod conversions; + +struct TestMap { + storage: RwLock, Vec>>, + types: PhantomData<(K, V)>, +} + +impl TestMap { + fn new() -> Self { + Self { + storage: RwLock::new(BTreeMap::new()), + types: PhantomData, + } + } +} + +impl Map for TestMap +where + K: ToBytes + FromBytes, + V: ToBytes + FromBytes, +{ + type Key = K; + type Value = V; + + async fn get(&self, key: &KB) -> Result, MapError> + where + Self::Key: Borrow, + KB: ToBytes + ?Sized, + { + self.storage + .read() + .expect("lock should not be poisoned") + .get(key.borrow().to_bytes().as_ref()) + .map(|v| { + Self::Value::from_bytes(v.to_owned()) + .map_err(MapError::FromBytes) + }) + .transpose() + } + + async fn set(&self, key: &KB, value: &VB) -> Result<(), MapError> + where + Self::Key: Borrow, + Self::Value: Borrow, + KB: ToBytes + ?Sized, + VB: ToBytes + ?Sized, + { + self.storage.write().expect("lock should not be poisoned").insert( + key.borrow().to_bytes().into_owned(), + value.borrow().to_bytes().into_owned(), + ); + + Ok(()) + } + + async fn del(&self, key: &KB) -> Result<(), MapError> + where + Self::Key: Borrow, + KB: ToBytes + ?Sized, + { + self.storage + .write() + .expect("lock should not be poisoned") + .remove(key.borrow().to_bytes().as_ref()); + + Ok(()) + } +} + +#[tokio::test] +async fn string_to_string() { + let test_map = TestMap::::new(); + + let key = "hello".to_owned(); + let value = "world".to_owned(); + + test_map.set(&key, &value).await.expect("insertion should succed"); + + let actual_value = test_map.get(&key).await.expect("lookup should succeed"); + + assert_eq!(Some(value), actual_value); + + test_map.del(&key).await.expect("deletion should succeed"); + + let actual_value = test_map.get(&key).await.expect("lookup should succeed"); + + assert_eq!(None, actual_value); +} + +#[tokio::test] +async fn hlist_to_hlist() { + let test_map = + TestMap::::new(); + + let key = hlist!["hello".to_owned(), "world".to_owned()]; + let value = hlist!["test".to_owned(), "suite".to_owned()]; + + test_map.set(&key, &value).await.expect("insertion should succed"); + + let actual_value = test_map.get(&key).await.expect("lookup should succeed"); + + assert_eq!(Some(value), actual_value); + + test_map.del(&key).await.expect("deletion should succeed"); + + let actual_value = test_map.get(&key).await.expect("lookup should succeed"); + + assert_eq!(None, actual_value); +} diff --git a/src/database/map/tests/conversions.rs b/src/database/map/tests/conversions.rs new file mode 100644 index 00000000..19469ffe --- /dev/null +++ b/src/database/map/tests/conversions.rs @@ -0,0 +1,105 @@ +use frunk::{hlist, HList}; + +use super::super::{FromBytes, ToBytes}; + +#[test] +pub(crate) fn serialize_hlist_0() { + let expected: &[u8] = &[]; + + let actual = hlist![]; + let actual_bytes = actual.to_bytes(); + + assert_eq!(expected, actual_bytes.as_ref()); +} + +#[test] +pub(crate) fn serialize_hlist_1() { + let expected = + [b"hello"].into_iter().flatten().copied().collect::>(); + + let actual = hlist!["hello".to_owned()]; + let actual_bytes = actual.to_bytes(); + + assert_eq!(expected.as_slice(), actual_bytes.as_ref()); +} + +#[test] +pub(crate) fn serialize_hlist_2() { + let expected = [b"hello", [0xFF].as_slice(), b"world"] + .into_iter() + .flatten() + .copied() + .collect::>(); + + let actual = hlist!["hello".to_owned(), "world".to_owned()]; + let actual_bytes = actual.to_bytes(); + + assert_eq!(expected.as_slice(), actual_bytes.as_ref()); +} + +#[test] +pub(crate) fn serialize_hlist_3() { + let expected = + [b"what's", [0xFF].as_slice(), b"up", [0xFF].as_slice(), b"world"] + .into_iter() + .flatten() + .copied() + .collect::>(); + + let actual = + hlist!["what's".to_owned(), "up".to_owned(), "world".to_owned()]; + let actual_bytes = actual.to_bytes(); + + assert_eq!(expected.as_slice(), actual_bytes.as_ref()); +} + +#[test] +pub(crate) fn deserialize_hlist_0() { + let actual = ::from_bytes(Vec::new()) + .expect("should be able to deserialize"); + + assert_eq!(hlist![], actual); +} + +#[test] +pub(crate) fn deserialize_hlist_1() { + let serialized = + [b"hello"].into_iter().flatten().copied().collect::>(); + + let actual = ::from_bytes(serialized) + .expect("should be able to deserialize"); + + assert_eq!(hlist!["hello".to_owned()], actual); +} + +#[test] +pub(crate) fn deserialize_hlist_2() { + let serialized = [b"hello", [0xFF].as_slice(), b"world"] + .into_iter() + .flatten() + .copied() + .collect::>(); + + let actual = ::from_bytes(serialized) + .expect("should be able to deserialize"); + + assert_eq!(hlist!["hello".to_owned(), "world".to_owned()], actual); +} + +#[test] +pub(crate) fn deserialize_hlist_3() { + let serialized = + [b"what's", [0xFF].as_slice(), b"up", [0xFF].as_slice(), b"world"] + .into_iter() + .flatten() + .copied() + .collect::>(); + + let actual = ::from_bytes(serialized) + .expect("should be able to deserialize"); + + assert_eq!( + hlist!["what's".to_owned(), "up".to_owned(), "world".to_owned()], + actual + ); +}