diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 640060c7..608c80fb 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,6 +1,7 @@ use std::time::Duration; use axum::response::IntoResponse; +use either::Either; use http::{ header::{CONTENT_DISPOSITION, CONTENT_SECURITY_POLICY, CONTENT_TYPE}, HeaderName, HeaderValue, Method, @@ -20,7 +21,7 @@ use ruma::{ use tracing::{debug, error, info, warn}; use crate::{ - service::media::{FileMeta, MxcData}, + service::media::{FileMeta, MediaCreationToken, MxcData, ThumbnailResult}, services, utils, Ar, Error, Ra, Result, }; @@ -154,10 +155,15 @@ pub(crate) async fn create_content_route( let media_id = utils::random_string(MXC_LENGTH); let mxc = MxcData::new(services().globals.server_name(), &media_id)?; + // this is a fresh MXC, there ought to not be anyone else creating it at the + // same time + let Either::Right(token) = services().media.get(mxc).await? else { + return Err(Error::bad_database("Media for fresh mxc already exists")); + }; services() .media .create( - mxc, + token, body.filename .clone() .map(|filename| ContentDisposition { @@ -319,9 +325,9 @@ async fn get_remote_content_via_legacy_api( #[tracing::instrument] pub(crate) async fn get_remote_content( - mxc: MxcData<'_>, + token: MediaCreationToken, ) -> Result { - let fed_result = get_remote_content_via_federation_api(mxc).await; + let fed_result = get_remote_content_via_federation_api(token.mxc()).await; let response = match fed_result { Ok(response) => { @@ -338,7 +344,7 @@ pub(crate) async fn get_remote_content( back to deprecated API" ); - get_remote_content_via_legacy_api(mxc).await? + get_remote_content_via_legacy_api(token.mxc()).await? } Err(e) => { return Err(e); @@ -348,7 +354,7 @@ pub(crate) async fn get_remote_content( services() .media .create( - mxc, + token, response.content.content_disposition.as_ref(), response.content.content_type.clone(), &response.content.file, @@ -447,37 +453,45 @@ async fn get_content_route_ruma( ) -> Result { let mxc = MxcData::new(&body.server_name, &body.media_id)?; - if let Some(( - FileMeta { - content_type, - .. - }, - file, - )) = services().media.get(mxc).await? - { - Ok(authenticated_media_client::get_content::v1::Response { + let token = match services().media.get(mxc).await? { + Either::Left(( + FileMeta { + content_type, + .. + }, file, - content_disposition: Some(content_disposition_for( - content_type.as_deref(), - None, - )), - content_type, - }) - } else if &*body.server_name != services().globals.server_name() - && allow_remote == AllowRemote::Yes + )) => { + return Ok(authenticated_media_client::get_content::v1::Response { + file, + content_disposition: Some(content_disposition_for( + content_type.as_deref(), + None, + )), + content_type, + }); + } + + Either::Right(token) => token, + }; + + if &*body.server_name == services().globals.server_name() + || allow_remote != AllowRemote::Yes { - let remote_response = get_remote_content(mxc).await?; - Ok(authenticated_media_client::get_content::v1::Response { - file: remote_response.content.file, - content_disposition: Some(content_disposition_for( - remote_response.content.content_type.as_deref(), - None, - )), - content_type: remote_response.content.content_type, - }) - } else { - Err(Error::BadRequest(ErrorKind::NotYetUploaded, "Media not found.")) + return Err(Error::BadRequest( + ErrorKind::NotYetUploaded, + "Media not found.", + )); } + + let remote_response = get_remote_content(token).await?; + Ok(authenticated_media_client::get_content::v1::Response { + file: remote_response.content.file, + content_disposition: Some(content_disposition_for( + remote_response.content.content_type.as_deref(), + None, + )), + content_type: remote_response.content.content_type, + }) } /// # `GET /_matrix/media/r0/download/{serverName}/{mediaId}/{fileName}` @@ -571,30 +585,36 @@ async fn get_content_as_filename_route_ruma( body: Ar, allow_remote: AllowRemote, ) -> Result { + use authenticated_media_client::get_content_as_filename::v1::Response; + let mxc = MxcData::new(&body.server_name, &body.media_id)?; - if let Some(( - FileMeta { - content_type, - .. - }, - file, - )) = services().media.get(mxc).await? - { - Ok(authenticated_media_client::get_content_as_filename::v1::Response { + let token = match services().media.get(mxc).await? { + Either::Left(( + FileMeta { + content_type, + .. + }, file, - content_disposition: Some(content_disposition_for( - content_type.as_deref(), - Some(body.filename.clone()), - )), - content_type, - }) - } else if &*body.server_name != services().globals.server_name() + )) => { + return Ok(Response { + file, + content_disposition: Some(content_disposition_for( + content_type.as_deref(), + Some(body.filename.clone()), + )), + content_type, + }); + } + Either::Right(token) => token, + }; + + if &*body.server_name != services().globals.server_name() && allow_remote == AllowRemote::Yes { - let remote_response = get_remote_content(mxc).await?; + let remote_response = get_remote_content(token).await?; - Ok(authenticated_media_client::get_content_as_filename::v1::Response { + Ok(Response { content_disposition: Some(content_disposition_for( remote_response.content.content_type.as_deref(), Some(body.filename.clone()), @@ -834,77 +854,99 @@ async fn get_content_thumbnail_route_ruma( } }; - if let Some(( + let lookup_result = + services().media.get_thumbnail(mxc, width, height).await?; + + if let ThumbnailResult::Data( FileMeta { content_type, .. }, file, - )) = services().media.get_thumbnail(mxc, width, height).await? + ) = lookup_result { return Ok(make_response(file, content_type)); } - if &*body.server_name != services().globals.server_name() - && allow_remote == AllowRemote::Yes + if &*body.server_name == services().globals.server_name() + || allow_remote != AllowRemote::Yes { - let get_thumbnail_response = get_remote_thumbnail( - &body.server_name, - authenticated_media_fed::get_content_thumbnail::v1::Request { - height: body.height, - width: body.width, - method: body.method.clone(), - media_id: body.media_id.clone(), - timeout_ms: Duration::from_secs(20), - // we don't support animated thumbnails, so don't try requesting - // one - we're allowed to ignore the client's request for an - // animated thumbnail - animated: Some(false), - }, - ) - .await; - - match get_thumbnail_response { - Ok(resp) => { - services() - .media - .upload_thumbnail( - mxc, - None, - resp.content.content_type.clone(), - width, - height, - &resp.content.file, - ) - .await?; - - return Ok(make_response( - resp.content.file, - resp.content.content_type, - )); - } - Err(error) => warn!( - %error, - "Failed to fetch thumbnail via federation, trying to fetch \ - original media and create thumbnail ourselves" - ), - } - - get_remote_content(mxc).await?; - - if let Some(( - FileMeta { - content_type, - .. - }, - file, - )) = services().media.get_thumbnail(mxc, width, height).await? - { - return Ok(make_response(file, content_type)); - } - - error!("Source media doesn't exist even after fetching it from remote"); + return Err(Error::BadRequest( + ErrorKind::NotYetUploaded, + "Media not found.", + )); } + let source_token = match lookup_result { + ThumbnailResult::Data(..) => unreachable!(), + ThumbnailResult::NeedSource(token) => { + // we need to fetch the whole media + token + } + ThumbnailResult::NeedThumbnail(thumbnail_token, source_token) => { + // try to fetch thumbnail from remote, falling back to fetching the + // whole media anyway + + let get_thumbnail_response = get_remote_thumbnail( + &body.server_name, + authenticated_media_fed::get_content_thumbnail::v1::Request { + height: thumbnail_token.height().into(), + width: thumbnail_token.width().into(), + method: body.method.clone(), + media_id: body.media_id.clone(), + timeout_ms: Duration::from_secs(20), + // we don't support animated thumbnails, so don't try + // requesting one - we're allowed to + // ignore the client's request for an + // animated thumbnail + animated: Some(false), + }, + ) + .await; + + match get_thumbnail_response { + Ok(resp) => { + services() + .media + .create_thumbnail( + thumbnail_token, + None, + resp.content.content_type.clone(), + &resp.content.file, + ) + .await?; + + return Ok(make_response( + resp.content.file, + resp.content.content_type, + )); + } + Err(error) => { + warn!( + %error, + "Failed to fetch thumbnail via federation, trying to fetch \ + original media and create thumbnail ourselves" + ); + } + } + + source_token + } + }; + + get_remote_content(source_token).await?; + + if let ThumbnailResult::Data( + FileMeta { + content_type, + .. + }, + file, + ) = services().media.get_thumbnail(mxc, width, height).await? + { + return Ok(make_response(file, content_type)); + } + + error!("Source media doesn't exist even after fetching it from remote"); Err(Error::BadRequest(ErrorKind::NotYetUploaded, "Media not found.")) } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 654e2528..7f655eb6 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -10,6 +10,7 @@ use std::{ use axum::{response::IntoResponse, Json}; use axum_extra::headers::{Authorization, HeaderMapExt}; use base64::Engine as _; +use either::Either; use get_profile_information::v1::ProfileField; use ruma::{ api::{ @@ -71,7 +72,7 @@ use crate::{ api::client_server::{self, claim_keys_helper, get_keys_helper}, observability::{FoundIn, Lookup, METRICS}, service::{ - media::MxcData, + media::{MxcData, ThumbnailResult}, pdu::{gen_event_id_canonical_json, PduBuilder}, }, services, @@ -2047,7 +2048,7 @@ pub(crate) async fn media_download_route( body: Ar, ) -> Result> { let mxc = MxcData::new(services().globals.server_name(), &body.media_id)?; - let Some(( + let Either::Left(( crate::service::media::FileMeta { content_disposition, content_type, @@ -2094,13 +2095,13 @@ pub(crate) async fn media_thumbnail_route( Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid.") })?; - let Some(( + let ThumbnailResult::Data( crate::service::media::FileMeta { content_type, .. }, file, - )) = services().media.get_thumbnail(mxc, width, height).await? + ) = services().media.get_thumbnail(mxc, width, height).await? else { return Err(Error::BadRequest( ErrorKind::NotYetUploaded, diff --git a/src/service.rs b/src/service.rs index 61216a61..4d851dee 100644 --- a/src/service.rs +++ b/src/service.rs @@ -145,9 +145,7 @@ impl Services { account_data: db, admin: admin::Service::build(), key_backups: db, - media: media::Service { - db, - }, + media: media::Service::new(db), sending: sending::Service::build(db, &config), globals: globals::Service::load(db, config, reload_handles)?, diff --git a/src/service/media.rs b/src/service/media.rs index b1b80485..955dadb4 100644 --- a/src/service/media.rs +++ b/src/service/media.rs @@ -1,17 +1,22 @@ use std::{fmt, io::Cursor}; +use either::Either; use image::imageops::FilterType; use ruma::{ api::client::error::ErrorKind, http_headers::ContentDisposition, MxcUri, - MxcUriError, OwnedMxcUri, + MxcUriError, OwnedMxcUri, OwnedServerName, }; use tokio::{ fs::File, io::{AsyncReadExt, AsyncWriteExt}, }; -use tracing::{debug, warn}; +use tracing::{debug, instrument, warn}; -use crate::{services, Error, Result}; +use crate::{ + services, + utils::on_demand_hashmap::{KeyToken, TokenSet}, + Error, Result, +}; mod data; @@ -91,16 +96,158 @@ impl<'a> TryFrom<&'a MxcUri> for MxcData<'a> { } } +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +struct ThumbnailDimensions { + width: u32, + height: u32, + should_crop: bool, +} + +impl ThumbnailDimensions { + /// Returns None when the server should send the original file. + fn new(width: u32, height: u32) -> Option { + match (width, height) { + (0..=32, 0..=32) => Some(Self { + width: 32, + height: 32, + should_crop: true, + }), + (0..=96, 0..=96) => Some(Self { + width: 96, + height: 96, + should_crop: true, + }), + (0..=320, 0..=240) => Some(Self { + width: 320, + height: 240, + should_crop: false, + }), + (0..=640, 0..=480) => Some(Self { + width: 640, + height: 480, + should_crop: false, + }), + (0..=800, 0..=600) => Some(Self { + width: 800, + height: 600, + should_crop: false, + }), + _ => None, + } + } +} + +#[derive(Clone, Eq, Hash, PartialEq)] +struct MediaCreationData { + server_name: OwnedServerName, + media_id: String, + thumbnail_dimensions: Option, +} + +impl MediaCreationData { + fn mxc(&self) -> MxcData<'_> { + MxcData { + server_name: &self.server_name, + media_id: &self.media_id, + } + } +} + +impl fmt::Debug for MediaCreationData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.mxc())?; + if let Some(ThumbnailDimensions { + width, + height, + .. + }) = self.thumbnail_dimensions + { + write!(f, " @ {width}x{height}")?; + } + + Ok(()) + } +} + +/// Used to serialize media creation. +/// +/// There can only be up to one holder of a `MediaCreationToken` for a given mxc +/// at any one time. +/// +/// Returned by [`Service::get()`] when the file is not available locally. +#[derive(Debug)] +pub(crate) struct MediaCreationToken(KeyToken); + +impl MediaCreationToken { + pub(crate) fn mxc(&self) -> MxcData<'_> { + self.0.mxc() + } +} + +/// Used to serialize thumbnail creation. +/// +/// There can only be up to one holder of a `ThumbnailCreationToken` for a given +/// mxc and resolution at any one time. +/// +/// Returned by [`Service::get_thumbnail()`] when the thumbnail is not available +/// locally. +#[derive(Debug)] +pub(crate) struct ThumbnailCreationToken(KeyToken); + +impl ThumbnailCreationToken { + pub(crate) fn width(&self) -> u32 { + self.0 + .thumbnail_dimensions + .as_ref() + .expect("thumbnail creation token should have thumbnail dimensions") + .width + } + + pub(crate) fn height(&self) -> u32 { + self.0 + .thumbnail_dimensions + .as_ref() + .expect("thumbnail creation token should have thumbnail dimensions") + .height + } +} + +pub(crate) enum ThumbnailResult { + /// Thumbnail available, corresponding metadata and file content returned. + Data(FileMeta, Vec), + /// Source file for thumbnail is required. + NeedSource(MediaCreationToken), + /// Thumbnail data is required (or can be generated from source file, once + /// present) + NeedThumbnail(ThumbnailCreationToken, MediaCreationToken), +} + pub(crate) struct Service { pub(crate) db: &'static dyn Data, + + creation_tokens: TokenSet, } impl Service { - /// Uploads a file. - #[tracing::instrument(skip(self, file))] - pub(crate) async fn create( + pub(crate) fn new(db: &'static dyn Data) -> Service { + Self { + db, + creation_tokens: TokenSet::new("media_creation".to_owned()), + } + } + + #[instrument(skip(self))] + async fn lock_creation( &self, - mxc: MxcData<'_>, + creation_data: MediaCreationData, + ) -> KeyToken { + self.creation_tokens.lock_key(creation_data).await + } + + #[tracing::instrument(skip(self, file))] + async fn create_inner( + &self, + creation_data: &MediaCreationData, content_disposition: Option<&ContentDisposition>, content_type: Option, file: &[u8], @@ -110,8 +257,23 @@ impl Service { .map(ContentDisposition::to_string), content_type, }; - // Width, Height = 0 if it's not a thumbnail - let key = self.db.create_file_metadata(mxc, 0, 0, &meta)?; + + let key = if let Some(ThumbnailDimensions { + width, + height, + .. + }) = creation_data.thumbnail_dimensions + { + self.db.create_file_metadata( + creation_data.mxc(), + width, + height, + &meta, + )? + } else { + // Width, Height = 0 if it's not a thumbnail + self.db.create_file_metadata(creation_data.mxc(), 0, 0, &meta)? + }; let path = services().globals.get_media_file(&key); let mut f = File::create(path).await?; @@ -119,65 +281,102 @@ impl Service { Ok(meta) } + /// Uploads or replaces a file. + pub(crate) async fn create( + &self, + MediaCreationToken(creation_data): MediaCreationToken, + content_disposition: Option<&ContentDisposition>, + content_type: Option, + file: &[u8], + ) -> Result { + self.create_inner( + &creation_data, + content_disposition, + content_type, + file, + ) + .await + } + /// Uploads or replaces a file thumbnail. - #[tracing::instrument(skip(self, file))] - pub(crate) async fn upload_thumbnail( + pub(crate) async fn create_thumbnail( &self, - mxc: MxcData<'_>, - content_disposition: Option, + ThumbnailCreationToken(creation_data): ThumbnailCreationToken, + content_disposition: Option<&ContentDisposition>, content_type: Option, - width: u32, - height: u32, file: &[u8], ) -> Result { - let meta = FileMeta { + self.create_inner( + &creation_data, content_disposition, content_type, - }; - let key = self.db.create_file_metadata(mxc, width, height, &meta)?; - - let path = services().globals.get_media_file(&key); - let mut f = File::create(path).await?; - f.write_all(file).await?; - - Ok(meta) + file, + ) + .await } - /// Downloads a file. #[tracing::instrument(skip(self))] + // internal function, and not actually that complex + #[allow(clippy::type_complexity)] + async fn get_inner( + &self, + mxc: MxcData<'_>, + thumbnail_dimensions: Option, + ) -> Result), KeyToken>> { + let try_get = || async { + if let Ok((meta, key)) = self.db.search_file_metadata(mxc, 0, 0) { + let path = services().globals.get_media_file(&key); + let mut file_data = Vec::new(); + let Ok(mut file) = File::open(path).await else { + return Ok::<_, Error>(None); + }; + + file.read_to_end(&mut file_data).await?; + + Ok(Some((meta, file_data))) + } else { + Ok(None) + } + }; + + if let Some(ret) = try_get().await? { + return Ok(Either::Left(ret)); + } + + let token = self + .lock_creation(MediaCreationData { + server_name: mxc.server_name.to_owned(), + media_id: mxc.media_id.to_owned(), + thumbnail_dimensions, + }) + .await; + + // check again in case it has been created in the meantime + if let Some(ret) = try_get().await? { + Ok(Either::Left(ret)) + } else { + Ok(Either::Right(token)) + } + } + + /// Loads a file from local storage. pub(crate) async fn get( &self, mxc: MxcData<'_>, - ) -> Result)>> { - if let Ok((meta, key)) = self.db.search_file_metadata(mxc, 0, 0) { - let path = services().globals.get_media_file(&key); - let mut file_data = Vec::new(); - let Ok(mut file) = File::open(path).await else { - return Ok(None); - }; - - file.read_to_end(&mut file_data).await?; - - Ok(Some((meta, file_data))) - } else { - Ok(None) - } + ) -> Result), MediaCreationToken>> { + Ok(self.get_inner(mxc, None).await?.map_right(MediaCreationToken)) } - /// Returns width, height of the thumbnail and whether it should be cropped. - /// Returns None when the server should send the original file. - fn thumbnail_properties( - width: u32, - height: u32, - ) -> Option<(u32, u32, bool)> { - match (width, height) { - (0..=32, 0..=32) => Some((32, 32, true)), - (0..=96, 0..=96) => Some((96, 96, true)), - (0..=320, 0..=240) => Some((320, 240, false)), - (0..=640, 0..=480) => Some((640, 480, false)), - (0..=800, 0..=600) => Some((800, 600, false)), - _ => None, - } + /// Loads a thumbnail from local storage. + async fn get_thumbnail_local( + &self, + mxc: MxcData<'_>, + dimensions: ThumbnailDimensions, + ) -> Result), ThumbnailCreationToken>> { + Ok(self + .get_inner(mxc, Some(dimensions)) + .await? + .map_right(ThumbnailCreationToken)) } /// Generates a thumbnail from the given image file contents. Returns @@ -262,7 +461,7 @@ impl Service { Ok(Some(thumbnail_bytes)) } - /// Downloads a file's thumbnail. + /// Loads or generates a file's thumbnail. /// /// Here's an example on how it works: /// @@ -281,30 +480,36 @@ impl Service { mxc: MxcData<'_>, width: u32, height: u32, - ) -> Result)>> { - // 0, 0 because that's the original file - let (width, height, crop) = - Self::thumbnail_properties(width, height).unwrap_or((0, 0, false)); - - if let Ok((meta, key)) = - self.db.search_file_metadata(mxc, width, height) - { - debug!("Using saved thumbnail"); - let path = services().globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - - return Ok(Some((meta, file.clone()))); - } - - let Ok((meta, key)) = self.db.search_file_metadata(mxc, 0, 0) else { - debug!("Original image not found, can't generate thumbnail"); - return Ok(None); + ) -> Result { + let Some(dimensions) = ThumbnailDimensions::new(width, height) else { + // image should be used as-is + return match self.get(mxc).await? { + Either::Left((meta, file)) => { + Ok(ThumbnailResult::Data(meta, file)) + } + Either::Right(token) => Ok(ThumbnailResult::NeedSource(token)), + }; }; - let path = services().globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; + let thumbnail_token = + match self.get_thumbnail_local(mxc, dimensions).await? { + Either::Left((meta, file)) => { + debug!("Using saved thumbnail"); + return Ok(ThumbnailResult::Data(meta, file)); + } + Either::Right(token) => token, + }; + + let (meta, file) = match self.get(mxc).await? { + Either::Left(ret) => ret, + Either::Right(media_token) => { + debug!("Original image not found, can't generate thumbnail"); + return Ok(ThumbnailResult::NeedThumbnail( + thumbnail_token, + media_token, + )); + } + }; debug!("Generating thumbnail"); let thumbnail_result = { @@ -313,7 +518,12 @@ impl Service { tokio::task::spawn_blocking(move || { outer_span.in_scope(|| { - Self::generate_thumbnail(&file, width, height, crop) + Self::generate_thumbnail( + &file, + width, + height, + dimensions.should_crop, + ) }) }) .await @@ -322,22 +532,20 @@ impl Service { let Some(thumbnail_bytes) = thumbnail_result? else { debug!("Returning source image as-is"); - return Ok(Some((meta, file))); + return Ok(ThumbnailResult::Data(meta, file)); }; debug!("Saving created thumbnail"); let meta = self - .upload_thumbnail( - mxc, - meta.content_disposition, + .create_thumbnail( + thumbnail_token, + None, meta.content_type, - width, - height, &thumbnail_bytes, ) .await?; - Ok(Some((meta, thumbnail_bytes.clone()))) + Ok(ThumbnailResult::Data(meta, thumbnail_bytes.clone())) } }