use std::{ cell::RefCell, future::Future, path::{Path, PathBuf}, pin::Pin, sync::Arc, }; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use thread_local::ThreadLocal; use tracing::debug; use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{database::Config, Result}; thread_local! { static READ_CONNECTION: RefCell> = const { RefCell::new(None) }; static READ_CONNECTION_ITERATOR: RefCell> = const { RefCell::new(None) }; } struct PreparedStatementIterator<'a> { pub(crate) iterator: Box + 'a>, pub(crate) _statement_ref: AliasableBox>, } impl Iterator for PreparedStatementIterator<'_> { type Item = TupleOfBytes; fn next(&mut self) -> Option { self.iterator.next() } } struct AliasableBox(*mut T); impl Drop for AliasableBox { fn drop(&mut self) { // SAFETY: This is cursed and relies on non-local reasoning. // // In order for this to be safe: // // * All aliased references to this value must have been dropped first, // for example by coming after its referrers in struct fields, because // struct fields are automatically dropped in order from top to bottom // in the absence of an explicit Drop impl. Otherwise, the referrers // may read into deallocated memory. // * This type must not be copyable or cloneable. Otherwise, double-free // can occur. // // These conditions are met, but again, note that changing safe code in // this module can result in unsoundness if any of these constraints are // violated. unsafe { drop(Box::from_raw(self.0)) } } } pub(crate) struct Engine { writer: Mutex, read_conn_tls: ThreadLocal, read_iterator_conn_tls: ThreadLocal, path: PathBuf, cache_size_per_thread: u32, } impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(path)?; conn.pragma_update(Some(Main), "page_size", 2048)?; conn.pragma_update(Some(Main), "journal_mode", "WAL")?; conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; conn.pragma_update( Some(Main), "cache_size", -i64::from(cache_size_kb), )?; conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; Ok(conn) } fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } fn read_lock(&self) -> &Connection { self.read_conn_tls.get_or(|| { Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap() }) } fn read_lock_iterator(&self) -> &Connection { self.read_iterator_conn_tls.get_or(|| { Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap() }) } pub(crate) fn flush_wal(self: &Arc) -> Result<()> { self.write_lock().pragma_update( Some(Main), "wal_checkpoint", "RESTART", )?; Ok(()) } } impl KeyValueDatabaseEngine for Arc { fn open(config: &Config) -> Result { let path = Path::new(&config.database_path).join(format!( "{}.db", if config.conduit_compat { "conduit" } else { "grapevine" } )); // calculates cache-size per permanent connection // 1. convert MB to KiB // 2. divide by permanent connections + permanent iter connections + // write connection // 3. round down to nearest integer #[allow( clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_precision_loss, clippy::cast_sign_loss )] let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get() as f64 * 2.0) + 1.0)) as u32; let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); let arc = Arc::new(Engine { writer, read_conn_tls: ThreadLocal::new(), read_iterator_conn_tls: ThreadLocal::new(), path, cache_size_per_thread, }); Ok(arc) } fn open_tree(&self, name: &str) -> Result> { self.write_lock().execute( &format!( "CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY \ KEY, \"value\" BLOB NOT NULL )" ), [], )?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), name: name.to_owned(), watchers: Watchers::default(), })) } fn cleanup(&self) -> Result<()> { self.flush_wal() } } pub(crate) struct SqliteTable { engine: Arc, name: String, watchers: Watchers, } type TupleOfBytes = (Vec, Vec); impl SqliteTable { fn get_with_guard( &self, guard: &Connection, key: &[u8], ) -> Result>> { Ok(guard .prepare( format!("SELECT value FROM {} WHERE key = ?", self.name) .as_str(), )? .query_row([key], |row| row.get(0)) .optional()?) } fn insert_with_guard( &self, guard: &Connection, key: &[u8], value: &[u8], ) -> Result<()> { guard.execute( format!( "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name ) .as_str(), [key, value], )?; Ok(()) } pub(crate) fn iter_with_guard<'a>( &'a self, guard: &'a Connection, ) -> Box + 'a> { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} ORDER BY key ASC", &self.name )) .unwrap(), )); let statement_ref = AliasableBox(statement); let iterator = Box::new( statement .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } } impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) } fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); self.insert_with_guard(&guard, key, value)?; drop(guard); self.watchers.wake(key); Ok(()) } fn insert_batch( &self, iter: &mut dyn Iterator, Vec)>, ) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for (key, value) in iter { self.insert_with_guard(&guard, &key, &value)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } fn increment_batch( &self, iter: &mut dyn Iterator>, ) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for key in iter { let old = self.get_with_guard(&guard, &key)?; let new = crate::utils::increment(old.as_deref()); self.insert_with_guard(&guard, &key, &new)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); guard.execute( format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key], )?; Ok(()) } fn iter<'a>(&'a self) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); self.iter_with_guard(guard) } fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, ) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); // TODO change interface? let from = from.to_vec(); if backwards { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key <= ? ORDER BY \ key DESC", &self.name )) .unwrap(), )); let statement_ref = AliasableBox(statement); let iterator = Box::new( statement .query_map([from], |row| { Ok((row.get_unwrap(0), row.get_unwrap(1))) }) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } else { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key >= ? ORDER BY \ key ASC", &self.name )) .unwrap(), )); let statement_ref = AliasableBox(statement); let iterator = Box::new( statement .query_map([from], |row| { Ok((row.get_unwrap(0), row.get_unwrap(1))) }) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } } fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); let old = self.get_with_guard(&guard, key)?; let new = crate::utils::increment(old.as_deref()); self.insert_with_guard(&guard, key, &new)?; Ok(new) } fn scan_prefix<'a>( &'a self, prefix: Vec, ) -> Box + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), ) } fn watch_prefix<'a>( &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { self.watchers.watch(prefix) } fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine .write_lock() .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } }