diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 552fe4bb..79f44481 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -15,7 +15,12 @@ use ruma::{ }; use tracing::error; -use crate::{service::media::FileMeta, services, utils, Ar, Error, Ra, Result}; +use crate::{ + service::media::FileMeta, + services, + utils::{self, MxcData}, + Ar, Error, Ra, Result, +}; const MXC_LENGTH: usize = 32; @@ -133,16 +138,13 @@ pub(crate) async fn get_media_config_route( pub(crate) async fn create_content_route( body: Ar, ) -> Result> { - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(MXC_LENGTH) - ); + let media_id = utils::random_string(MXC_LENGTH); + let mxc = MxcData::new(services().globals.server_name(), &media_id)?; services() .media .create( - mxc.clone(), + mxc.to_string(), body.filename .clone() .map(|filename| ContentDisposition { @@ -163,18 +165,16 @@ pub(crate) async fn create_content_route( #[allow(deprecated)] // unauthenticated media pub(crate) async fn get_remote_content( - mxc: &str, - server_name: &ruma::ServerName, - media_id: String, + mxc: &MxcData<'_>, ) -> Result { let content_response = services() .sending .send_federation_request( - server_name, + mxc.server_name, legacy_media::get_content::v3::Request { allow_remote: false, - server_name: server_name.to_owned(), - media_id, + server_name: mxc.server_name.to_owned(), + media_id: mxc.media_id.to_owned(), timeout_ms: Duration::from_secs(20), allow_redirect: false, }, @@ -184,7 +184,7 @@ pub(crate) async fn get_remote_content( services() .media .create( - mxc.to_owned(), + mxc.to_string(), content_response.content_disposition.as_ref(), content_response.content_type.as_deref(), &content_response.file, @@ -225,13 +225,13 @@ pub(crate) async fn get_content_route( async fn get_content_route_ruma( body: Ar, ) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); + let mxc = MxcData::new(&body.server_name, &body.media_id)?; if let Some(FileMeta { content_type, file, .. - }) = services().media.get(mxc.clone()).await? + }) = services().media.get(mxc.to_string()).await? { Ok(legacy_media::get_content::v3::Response { file, @@ -245,9 +245,7 @@ async fn get_content_route_ruma( } else if &*body.server_name != services().globals.server_name() && body.allow_remote { - let remote_content_response = - get_remote_content(&mxc, &body.server_name, body.media_id.clone()) - .await?; + let remote_content_response = get_remote_content(&mxc).await?; Ok(legacy_media::get_content::v3::Response { file: remote_content_response.file, content_disposition: Some(content_disposition_for( @@ -288,13 +286,13 @@ pub(crate) async fn get_content_as_filename_route( pub(crate) async fn get_content_as_filename_route_ruma( body: Ar, ) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); + let mxc = MxcData::new(&body.server_name, &body.media_id)?; if let Some(FileMeta { content_type, file, .. - }) = services().media.get(mxc.clone()).await? + }) = services().media.get(mxc.to_string()).await? { Ok(legacy_media::get_content_as_filename::v3::Response { file, @@ -308,9 +306,7 @@ pub(crate) async fn get_content_as_filename_route_ruma( } else if &*body.server_name != services().globals.server_name() && body.allow_remote { - let remote_content_response = - get_remote_content(&mxc, &body.server_name, body.media_id.clone()) - .await?; + let remote_content_response = get_remote_content(&mxc).await?; Ok(legacy_media::get_content_as_filename::v3::Response { content_disposition: Some(content_disposition_for( @@ -366,7 +362,7 @@ pub(crate) async fn get_content_thumbnail_route( async fn get_content_thumbnail_route_ruma( body: Ar, ) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); + let mxc = MxcData::new(&body.server_name, &body.media_id)?; if let Some(FileMeta { content_type, @@ -375,7 +371,7 @@ async fn get_content_thumbnail_route_ruma( }) = services() .media .get_thumbnail( - mxc.clone(), + mxc.to_string(), body.width.try_into().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.") })?, @@ -414,7 +410,7 @@ async fn get_content_thumbnail_route_ruma( services() .media .upload_thumbnail( - mxc, + mxc.to_string(), None, get_thumbnail_response.content_type.as_deref(), body.width.try_into().expect("all UInts are valid u32s"), diff --git a/src/utils.rs b/src/utils.rs index 52956bf6..c1e2d6a6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -13,9 +13,12 @@ use cmp::Ordering; use rand::{prelude::*, rngs::OsRng}; use ring::digest; use ruma::{ - canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, + api::client::error::ErrorKind, canonical_json::try_from_json_map, + CanonicalJsonError, CanonicalJsonObject, MxcUri, MxcUriError, OwnedMxcUri, }; +use crate::{Error, Result}; + // Hopefully we have a better chat protocol in 530 years #[allow(clippy::as_conversions, clippy::cast_possible_truncation)] pub(crate) fn millis_since_unix_epoch() -> u64 { @@ -243,6 +246,57 @@ pub(crate) fn dbg_truncate_str(s: &str, mut max_len: usize) -> Cow<'_, str> { } } +/// Data that makes up an `mxc://` URL. +#[derive(Debug, Clone)] +pub(crate) struct MxcData<'a> { + pub(crate) server_name: &'a ruma::ServerName, + pub(crate) media_id: &'a str, +} + +impl<'a> MxcData<'a> { + pub(crate) fn new( + server_name: &'a ruma::ServerName, + media_id: &'a str, + ) -> Result { + if !media_id.bytes().all(|b| { + matches!(b, + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'-' | b'_' + ) + }) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid MXC media id", + )); + } + + Ok(Self { + server_name, + media_id, + }) + } +} + +impl fmt::Display for MxcData<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "mxc://{}/{}", self.server_name, self.media_id) + } +} + +impl From> for OwnedMxcUri { + fn from(value: MxcData<'_>) -> Self { + value.to_string().into() + } +} + +impl<'a> TryFrom<&'a MxcUri> for MxcData<'a> { + type Error = MxcUriError; + + fn try_from(value: &'a MxcUri) -> Result { + Ok(Self::new(value.server_name()?, value.media_id()?) + .expect("validated MxcUri should always be valid MxcData")) + } +} + #[cfg(test)] mod tests { use crate::utils::dbg_truncate_str;