mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-17 07:41:23 +01:00
This is useful for checking for potential overlap between paths that have not been fully created yet.
524 lines
15 KiB
Rust
524 lines
15 KiB
Rust
use std::{
|
|
borrow::Cow,
|
|
cmp, fmt,
|
|
fmt::Write,
|
|
io,
|
|
path::{Component, Path, PathBuf},
|
|
str::FromStr,
|
|
time::{SystemTime, UNIX_EPOCH},
|
|
};
|
|
|
|
use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier};
|
|
use cmp::Ordering;
|
|
use rand::{prelude::*, rngs::OsRng};
|
|
use ring::digest;
|
|
use ruma::{
|
|
api::client::error::ErrorKind, canonical_json::try_from_json_map,
|
|
CanonicalJsonError, CanonicalJsonObject, MxcUri, MxcUriError, OwnedMxcUri,
|
|
};
|
|
use tokio::fs;
|
|
|
|
use crate::{Error, Result};
|
|
|
|
pub(crate) mod error;
|
|
pub(crate) mod on_demand_hashmap;
|
|
pub(crate) mod proxy_protocol;
|
|
pub(crate) mod room_version;
|
|
|
|
// Hopefully we have a better chat protocol in 530 years
|
|
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
|
|
pub(crate) fn millis_since_unix_epoch() -> u64 {
|
|
SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.expect("time is valid")
|
|
.as_millis() as u64
|
|
}
|
|
|
|
#[cfg(any(feature = "rocksdb", feature = "sqlite"))]
|
|
pub(crate) fn increment(old: Option<&[u8]>) -> Vec<u8> {
|
|
let number = match old.map(TryInto::try_into) {
|
|
Some(Ok(bytes)) => {
|
|
let number = u64::from_be_bytes(bytes);
|
|
number + 1
|
|
}
|
|
// Start at one. since 0 should return the first event in the db
|
|
_ => 1,
|
|
};
|
|
|
|
number.to_be_bytes().to_vec()
|
|
}
|
|
|
|
pub(crate) fn generate_keypair() -> Vec<u8> {
|
|
let mut value = random_string(8).as_bytes().to_vec();
|
|
value.push(0xFF);
|
|
value.extend_from_slice(
|
|
&ruma::signatures::Ed25519KeyPair::generate()
|
|
.expect("Ed25519KeyPair generation always works (?)"),
|
|
);
|
|
value
|
|
}
|
|
|
|
/// Parses the bytes into an u64.
|
|
pub(crate) fn u64_from_bytes(
|
|
bytes: &[u8],
|
|
) -> Result<u64, std::array::TryFromSliceError> {
|
|
let array: [u8; 8] = bytes.try_into()?;
|
|
Ok(u64::from_be_bytes(array))
|
|
}
|
|
|
|
/// Parses the bytes into a string.
|
|
pub(crate) fn string_from_bytes(
|
|
bytes: &[u8],
|
|
) -> Result<String, std::string::FromUtf8Error> {
|
|
String::from_utf8(bytes.to_vec())
|
|
}
|
|
|
|
pub(crate) fn random_string(length: usize) -> String {
|
|
thread_rng()
|
|
.sample_iter(&rand::distributions::Alphanumeric)
|
|
.take(length)
|
|
.map(char::from)
|
|
.collect()
|
|
}
|
|
|
|
/// Hash the given password
|
|
pub(crate) fn hash_password<B>(
|
|
password: B,
|
|
) -> Result<password_hash::PasswordHashString, password_hash::Error>
|
|
where
|
|
B: AsRef<[u8]>,
|
|
{
|
|
Argon2::default()
|
|
.hash_password(
|
|
password.as_ref(),
|
|
&password_hash::SaltString::generate(&mut OsRng),
|
|
)
|
|
.map(|x| x.serialize())
|
|
}
|
|
|
|
/// Compare a password to a hash
|
|
///
|
|
/// Returns `true` if the password matches the hash, `false` otherwise.
|
|
pub(crate) fn verify_password<S, B>(hash: S, password: B) -> bool
|
|
where
|
|
S: AsRef<str>,
|
|
B: AsRef<[u8]>,
|
|
{
|
|
let Ok(hash) = password_hash::PasswordHash::new(hash.as_ref()) else {
|
|
return false;
|
|
};
|
|
|
|
Argon2::default().verify_password(password.as_ref(), &hash).is_ok()
|
|
}
|
|
|
|
#[tracing::instrument(skip(keys))]
|
|
pub(crate) fn calculate_hash<'a, I, T>(keys: I) -> Vec<u8>
|
|
where
|
|
I: IntoIterator<Item = T>,
|
|
T: AsRef<[u8]>,
|
|
{
|
|
let mut bytes = Vec::new();
|
|
for (i, key) in keys.into_iter().enumerate() {
|
|
if i != 0 {
|
|
bytes.push(0xFF);
|
|
}
|
|
bytes.extend_from_slice(key.as_ref());
|
|
}
|
|
let hash = digest::digest(&digest::SHA256, &bytes);
|
|
hash.as_ref().to_owned()
|
|
}
|
|
|
|
pub(crate) fn common_elements<I, T, F>(
|
|
mut iterators: I,
|
|
check_order: F,
|
|
) -> Option<impl Iterator<Item = T>>
|
|
where
|
|
I: Iterator,
|
|
I::Item: Iterator<Item = T>,
|
|
F: Fn(&T, &T) -> Ordering,
|
|
{
|
|
let first_iterator = iterators.next()?;
|
|
let mut other_iterators =
|
|
iterators.map(Iterator::peekable).collect::<Vec<_>>();
|
|
|
|
Some(first_iterator.filter(move |target| {
|
|
other_iterators.iter_mut().all(|it| {
|
|
while let Some(element) = it.peek() {
|
|
match check_order(element, target) {
|
|
// We went too far
|
|
Ordering::Greater => return false,
|
|
// Element is in both iters
|
|
Ordering::Equal => return true,
|
|
// Keep searching
|
|
Ordering::Less => {
|
|
it.next();
|
|
}
|
|
}
|
|
}
|
|
false
|
|
})
|
|
}))
|
|
}
|
|
|
|
/// Fallible conversion from any value that implements `Serialize` to a
|
|
/// `CanonicalJsonObject`.
|
|
///
|
|
/// `value` must serialize to an `serde_json::Value::Object`.
|
|
pub(crate) fn to_canonical_object<T: serde::Serialize>(
|
|
value: T,
|
|
) -> Result<CanonicalJsonObject, CanonicalJsonError> {
|
|
use serde::ser::Error;
|
|
|
|
match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? {
|
|
serde_json::Value::Object(map) => try_from_json_map(map),
|
|
_ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom(
|
|
"Value must be an object",
|
|
))),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn deserialize_from_str<
|
|
'de,
|
|
D: serde::de::Deserializer<'de>,
|
|
T: FromStr<Err = E>,
|
|
E: fmt::Display,
|
|
>(
|
|
deserializer: D,
|
|
) -> Result<T, D::Error> {
|
|
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
|
|
impl<T: FromStr<Err = Err>, Err: fmt::Display> serde::de::Visitor<'_>
|
|
for Visitor<T, Err>
|
|
{
|
|
type Value = T;
|
|
|
|
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(formatter, "a parsable string")
|
|
}
|
|
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
v.parse().map_err(serde::de::Error::custom)
|
|
}
|
|
}
|
|
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
|
|
}
|
|
|
|
/// Debug-formats the given slice, but only up to the first `max_len` elements.
|
|
/// Any further elements are replaced by an ellipsis.
|
|
///
|
|
/// See also [`debug_slice_truncated()`],
|
|
pub(crate) struct TruncatedDebugSlice<'a, T> {
|
|
inner: &'a [T],
|
|
max_len: usize,
|
|
}
|
|
|
|
impl<T: fmt::Debug> fmt::Debug for TruncatedDebugSlice<'_, T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
if self.inner.len() <= self.max_len {
|
|
write!(f, "{:?}", self.inner)
|
|
} else {
|
|
f.debug_list()
|
|
.entries(&self.inner[..self.max_len])
|
|
.entry(&"...")
|
|
.finish()
|
|
}
|
|
}
|
|
}
|
|
|
|
/// See [`TruncatedDebugSlice`]. Useful for `#[instrument]`:
|
|
///
|
|
/// ```ignore
|
|
/// #[tracing::instrument(fields(
|
|
/// foos = debug_slice_truncated(foos, N)
|
|
/// ))]
|
|
/// ```
|
|
pub(crate) fn debug_slice_truncated<T: fmt::Debug>(
|
|
slice: &[T],
|
|
max_len: usize,
|
|
) -> tracing::field::DebugValue<TruncatedDebugSlice<'_, T>> {
|
|
tracing::field::debug(TruncatedDebugSlice {
|
|
inner: slice,
|
|
max_len,
|
|
})
|
|
}
|
|
|
|
/// Truncates a string to an approximate maximum length, replacing any extra
|
|
/// text with an ellipsis.
|
|
///
|
|
/// Only to be used for informational purposes, exact semantics are unspecified.
|
|
pub(crate) fn dbg_truncate_str(s: &str, mut max_len: usize) -> Cow<'_, str> {
|
|
while max_len < s.len() && !s.is_char_boundary(max_len) {
|
|
max_len += 1;
|
|
}
|
|
|
|
if s.len() <= max_len {
|
|
s.into()
|
|
} else {
|
|
#[allow(clippy::string_slice)] // we checked it's at a char boundary
|
|
format!("{}...", &s[..max_len]).into()
|
|
}
|
|
}
|
|
|
|
/// Data that makes up an `mxc://` URL.
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct MxcData<'a> {
|
|
pub(crate) server_name: &'a ruma::ServerName,
|
|
pub(crate) media_id: &'a str,
|
|
}
|
|
|
|
impl<'a> MxcData<'a> {
|
|
pub(crate) fn new(
|
|
server_name: &'a ruma::ServerName,
|
|
media_id: &'a str,
|
|
) -> Result<Self> {
|
|
if !media_id.bytes().all(|b| {
|
|
matches!(b,
|
|
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'-' | b'_'
|
|
)
|
|
}) {
|
|
return Err(Error::BadRequest(
|
|
ErrorKind::InvalidParam,
|
|
"Invalid MXC media id",
|
|
));
|
|
}
|
|
|
|
Ok(Self {
|
|
server_name,
|
|
media_id,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for MxcData<'_> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "mxc://{}/{}", self.server_name, self.media_id)
|
|
}
|
|
}
|
|
|
|
impl From<MxcData<'_>> for OwnedMxcUri {
|
|
fn from(value: MxcData<'_>) -> Self {
|
|
value.to_string().into()
|
|
}
|
|
}
|
|
|
|
impl<'a> TryFrom<&'a MxcUri> for MxcData<'a> {
|
|
type Error = MxcUriError;
|
|
|
|
fn try_from(value: &'a MxcUri) -> Result<Self, Self::Error> {
|
|
Ok(Self::new(value.server_name()?, value.media_id()?)
|
|
.expect("validated MxcUri should always be valid MxcData"))
|
|
}
|
|
}
|
|
|
|
fn curlify_args<T>(req: &http::Request<T>) -> Option<Vec<String>> {
|
|
let mut args =
|
|
vec!["curl".to_owned(), "-X".to_owned(), req.method().to_string()];
|
|
|
|
for (name, val) in req.headers() {
|
|
args.extend([
|
|
"-H".to_owned(),
|
|
format!("{name}: {}", val.to_str().ok()?),
|
|
]);
|
|
}
|
|
|
|
let fix_uri = || {
|
|
if req.uri().scheme().is_some() {
|
|
return None;
|
|
}
|
|
if req.uri().authority().is_some() {
|
|
return None;
|
|
}
|
|
let mut parts = req.uri().clone().into_parts();
|
|
|
|
parts.scheme = Some(http::uri::Scheme::HTTPS);
|
|
|
|
let host =
|
|
req.headers().get(http::header::HOST)?.to_str().ok()?.to_owned();
|
|
parts.authority =
|
|
Some(http::uri::Authority::from_maybe_shared(host).ok()?);
|
|
|
|
http::uri::Uri::from_parts(parts).ok()
|
|
};
|
|
|
|
let uri = if let Some(new_uri) = fix_uri() {
|
|
Cow::Owned(new_uri)
|
|
} else {
|
|
Cow::Borrowed(req.uri())
|
|
};
|
|
|
|
args.push(uri.to_string());
|
|
|
|
Some(args)
|
|
}
|
|
|
|
pub(crate) fn curlify<T>(req: &http::Request<T>) -> Option<String> {
|
|
let args = curlify_args(req)?;
|
|
|
|
Some(
|
|
args.into_iter()
|
|
.map(|arg| {
|
|
if arg.chars().all(|c| {
|
|
c.is_alphanumeric() || ['-', '_', ':', '/'].contains(&c)
|
|
}) {
|
|
arg
|
|
} else {
|
|
format!("'{}'", arg.replace('\'', "\\'"))
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join(" "),
|
|
)
|
|
}
|
|
|
|
/// Format a u8 slice as an uppercase hex string
|
|
///
|
|
/// The output does not contain a leading `0x` nor any non-hex characters (e.g.
|
|
/// whitespace or commas do not appear in the output).
|
|
pub(crate) fn u8_slice_to_hex(slice: &[u8]) -> String {
|
|
slice.iter().fold(String::new(), |mut acc, x| {
|
|
write!(acc, "{x:02X}").expect("in-memory write should succeed");
|
|
acc
|
|
})
|
|
}
|
|
|
|
/// Canonicalize a path where some components may not exist yet.
|
|
///
|
|
/// It's assumed that non-existent components will be created as
|
|
/// directories. This should match the result of [`fs::canonicalize`]
|
|
/// _after_ calling [`fs::create_dir_all`] on `path`.
|
|
#[allow(dead_code)]
|
|
pub(crate) async fn partial_canonicalize(path: &Path) -> io::Result<PathBuf> {
|
|
let mut ret = std::env::current_dir()?;
|
|
|
|
let mut base_path = Cow::Borrowed(path);
|
|
let mut components = base_path.components();
|
|
|
|
while let Some(component) = components.next() {
|
|
match component {
|
|
Component::Prefix(_) | Component::RootDir => {
|
|
let component_path: &Path = component.as_ref();
|
|
component_path.clone_into(&mut ret);
|
|
}
|
|
Component::CurDir => (),
|
|
Component::ParentDir => {
|
|
ret.pop();
|
|
}
|
|
Component::Normal(p) => {
|
|
let component_path = ret.join(p);
|
|
match fs::symlink_metadata(&component_path).await {
|
|
// path is a symlink
|
|
Ok(metadata) if metadata.is_symlink() => {
|
|
let destination =
|
|
fs::read_link(&component_path).await?;
|
|
// iterate over the symlink destination components
|
|
// before continuing with the original path
|
|
base_path =
|
|
Cow::Owned(destination.join(components.as_path()));
|
|
components = base_path.components();
|
|
}
|
|
// path exists, not a symlink
|
|
Ok(_) => {
|
|
ret.push(p);
|
|
}
|
|
// path does not exist
|
|
Err(error) if error.kind() == io::ErrorKind::NotFound => {
|
|
// assume a directory will be created here
|
|
ret.push(p);
|
|
}
|
|
Err(error) => return Err(error),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(ret)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use tempfile::TempDir;
|
|
use tokio::fs;
|
|
|
|
use crate::utils::{
|
|
dbg_truncate_str, partial_canonicalize, u8_slice_to_hex,
|
|
};
|
|
|
|
#[test]
|
|
fn test_truncate_str() {
|
|
assert_eq!(dbg_truncate_str("short", 10), "short");
|
|
assert_eq!(dbg_truncate_str("very long string", 10), "very long ...");
|
|
assert_eq!(dbg_truncate_str("no info, only dots", 0), "...");
|
|
assert_eq!(dbg_truncate_str("", 0), "");
|
|
assert_eq!(dbg_truncate_str("unicöde", 5), "unicö...");
|
|
let ok_hand = "👌🏽";
|
|
assert_eq!(dbg_truncate_str(ok_hand, 1), "👌...");
|
|
assert_eq!(dbg_truncate_str(ok_hand, ok_hand.len() - 1), "👌🏽");
|
|
assert_eq!(dbg_truncate_str(ok_hand, ok_hand.len()), "👌🏽");
|
|
}
|
|
|
|
#[test]
|
|
fn test_slice_to_hex() {
|
|
assert_eq!(u8_slice_to_hex(&[]), "");
|
|
assert_eq!(u8_slice_to_hex(&[0]), "00");
|
|
assert_eq!(u8_slice_to_hex(&[0xFF]), "FF");
|
|
assert_eq!(u8_slice_to_hex(&[1, 2, 3, 4]), "01020304");
|
|
assert_eq!(
|
|
u8_slice_to_hex(&[0x42; 100]),
|
|
"4242424242424242424242424242424242424242\
|
|
4242424242424242424242424242424242424242\
|
|
4242424242424242424242424242424242424242\
|
|
4242424242424242424242424242424242424242\
|
|
4242424242424242424242424242424242424242"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_partial_canonicalize() {
|
|
let tmp_dir =
|
|
TempDir::with_prefix("test_partial_canonicalize").unwrap();
|
|
let path = tmp_dir.path();
|
|
|
|
fs::create_dir(&path.join("dir")).await.unwrap();
|
|
fs::symlink(path.join("dir"), path.join("absolute-link-to-dir"))
|
|
.await
|
|
.unwrap();
|
|
fs::symlink("./dir", path.join("relative-link-to-dir")).await.unwrap();
|
|
|
|
assert_eq!(partial_canonicalize(path).await.unwrap(), path);
|
|
assert_eq!(partial_canonicalize(&path.join("./")).await.unwrap(), path);
|
|
assert_eq!(
|
|
partial_canonicalize(&path.join("dir/..")).await.unwrap(),
|
|
path
|
|
);
|
|
assert_eq!(
|
|
partial_canonicalize(&path.join("absolute-link-to-dir"))
|
|
.await
|
|
.unwrap(),
|
|
path.join("dir")
|
|
);
|
|
assert_eq!(
|
|
partial_canonicalize(&path.join("relative-link-to-dir"))
|
|
.await
|
|
.unwrap(),
|
|
path.join("dir")
|
|
);
|
|
assert_eq!(
|
|
partial_canonicalize(&path.join("absolute-link-to-dir/new-dir"))
|
|
.await
|
|
.unwrap(),
|
|
path.join("dir/new-dir")
|
|
);
|
|
assert_eq!(
|
|
partial_canonicalize(
|
|
&path.join("absolute-link-to-dir/new-dir/../..")
|
|
)
|
|
.await
|
|
.unwrap(),
|
|
path,
|
|
);
|
|
|
|
tmp_dir.close().unwrap();
|
|
}
|
|
}
|