use async_trait::async_trait; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::discovery::{OldVerifyKey, ServerSigningKeys}, signatures::Ed25519KeyPair, DeviceId, ServerName, UserId, }; use crate::{ database::KeyValueDatabase, observability::prelude::*, service::{self, globals::SigningKeys}, services, utils, Error, Result, }; pub(crate) const COUNTER: &[u8] = b"c"; #[async_trait] impl service::globals::Data for KeyValueDatabase { fn next_count(&self) -> Result { utils::u64_from_bytes(&self.global.increment(COUNTER)?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) } fn current_count(&self) -> Result { self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { utils::u64_from_bytes(&bytes) .map_err(|_| Error::bad_database("Count has invalid bytes.")) }) } #[t::instrument(skip(self))] async fn watch( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); userid_prefix.push(0xFF); let mut userdeviceid_prefix = userid_prefix.clone(); userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.push(0xFF); let mut futures = FuturesUnordered::new(); // Return when *any* user changed their key // TODO: only send for user they share a room with futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push( self.userroomid_notificationcount.watch_prefix(&userid_prefix), ); futures .push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in for room_id in services() .rooms .state_cache .rooms_joined(user_id) .filter_map(Result::ok) { let short_roomid = services() .rooms .short .get_shortroomid(&room_id) .ok() .flatten() .expect("room exists") .get() .to_be_bytes() .to_vec(); let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); roomid_prefix.push(0xFF); // PDUs futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); // EDUs futures.push(Box::pin(async move { let _result = services() .rooms .edus .typing .wait_for_update(&room_id) .await; })); futures.push( self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix), ); // Key changes futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); // Room account data let mut roomuser_prefix = roomid_prefix.clone(); roomuser_prefix.extend_from_slice(&userid_prefix); futures.push( self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix), ); } let mut globaluserdata_prefix = vec![0xFF]; globaluserdata_prefix.extend_from_slice(&userid_prefix); futures.push( self.roomusertype_roomuserdataid .watch_prefix(&globaluserdata_prefix), ); // More key changes (used when user is not joined to any rooms) futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); // One time keys futures .push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(Box::pin(services().globals.rotate.watch())); // Wait until one of them finds something futures.next().await; Ok(()) } fn cleanup(&self) -> Result<()> { self.db.cleanup() } fn load_keypair(&self) -> Result { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { let keypair = utils::generate_keypair(); self.global.insert(b"keypair", &keypair)?; Ok::<_, Error>(keypair) }, |s| Ok(s.clone()), )?; let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); utils::string_from_bytes( // 1. version parts.next().expect("splitn always returns at least one element"), ) .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .and_then(|version| { // 2. key parts .next() .ok_or_else(|| { Error::bad_database("Invalid keypair format in database.") }) .map(|key| (version, key)) }) .and_then(|(version, key)| { Ed25519KeyPair::from_der(key, version).map_err(|_| { Error::bad_database("Private or public keys are invalid.") }) }) } fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } fn add_signing_key_from_trusted_server( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result { let prev_keys = self.server_signingkeys.get(origin.as_bytes())?; Ok( if let Some(mut prev_keys) = prev_keys.and_then(|keys| { serde_json::from_slice::(&keys).ok() }) { let ServerSigningKeys { verify_keys, old_verify_keys, .. } = new_keys; prev_keys.verify_keys.extend(verify_keys); prev_keys.old_verify_keys.extend(old_verify_keys); prev_keys.valid_until_ts = new_keys.valid_until_ts; self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&prev_keys) .expect("serversigningkeys can be serialized"), )?; prev_keys.into() } else { self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&new_keys) .expect("serversigningkeys can be serialized"), )?; new_keys.into() }, ) } fn add_signing_key_from_origin( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result { let prev_keys = self.server_signingkeys.get(origin.as_bytes())?; Ok( if let Some(mut prev_keys) = prev_keys.and_then(|keys| { serde_json::from_slice::(&keys).ok() }) { let ServerSigningKeys { verify_keys, old_verify_keys, .. } = new_keys; // Moving `verify_keys` no longer present to `old_verify_keys` for (key_id, key) in prev_keys.verify_keys { if !verify_keys.contains_key(&key_id) { prev_keys.old_verify_keys.insert( key_id, OldVerifyKey::new( prev_keys.valid_until_ts, key.key, ), ); } } prev_keys.verify_keys = verify_keys; prev_keys.old_verify_keys.extend(old_verify_keys); prev_keys.valid_until_ts = new_keys.valid_until_ts; self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&prev_keys) .expect("serversigningkeys can be serialized"), )?; prev_keys.into() } else { self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&new_keys) .expect("serversigningkeys can be serialized"), )?; new_keys.into() }, ) } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. fn signing_keys_for( &self, origin: &ServerName, ) -> Result> { let signingkeys = self.server_signingkeys.get(origin.as_bytes())?.and_then(|bytes| { serde_json::from_slice::(&bytes).ok() }); Ok(signingkeys) } fn database_version(&self) -> Result { self.global.get(b"version")?.map_or(Ok(0), |version| { utils::u64_from_bytes(&version).map_err(|_| { Error::bad_database("Database version id is invalid.") }) }) } fn bump_database_version(&self, new_version: u64) -> Result<()> { self.global.insert(b"version", &new_version.to_be_bytes())?; Ok(()) } }