diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index be8df36bb5..3e5e71d2a5 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -13,13 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(feature = "e2e-encryption")] -use std::ops::Deref; use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt, iter, - sync::Arc, }; +#[cfg(feature = "e2e-encryption")] +use std::{ops::Deref, sync::Arc}; use eyeball::{SharedObservable, Subscriber}; #[cfg(not(target_arch = "wasm32"))] @@ -71,7 +70,7 @@ use crate::RoomMemberships; use crate::{ deserialized_responses::{RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent}, error::{Error, Result}, - event_cache_store::DynEventCacheStore, + event_cache_store::EventCacheStoreWrapper, rooms::{ normal::{RoomInfoNotableUpdate, RoomInfoNotableUpdateReasons}, Room, RoomInfo, RoomState, @@ -93,7 +92,7 @@ pub struct BaseClient { /// Database pub(crate) store: Store, /// The store used by the event cache. - event_cache_store: Arc, + event_cache_store: EventCacheStoreWrapper, /// The store used for encryption. /// /// This field is only meant to be used for `OlmMachine` initialization. @@ -147,7 +146,7 @@ impl BaseClient { BaseClient { store: Store::new(config.state_store), - event_cache_store: config.event_cache_store, + event_cache_store: EventCacheStoreWrapper::new(config.event_cache_store), #[cfg(feature = "e2e-encryption")] crypto_store: config.crypto_store, #[cfg(feature = "e2e-encryption")] @@ -222,8 +221,8 @@ impl BaseClient { } /// Get a reference to the event cache store. - pub fn event_cache_store(&self) -> &DynEventCacheStore { - &*self.event_cache_store + pub fn event_cache_store(&self) -> &EventCacheStoreWrapper { + &self.event_cache_store } /// Is the client logged in. diff --git a/crates/matrix-sdk-base/src/event_cache_store/integration_tests.rs b/crates/matrix-sdk-base/src/event_cache_store/integration_tests.rs index 2a0cc30faf..dfc5049bec 100644 --- a/crates/matrix-sdk-base/src/event_cache_store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/event_cache_store/integration_tests.rs @@ -14,12 +14,15 @@ //! Trait and macro of integration tests for `EventCacheStore` implementations. +use std::time::Duration; + use async_trait::async_trait; use ruma::{ - api::client::media::get_content_thumbnail::v3::Method, events::room::MediaSource, mxc_uri, uint, + api::client::media::get_content_thumbnail::v3::Method, events::room::MediaSource, mxc_uri, + owned_mxc_uri, time::SystemTime, uint, }; -use super::DynEventCacheStore; +use super::{DynEventCacheStore, MediaRetentionPolicy}; use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSettings}; /// `EventCacheStore` integration tests. @@ -31,6 +34,18 @@ use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSettings}; pub trait EventCacheStoreIntegrationTests { /// Test media content storage. async fn test_media_content(&self); + + /// Test media retention policy storage. + async fn test_store_media_retention_policy(&self); + + /// Test media content's retention policy max file size. + async fn test_media_max_file_size(&self); + + /// Test media content's retention policy max file size. + async fn test_media_max_cache_size(&self); + + /// Test media content's retention policy expiry. + async fn test_media_expiry(&self); } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -49,32 +64,34 @@ impl EventCacheStoreIntegrationTests for DynEventCacheStore { )), }; - let other_uri = mxc_uri!("mxc://localhost/media-other"); - let request_other_file = MediaRequest { - source: MediaSource::Plain(other_uri.to_owned()), - format: MediaFormat::File, - }; + let other_uri = owned_mxc_uri!("mxc://localhost/media-other"); + let request_other_file = + MediaRequest { source: MediaSource::Plain(other_uri), format: MediaFormat::File }; let content: Vec = "hello".into(); let thumbnail_content: Vec = "world".into(); let other_content: Vec = "foo".into(); + let time = SystemTime::now(); + let policy = MediaRetentionPolicy::empty(); // Media isn't present in the cache. assert!( - self.get_media_content(&request_file).await.unwrap().is_none(), + self.get_media_content(&request_file, time).await.unwrap().is_none(), "unexpected media found" ); assert!( - self.get_media_content(&request_thumbnail).await.unwrap().is_none(), + self.get_media_content(&request_thumbnail, time).await.unwrap().is_none(), "media not found" ); // Let's add the media. - self.add_media_content(&request_file, content.clone()).await.expect("adding media failed"); + self.add_media_content(&request_file, content.clone(), time, policy) + .await + .expect("adding media failed"); // Media is present in the cache. assert_eq!( - self.get_media_content(&request_file).await.unwrap().as_ref(), + self.get_media_content(&request_file, time).await.unwrap().as_ref(), Some(&content), "media not found though added" ); @@ -84,41 +101,41 @@ impl EventCacheStoreIntegrationTests for DynEventCacheStore { // Media isn't present in the cache. assert!( - self.get_media_content(&request_file).await.unwrap().is_none(), + self.get_media_content(&request_file, time).await.unwrap().is_none(), "media still there after removing" ); // Let's add the media again. - self.add_media_content(&request_file, content.clone()) + self.add_media_content(&request_file, content.clone(), time, policy) .await .expect("adding media again failed"); assert_eq!( - self.get_media_content(&request_file).await.unwrap().as_ref(), + self.get_media_content(&request_file, time).await.unwrap().as_ref(), Some(&content), "media not found after adding again" ); // Let's add the thumbnail media. - self.add_media_content(&request_thumbnail, thumbnail_content.clone()) + self.add_media_content(&request_thumbnail, thumbnail_content.clone(), time, policy) .await .expect("adding thumbnail failed"); // Media's thumbnail is present. assert_eq!( - self.get_media_content(&request_thumbnail).await.unwrap().as_ref(), + self.get_media_content(&request_thumbnail, time).await.unwrap().as_ref(), Some(&thumbnail_content), "thumbnail not found" ); // Let's add another media with a different URI. - self.add_media_content(&request_other_file, other_content.clone()) + self.add_media_content(&request_other_file, other_content.clone(), time, policy) .await .expect("adding other media failed"); // Other file is present. assert_eq!( - self.get_media_content(&request_other_file).await.unwrap().as_ref(), + self.get_media_content(&request_other_file, time).await.unwrap().as_ref(), Some(&other_content), "other file not found" ); @@ -127,23 +144,411 @@ impl EventCacheStoreIntegrationTests for DynEventCacheStore { self.remove_media_content_for_uri(uri).await.expect("removing all media for uri failed"); assert!( - self.get_media_content(&request_file).await.unwrap().is_none(), + self.get_media_content(&request_file, time).await.unwrap().is_none(), "media wasn't removed" ); assert!( - self.get_media_content(&request_thumbnail).await.unwrap().is_none(), + self.get_media_content(&request_thumbnail, time).await.unwrap().is_none(), "thumbnail wasn't removed" ); assert!( - self.get_media_content(&request_other_file).await.unwrap().is_some(), + self.get_media_content(&request_other_file, time).await.unwrap().is_some(), "other media was removed" ); } + + async fn test_store_media_retention_policy(&self) { + let stored = self.media_retention_policy().await.unwrap(); + assert!(stored.is_none()); + + let policy = MediaRetentionPolicy::default(); + self.set_media_retention_policy(policy).await.unwrap(); + + let stored = self.media_retention_policy().await.unwrap(); + assert_eq!(stored, Some(policy)); + } + + async fn test_media_max_file_size(&self) { + let time = SystemTime::now(); + + // 256 bytes content. + let content_big = vec![0; 256]; + let uri_big = owned_mxc_uri!("mxc://localhost/big-media"); + let request_big = + MediaRequest { source: MediaSource::Plain(uri_big), format: MediaFormat::File }; + + // 128 bytes content. + let content_avg = vec![0; 128]; + let uri_avg = owned_mxc_uri!("mxc://localhost/average-media"); + let request_avg = + MediaRequest { source: MediaSource::Plain(uri_avg), format: MediaFormat::File }; + + // 64 bytes content. + let content_small = vec![0; 64]; + let uri_small = owned_mxc_uri!("mxc://localhost/small-media"); + let request_small = + MediaRequest { source: MediaSource::Plain(uri_small), format: MediaFormat::File }; + + // First, with a policy that doesn't accept the big media. + let policy = MediaRetentionPolicy::empty().with_max_file_size(Some(200)); + + self.add_media_content(&request_big, content_big.clone(), time, policy).await.unwrap(); + self.add_media_content(&request_avg, content_avg.clone(), time, policy).await.unwrap(); + self.add_media_content(&request_small, content_small, time, policy).await.unwrap(); + + // The big content was NOT cached but the others were. + let stored = self.get_media_content(&request_big, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_some()); + let stored = self.get_media_content(&request_small, time).await.unwrap(); + assert!(stored.is_some()); + + // A cleanup doesn't have any effect. + self.clean_up_media_cache(policy, time).await.unwrap(); + + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_some()); + let stored = self.get_media_content(&request_small, time).await.unwrap(); + assert!(stored.is_some()); + + // Change to a policy that doesn't accept the average media. + let policy = MediaRetentionPolicy::empty().with_max_file_size(Some(100)); + + // The cleanup removes the average media. + self.clean_up_media_cache(policy, time).await.unwrap(); + + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_small, time).await.unwrap(); + assert!(stored.is_some()); + + // Caching big and average media doesn't work. + self.add_media_content(&request_big, content_big.clone(), time, policy).await.unwrap(); + self.add_media_content(&request_avg, content_avg.clone(), time, policy).await.unwrap(); + + let stored = self.get_media_content(&request_big, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_none()); + + // If there are both a cache size and a file size, the minimum value is used. + let policy = MediaRetentionPolicy::empty() + .with_max_cache_size(Some(200)) + .with_max_file_size(Some(1000)); + + // Caching big doesn't work. + self.add_media_content(&request_big, content_big.clone(), time, policy).await.unwrap(); + self.add_media_content(&request_avg, content_avg.clone(), time, policy).await.unwrap(); + + let stored = self.get_media_content(&request_big, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_some()); + + // Change to a policy that doesn't accept the average media. + let policy = MediaRetentionPolicy::empty() + .with_max_cache_size(Some(100)) + .with_max_file_size(Some(1000)); + + // The cleanup removes the average media. + self.clean_up_media_cache(policy, time).await.unwrap(); + + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_small, time).await.unwrap(); + assert!(stored.is_some()); + + // Caching big and average media doesn't work. + self.add_media_content(&request_big, content_big, time, policy).await.unwrap(); + self.add_media_content(&request_avg, content_avg, time, policy).await.unwrap(); + + let stored = self.get_media_content(&request_big, time).await.unwrap(); + assert!(stored.is_none()); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_none()); + } + + async fn test_media_max_cache_size(&self) { + // 256 bytes content. + let content_big = vec![0; 256]; + let uri_big = owned_mxc_uri!("mxc://localhost/big-media"); + let request_big = + MediaRequest { source: MediaSource::Plain(uri_big), format: MediaFormat::File }; + + // 128 bytes content. + let content_avg = vec![0; 128]; + let uri_avg = owned_mxc_uri!("mxc://localhost/average-media"); + let request_avg = + MediaRequest { source: MediaSource::Plain(uri_avg), format: MediaFormat::File }; + + // 64 bytes content. + let content_small = vec![0; 64]; + let uri_small_1 = owned_mxc_uri!("mxc://localhost/small-media-1"); + let request_small_1 = + MediaRequest { source: MediaSource::Plain(uri_small_1), format: MediaFormat::File }; + let uri_small_2 = owned_mxc_uri!("mxc://localhost/small-media-2"); + let request_small_2 = + MediaRequest { source: MediaSource::Plain(uri_small_2), format: MediaFormat::File }; + let uri_small_3 = owned_mxc_uri!("mxc://localhost/small-media-3"); + let request_small_3 = + MediaRequest { source: MediaSource::Plain(uri_small_3), format: MediaFormat::File }; + let uri_small_4 = owned_mxc_uri!("mxc://localhost/small-media-4"); + let request_small_4 = + MediaRequest { source: MediaSource::Plain(uri_small_4), format: MediaFormat::File }; + let uri_small_5 = owned_mxc_uri!("mxc://localhost/small-media-5"); + let request_small_5 = + MediaRequest { source: MediaSource::Plain(uri_small_5), format: MediaFormat::File }; + + // A policy that doesn't accept the big media. + let policy = MediaRetentionPolicy::empty().with_max_cache_size(Some(200)); + + // Try to add all the content at different times. + let mut time = SystemTime::UNIX_EPOCH; + self.add_media_content(&request_big, content_big, time, policy).await.unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_1, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_2, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_3, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_4, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_5, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_avg, content_avg, time, policy).await.unwrap(); + + // The big content was NOT cached but the others were. + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_big, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_1, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_2, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_4, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_5, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_some()); + + // Cleanup removes the oldest content first. + time += Duration::from_secs(1); + self.clean_up_media_cache(policy, time).await.unwrap(); + + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_1, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_2, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_3, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_4, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_5, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_some()); + + // Reinsert the small medias that were removed. + time += Duration::from_secs(1); + self.add_media_content(&request_small_1, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_2, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_3, content_small.clone(), time, policy) + .await + .unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_small_4, content_small, time, policy).await.unwrap(); + + // Check that they are cached. + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_1, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_2, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_4, time).await.unwrap(); + assert!(stored.is_some()); + + // Access small_5 too so its last access is updated too. + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_5, time).await.unwrap(); + assert!(stored.is_some()); + + // Cleanup still removes the oldest content first, which is not the same as + // before. + time += Duration::from_secs(1); + self.clean_up_media_cache(policy, time).await.unwrap(); + + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_1, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_2, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_4, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_small_5, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_avg, time).await.unwrap(); + assert!(stored.is_none()); + } + + async fn test_media_expiry(&self) { + // 64 bytes content. + let content = vec![0; 64]; + + let uri_1 = owned_mxc_uri!("mxc://localhost/media-1"); + let request_1 = + MediaRequest { source: MediaSource::Plain(uri_1), format: MediaFormat::File }; + let uri_2 = owned_mxc_uri!("mxc://localhost/media-2"); + let request_2 = + MediaRequest { source: MediaSource::Plain(uri_2), format: MediaFormat::File }; + let uri_3 = owned_mxc_uri!("mxc://localhost/media-3"); + let request_3 = + MediaRequest { source: MediaSource::Plain(uri_3), format: MediaFormat::File }; + let uri_4 = owned_mxc_uri!("mxc://localhost/media-4"); + let request_4 = + MediaRequest { source: MediaSource::Plain(uri_4), format: MediaFormat::File }; + let uri_5 = owned_mxc_uri!("mxc://localhost/media-5"); + let request_5 = + MediaRequest { source: MediaSource::Plain(uri_5), format: MediaFormat::File }; + + // A policy with 30 seconds expiry. + let policy = + MediaRetentionPolicy::empty().with_last_access_expiry(Some(Duration::from_secs(30))); + + // Add all the content at different times. + let mut time = SystemTime::UNIX_EPOCH; + self.add_media_content(&request_1, content.clone(), time, policy).await.unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_2, content.clone(), time, policy).await.unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_3, content.clone(), time, policy).await.unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_4, content.clone(), time, policy).await.unwrap(); + time += Duration::from_secs(1); + self.add_media_content(&request_5, content, time, policy).await.unwrap(); + + // The content was cached. + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_1, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_2, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_4, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_5, time).await.unwrap(); + assert!(stored.is_some()); + + // We are now at UNIX_EPOCH + 10 seconds, the oldest content was accessed 5 + // seconds ago. + time += Duration::from_secs(1); + assert_eq!(time, SystemTime::UNIX_EPOCH + Duration::from_secs(10)); + + // Cleanup has no effect, nothing has expired. + self.clean_up_media_cache(policy, time).await.unwrap(); + + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_1, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_2, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_4, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_5, time).await.unwrap(); + assert!(stored.is_some()); + + // We are now at UNIX_EPOCH + 16 seconds, the oldest content was accessed 5 + // seconds ago. + time += Duration::from_secs(1); + assert_eq!(time, SystemTime::UNIX_EPOCH + Duration::from_secs(16)); + + // Jump 26 seconds in the future, so the 2 first media contents are expired. + time += Duration::from_secs(26); + + // Cleanup removes the two oldest media contents. + self.clean_up_media_cache(policy, time).await.unwrap(); + + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_1, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_2, time).await.unwrap(); + assert!(stored.is_none()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_3, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_4, time).await.unwrap(); + assert!(stored.is_some()); + time += Duration::from_secs(1); + let stored = self.get_media_content(&request_5, time).await.unwrap(); + assert!(stored.is_some()); + } } /// Macro building to allow your `EventCacheStore` implementation to run the /// entire tests suite locally. /// +/// Can be run with the `with_media_size_tests` argument to include more tests +/// about the media cache retention policy based on content size. It is not +/// recommended to run those in encrypted stores because the size of the +/// encrypted content may vary compared to what the tests expect. +/// /// You need to provide a `async fn get_event_cache_store() -> /// EventCacheStoreResult` providing a fresh event cache /// store on the same level you invoke the macro. @@ -171,19 +576,53 @@ impl EventCacheStoreIntegrationTests for DynEventCacheStore { #[allow(unused_macros, unused_extern_crates)] #[macro_export] macro_rules! event_cache_store_integration_tests { - () => { + (with_media_size_tests) => { mod event_cache_store_integration_tests { - use matrix_sdk_test::async_test; - use $crate::event_cache_store::{EventCacheStoreIntegrationTests, IntoEventCacheStore}; + $crate::event_cache_store_integration_tests!(@inner); - use super::get_event_cache_store; + #[async_test] + async fn test_media_max_file_size() { + let event_cache_store = get_event_cache_store().await.unwrap().into_event_cache_store(); + event_cache_store.test_media_max_file_size().await; + } #[async_test] - async fn test_media_content() { - let event_cache_store = - get_event_cache_store().await.unwrap().into_event_cache_store(); - event_cache_store.test_media_content().await; + async fn test_media_max_cache_size() { + let event_cache_store = get_event_cache_store().await.unwrap().into_event_cache_store(); + event_cache_store.test_media_max_cache_size().await; } } }; + + () => { + mod event_cache_store_integration_tests { + $crate::event_cache_store_integration_tests!(@inner); + } + }; + + (@inner) => { + use matrix_sdk_test::async_test; + use $crate::event_cache_store::{EventCacheStoreIntegrationTests, IntoEventCacheStore}; + + use super::get_event_cache_store; + + #[async_test] + async fn test_media_content() { + let event_cache_store = + get_event_cache_store().await.unwrap().into_event_cache_store(); + event_cache_store.test_media_content().await; + } + + #[async_test] + async fn test_store_media_retention_policy() { + let event_cache_store = get_event_cache_store().await.unwrap().into_event_cache_store(); + event_cache_store.test_store_media_retention_policy().await; + } + + #[async_test] + async fn test_media_expiry() { + let event_cache_store = get_event_cache_store().await.unwrap().into_event_cache_store(); + event_cache_store.test_media_expiry().await; + } + }; } diff --git a/crates/matrix-sdk-base/src/event_cache_store/memory_store.rs b/crates/matrix-sdk-base/src/event_cache_store/memory_store.rs index 381c52cbc1..0ad0b22869 100644 --- a/crates/matrix-sdk-base/src/event_cache_store/memory_store.rs +++ b/crates/matrix-sdk-base/src/event_cache_store/memory_store.rs @@ -16,9 +16,9 @@ use std::{num::NonZeroUsize, sync::RwLock as StdRwLock}; use async_trait::async_trait; use matrix_sdk_common::ring_buffer::RingBuffer; -use ruma::{MxcUri, OwnedMxcUri}; +use ruma::{time::SystemTime, MxcUri, OwnedMxcUri}; -use super::{EventCacheStore, EventCacheStoreError, Result}; +use super::{EventCacheStore, EventCacheStoreError, MediaRetentionPolicy, Result}; use crate::media::{MediaRequest, UniqueKey as _}; /// In-memory, non-persistent implementation of the `EventCacheStore`. @@ -27,15 +27,41 @@ use crate::media::{MediaRequest, UniqueKey as _}; #[allow(clippy::type_complexity)] #[derive(Debug)] pub struct MemoryStore { - media: StdRwLock)>>, + inner: StdRwLock, +} + +#[derive(Debug)] +struct MemoryStoreInner { + /// The media retention policy to use on cleanups. + media_retention_policy: Option, + /// Media content. + media: RingBuffer, } // SAFETY: `new_unchecked` is safe because 20 is not zero. const NUMBER_OF_MEDIAS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(20) }; +/// A media content. +#[derive(Debug, Clone)] +struct MediaContent { + /// The Matrix URI of the media. + uri: OwnedMxcUri, + /// The unique key of the media request. + key: String, + /// The content of the media. + data: Vec, + /// The last access time of the media. + last_access: SystemTime, +} + impl Default for MemoryStore { fn default() -> Self { - Self { media: StdRwLock::new(RingBuffer::new(NUMBER_OF_MEDIAS)) } + let inner = MemoryStoreInner { + media_retention_policy: Default::default(), + media: RingBuffer::new(NUMBER_OF_MEDIAS), + }; + + Self { inner: StdRwLock::new(inner) } } } @@ -51,53 +77,178 @@ impl MemoryStore { impl EventCacheStore for MemoryStore { type Error = EventCacheStoreError; - async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + async fn media_retention_policy(&self) -> Result, Self::Error> { + Ok(self.inner.read().unwrap().media_retention_policy) + } + + async fn set_media_retention_policy( + &self, + policy: MediaRetentionPolicy, + ) -> Result<(), Self::Error> { + let mut inner = self.inner.write().unwrap(); + inner.media_retention_policy = Some(policy); + + Ok(()) + } + + async fn add_media_content( + &self, + request: &MediaRequest, + data: Vec, + current_time: SystemTime, + policy: MediaRetentionPolicy, + ) -> Result<()> { // Avoid duplication. Let's try to remove it first. self.remove_media_content(request).await?; + + if policy.exceeds_max_file_size(data.len()) { + // The content is too big to be cached. + return Ok(()); + } + // Now, let's add it. - self.media.write().unwrap().push((request.uri().to_owned(), request.unique_key(), data)); + let content = MediaContent { + uri: request.uri().to_owned(), + key: request.unique_key(), + data, + last_access: current_time, + }; + self.inner.write().unwrap().media.push(content); Ok(()) } - async fn get_media_content(&self, request: &MediaRequest) -> Result>> { - let media = self.media.read().unwrap(); + async fn get_media_content( + &self, + request: &MediaRequest, + current_time: SystemTime, + ) -> Result>> { + let mut inner = self.inner.write().unwrap(); let expected_key = request.unique_key(); - Ok(media.iter().find_map(|(_media_uri, media_key, media_content)| { - (media_key == &expected_key).then(|| media_content.to_owned()) - })) + // First get the content out of the buffer. + let Some(index) = inner.media.iter().position(|media| media.key == expected_key) else { + return Ok(None); + }; + let Some(mut content) = inner.media.remove(index) else { + return Ok(None); + }; + + // Clone the data. + let data = content.data.clone(); + + // Update the last access time. + content.last_access = current_time; + + // Put it back in the buffer. + inner.media.push(content); + + Ok(Some(data)) } async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { - let mut media = self.media.write().unwrap(); + let mut inner = self.inner.write().unwrap(); + let expected_key = request.unique_key(); - let Some(index) = media - .iter() - .position(|(_media_uri, media_key, _media_content)| media_key == &expected_key) - else { + let Some(index) = inner.media.iter().position(|media| media.key == expected_key) else { return Ok(()); }; - media.remove(index); + inner.media.remove(index); Ok(()) } async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - let mut media = self.media.write().unwrap(); - let expected_key = uri.to_owned(); - let positions = media + let mut inner = self.inner.write().unwrap(); + let positions = inner + .media .iter() .enumerate() - .filter_map(|(position, (media_uri, _media_key, _media_content))| { - (media_uri == &expected_key).then_some(position) - }) + .filter_map(|(position, media)| (media.uri == uri).then_some(position)) .collect::>(); // Iterate in reverse-order so that positions stay valid after first removals. for position in positions.into_iter().rev() { - media.remove(position); + inner.media.remove(position); + } + + Ok(()) + } + + async fn clean_up_media_cache( + &self, + policy: MediaRetentionPolicy, + current_time: SystemTime, + ) -> Result<(), Self::Error> { + if !policy.has_limitations() { + // We can safely skip all the checks. + return Ok(()); + } + + let mut inner = self.inner.write().unwrap(); + + // First, check media content that exceed the max filesize. + if policy.max_file_size.is_some() || policy.max_cache_size.is_some() { + inner.media.retain(|content| !policy.exceeds_max_file_size(content.data.len())); + } + + // Then, clean up expired media content. + if policy.last_access_expiry.is_some() { + inner + .media + .retain(|content| !policy.has_content_expired(current_time, content.last_access)); + } + + // Finally, if the cache size is too big, remove old items until it fits. + if let Some(max_cache_size) = policy.max_cache_size { + // Reverse the iterator because in case the cache size is overflowing, we want + // to count the number of old items to remove, and old items are at + // the start. + let (cache_size, overflowing_count) = inner.media.iter().rev().fold( + (0usize, 0u8), + |(cache_size, overflowing_count), content| { + if overflowing_count > 0 { + // Assume that all data is overflowing now. Overflowing count cannot + // overflow because the number of items is limited to 20. + (cache_size, overflowing_count + 1) + } else { + match cache_size.checked_add(content.data.len()) { + Some(cache_size) => (cache_size, 0), + // The cache size is overflowing, let's count the number of overflowing + // items to be able to remove them, since the max cache size cannot be + // bigger than usize::MAX. + None => (cache_size, 1), + } + } + }, + ); + + // If the cache size is overflowing, remove the number of old items we counted. + for _position in 0..overflowing_count { + inner.media.pop(); + } + + if cache_size > max_cache_size { + let difference = cache_size - max_cache_size; + + // Count the number of old items to remove to reach the difference. + let mut accumulated_items_size = 0usize; + let mut remove_items_count = 0u8; + for content in inner.media.iter() { + remove_items_count += 1; + // Cannot overflow since we already removed overflowing items. + accumulated_items_size += content.data.len(); + + if accumulated_items_size >= difference { + break; + } + } + + for _position in 0..remove_items_count { + inner.media.pop(); + } + } } Ok(()) @@ -112,5 +263,5 @@ mod tests { Ok(MemoryStore::new()) } - event_cache_store_integration_tests!(); + event_cache_store_integration_tests!(with_media_size_tests); } diff --git a/crates/matrix-sdk-base/src/event_cache_store/mod.rs b/crates/matrix-sdk-base/src/event_cache_store/mod.rs index f6458a580a..87fefe6179 100644 --- a/crates/matrix-sdk-base/src/event_cache_store/mod.rs +++ b/crates/matrix-sdk-base/src/event_cache_store/mod.rs @@ -19,21 +19,25 @@ //! into the event cache for the actual storage. By default this brings an //! in-memory store. -use std::str::Utf8Error; +use std::{str::Utf8Error, time::Duration}; #[cfg(any(test, feature = "testing"))] #[macro_use] pub mod integration_tests; mod memory_store; mod traits; +mod wrapper; pub use matrix_sdk_store_encryption::Error as StoreEncryptionError; +use ruma::time::SystemTime; +use serde::{Deserialize, Serialize}; #[cfg(any(test, feature = "testing"))] pub use self::integration_tests::EventCacheStoreIntegrationTests; pub use self::{ memory_store::MemoryStore, traits::{DynEventCacheStore, EventCacheStore, IntoEventCacheStore}, + wrapper::EventCacheStoreWrapper, }; /// Event cache store specific error type. @@ -83,3 +87,162 @@ impl EventCacheStoreError { /// An `EventCacheStore` specific result type. pub type Result = std::result::Result; + +/// The retention policy for media content used by the `EventCacheStore`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct MediaRetentionPolicy { + /// The maximum authorized size of the overall media cache, in bytes. + /// + /// The cache size is defined as the sum of the sizes of all the (possibly + /// encrypted) media contents in the cache, excluding any metadata + /// associated with them. + /// + /// If this is set and the cache size is bigger than this value, the oldest + /// media contents in the cache will be removed during a cleanup until the + /// cache size is below this threshold. + /// + /// Note that it is possible for the cache size to temporarily exceed this + /// value between two cleanups. + /// + /// Defaults to 400 MiB. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_cache_size: Option, + /// The maximum authorized size of a single media content, in bytes. + /// + /// The size of a media content is the size taken by the content in the + /// database, after it was possibly encrypted, so it might differ from the + /// initial size of the content. + /// + /// The maximum authorized size of a single media content is actually the + /// lowest value between `max_cache_size` and `max_file_size`. + /// + /// If it is set, media content bigger than the maximum size will not be + /// cached. If the maximum size changed after media content that exceeds the + /// new value was cached, the corresponding content will be removed + /// during a cleanup. + /// + /// Defaults to 20 MiB. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_file_size: Option, + /// The duration after which unaccessed media content is considered + /// expired. + /// + /// If this is set, media content whose last access is older than this + /// duration will be removed from the media cache during a cleanup. + /// + /// Defaults to 30 days. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub last_access_expiry: Option, +} + +impl MediaRetentionPolicy { + /// Create a `MediaRetentionPolicy` with the default values. + pub fn new() -> Self { + Self::default() + } + + /// Create an empty `MediaRetentionPolicy`. + /// + /// This means that all media will be cached and cleanups have no effect. + pub fn empty() -> Self { + Self { max_cache_size: None, max_file_size: None, last_access_expiry: None } + } + + /// Set the maximum authorized size of the overall media cache, in bytes. + pub fn with_max_cache_size(mut self, size: Option) -> Self { + self.max_cache_size = size; + self + } + + /// Set the maximum authorized size of a single media content, in bytes. + pub fn with_max_file_size(mut self, size: Option) -> Self { + self.max_file_size = size; + self + } + + /// Set the duration before which unaccessed media content is considered + /// expired. + pub fn with_last_access_expiry(mut self, duration: Option) -> Self { + self.last_access_expiry = duration; + self + } + + /// Whether this policy has limitations. + /// + /// If this policy has no limitations, a cleanup job would have no effect. + /// + /// Returns `true` if at least one limitation is set. + pub fn has_limitations(&self) -> bool { + self.max_cache_size.is_some() + || self.max_file_size.is_some() + || self.last_access_expiry.is_some() + } + + /// Whether the given size exceeds the maximum authorized size of the media + /// cache. + /// + /// # Arguments + /// + /// * `size` - The overall size of the media cache to check, in bytes. + pub fn exceeds_max_cache_size(&self, size: usize) -> bool { + self.max_cache_size.is_some_and(|max_size| size > max_size) + } + + /// The computed maximum authorized size of a single media content, in + /// bytes. + /// + /// This is the lowest value between `max_cache_size` and `max_file_size`. + pub fn computed_max_file_size(&self) -> Option { + match (self.max_cache_size, self.max_file_size) { + (None, None) => None, + (None, Some(size)) => Some(size), + (Some(size), None) => Some(size), + (Some(max_cache_size), Some(max_file_size)) => Some(max_cache_size.min(max_file_size)), + } + } + + /// Whether the given size, in bytes, exceeds the computed maximum + /// authorized size of a single media content. + /// + /// # Arguments + /// + /// * `size` - The size of the media content to check, in bytes. + pub fn exceeds_max_file_size(&self, size: usize) -> bool { + self.computed_max_file_size().is_some_and(|max_size| size > max_size) + } + + /// Whether a content whose last access was at the given time has expired. + /// + /// # Arguments + /// + /// * `current_time` - The current time. + /// + /// * `last_access_time` - The time when the media content to check was last + /// accessed. + pub fn has_content_expired( + &self, + current_time: SystemTime, + last_access_time: SystemTime, + ) -> bool { + self.last_access_expiry.is_some_and(|max_duration| { + current_time + .duration_since(last_access_time) + // If this returns an error, the last access time is newer than the current time. + // This shouldn't happen but in this case the content cannot be expired. + .is_ok_and(|elapsed| elapsed >= max_duration) + }) + } +} + +impl Default for MediaRetentionPolicy { + fn default() -> Self { + Self { + // 400 MiB. + max_cache_size: Some(400 * 1024 * 1024), + // 20 MiB. + max_file_size: Some(20 * 1024 * 1024), + // 30 days. + last_access_expiry: Some(Duration::from_secs(30 * 24 * 60 * 60)), + } + } +} diff --git a/crates/matrix-sdk-base/src/event_cache_store/traits.rs b/crates/matrix-sdk-base/src/event_cache_store/traits.rs index 6c56166177..67384f067d 100644 --- a/crates/matrix-sdk-base/src/event_cache_store/traits.rs +++ b/crates/matrix-sdk-base/src/event_cache_store/traits.rs @@ -16,9 +16,9 @@ use std::{fmt, sync::Arc}; use async_trait::async_trait; use matrix_sdk_common::AsyncTraitDeps; -use ruma::MxcUri; +use ruma::{time::SystemTime, MxcUri}; -use super::EventCacheStoreError; +use super::{EventCacheStoreError, MediaRetentionPolicy}; use crate::media::MediaRequest; /// An abstract trait that can be used to implement different store backends @@ -29,30 +29,54 @@ pub trait EventCacheStore: AsyncTraitDeps { /// The error type used by this event cache store. type Error: fmt::Debug + Into; - /// Add a media file's content in the media store. + /// The retention policy set to cleanup the media cache. + async fn media_retention_policy(&self) -> Result, Self::Error>; + + /// Set the retention policy used to cleanup the media cache. + /// + /// If the store implementation is persistent, this setting should be + /// persisted by the store. + async fn set_media_retention_policy( + &self, + policy: MediaRetentionPolicy, + ) -> Result<(), Self::Error>; + + /// Add a media file's content in the media cache. /// /// # Arguments /// /// * `request` - The `MediaRequest` of the file. /// /// * `content` - The content of the file. + /// + /// * `current_time` - The current time, to be used as the last access time + /// of the media. + /// + /// * `policy` - The media retention policy, to check whether the media is + /// too big to be cached. async fn add_media_content( &self, request: &MediaRequest, content: Vec, + current_time: SystemTime, + policy: MediaRetentionPolicy, ) -> Result<(), Self::Error>; - /// Get a media file's content out of the media store. + /// Get a media file's content out of the media cache. /// /// # Arguments /// /// * `request` - The `MediaRequest` of the file. + /// + /// * `current_time` - The current time, to be used as the last access time + /// of the media. async fn get_media_content( &self, request: &MediaRequest, + current_time: SystemTime, ) -> Result>, Self::Error>; - /// Remove a media file's content from the media store. + /// Remove a media file's content from the media cache. /// /// # Arguments /// @@ -60,12 +84,26 @@ pub trait EventCacheStore: AsyncTraitDeps { async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error>; /// Remove all the media files' content associated to an `MxcUri` from the - /// media store. + /// media cache. /// /// # Arguments /// /// * `uri` - The `MxcUri` of the media files. async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error>; + + /// Clean up the media cache with the given policy. + /// + /// # Arguments + /// + /// * `policy` - The media retention policy to use for the cleanup. The + /// `cleanup_frequency` will be ignored. + /// * `current_time` - The current time, to be used to check for expired + /// content. + async fn clean_up_media_cache( + &self, + policy: MediaRetentionPolicy, + current_time: SystemTime, + ) -> Result<(), Self::Error>; } #[repr(transparent)] @@ -83,19 +121,33 @@ impl fmt::Debug for EraseEventCacheStoreError { impl EventCacheStore for EraseEventCacheStoreError { type Error = EventCacheStoreError; + async fn media_retention_policy(&self) -> Result, Self::Error> { + self.0.media_retention_policy().await.map_err(Into::into) + } + + async fn set_media_retention_policy( + &self, + policy: MediaRetentionPolicy, + ) -> Result<(), Self::Error> { + self.0.set_media_retention_policy(policy).await.map_err(Into::into) + } + async fn add_media_content( &self, request: &MediaRequest, content: Vec, + current_time: SystemTime, + policy: MediaRetentionPolicy, ) -> Result<(), Self::Error> { - self.0.add_media_content(request, content).await.map_err(Into::into) + self.0.add_media_content(request, content, current_time, policy).await.map_err(Into::into) } async fn get_media_content( &self, request: &MediaRequest, + current_time: SystemTime, ) -> Result>, Self::Error> { - self.0.get_media_content(request).await.map_err(Into::into) + self.0.get_media_content(request, current_time).await.map_err(Into::into) } async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> { @@ -105,6 +157,14 @@ impl EventCacheStore for EraseEventCacheStoreError { async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> { self.0.remove_media_content_for_uri(uri).await.map_err(Into::into) } + + async fn clean_up_media_cache( + &self, + policy: MediaRetentionPolicy, + current_time: SystemTime, + ) -> Result<(), Self::Error> { + self.0.clean_up_media_cache(policy, current_time).await.map_err(Into::into) + } } /// A type-erased [`EventCacheStore`]. diff --git a/crates/matrix-sdk-base/src/event_cache_store/wrapper.rs b/crates/matrix-sdk-base/src/event_cache_store/wrapper.rs new file mode 100644 index 0000000000..c6f02328ab --- /dev/null +++ b/crates/matrix-sdk-base/src/event_cache_store/wrapper.rs @@ -0,0 +1,130 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::{Arc, Mutex as StdMutex}; + +use ruma::{time::SystemTime, MxcUri}; +use tokio::sync::Mutex as AsyncMutex; + +use super::{DynEventCacheStore, MediaRetentionPolicy, Result}; +use crate::media::MediaRequest; + +/// A wrapper around an [`EventCacheStore`] implementation to abstract common +/// operations. +/// +/// [`EventCacheStore`]: super::EventCacheStore +#[derive(Debug, Clone)] +pub struct EventCacheStoreWrapper { + /// The inner store implementation. + store: Arc, + /// The media retention policy. + media_retention_policy: Arc>>, + /// Guard to only have a single media cleanup at a time. + media_cleanup_guard: Arc>, +} + +impl EventCacheStoreWrapper { + /// Create a new `EventCacheStoreWrapper` around the given store. + pub(crate) fn new(store: Arc) -> Self { + Self { + store, + media_retention_policy: Default::default(), + media_cleanup_guard: Default::default(), + } + } + + /// The media retention policy. + pub async fn media_retention_policy(&self) -> Result { + if let Some(policy) = *self.media_retention_policy.lock().unwrap() { + return Ok(policy); + } + + let policy = self.store.media_retention_policy().await?.unwrap_or_default(); + *self.media_retention_policy.lock().unwrap() = Some(policy); + + Ok(policy) + } + + /// Set the media retention policy. + pub async fn set_media_retention_policy(&self, policy: MediaRetentionPolicy) -> Result<()> { + self.store.set_media_retention_policy(policy).await?; + + *self.media_retention_policy.lock().unwrap() = Some(policy); + Ok(()) + } + + /// Add a media file's content in the media cache. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + /// * `content` - The content of the file. + pub async fn add_media_content(&self, request: &MediaRequest, content: Vec) -> Result<()> { + let policy = self.media_retention_policy().await?; + + if policy.exceeds_max_file_size(content.len()) { + // The media content should not be cached. + return Ok(()); + } + + // We let the store implementation check the max file size again because the + // size of the content should change if it is encrypted. + self.store.add_media_content(request, content, SystemTime::now(), policy).await + } + + /// Get a media file's content out of the media cache. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + pub async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + self.store.get_media_content(request, SystemTime::now()).await + } + + /// Remove a media file's content from the media cache. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.store.remove_media_content(request).await + } + + /// Remove all the media files' content associated to an `MxcUri` from the + /// media cache. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the media files. + pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + self.store.remove_media_content_for_uri(uri).await + } + + /// Clean up the media cache with the current media retention policy. + /// + /// This is a noop if another cleanup is ongoing. + pub async fn clean_up_media_cache(&self) -> Result<()> { + let Ok(_guard) = self.media_cleanup_guard.try_lock() else { + return Ok(()); + }; + + let policy = self.media_retention_policy().await?; + if !policy.has_limitations() { + // We can safely skip all the checks. + return Ok(()); + } + + self.store.clean_up_media_cache(policy, SystemTime::now()).await + } +} diff --git a/crates/matrix-sdk-sqlite/src/event_cache_store.rs b/crates/matrix-sdk-sqlite/src/event_cache_store.rs index 51afb9e5de..ec68041066 100644 --- a/crates/matrix-sdk-sqlite/src/event_cache_store.rs +++ b/crates/matrix-sdk-sqlite/src/event_cache_store.rs @@ -3,21 +3,28 @@ use std::{borrow::Cow, fmt, path::Path, sync::Arc}; use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime}; use matrix_sdk_base::{ - event_cache_store::EventCacheStore, + event_cache_store::{EventCacheStore, MediaRetentionPolicy}, media::{MediaRequest, UniqueKey}, }; use matrix_sdk_store_encryption::StoreCipher; -use rusqlite::OptionalExtension; +use ruma::time::SystemTime; +use rusqlite::{params_from_iter, OptionalExtension}; use tokio::fs; use tracing::debug; use crate::{ error::{Error, Result}, - utils::{Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt, SqliteKeyValueStoreConnExt}, + utils::{ + repeat_vars, time_to_timestamp, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt, + SqliteKeyValueStoreConnExt, SqliteTransactionExt, + }, OpenStoreError, }; mod keys { + // Entries in Key-value store + pub const MEDIA_RETENTION_POLICY: &str = "media-retention-policy"; + // Tables pub const MEDIA: &str = "media"; } @@ -140,24 +147,62 @@ async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> { impl EventCacheStore for SqliteEventCacheStore { type Error = Error; - async fn add_media_content(&self, request: &MediaRequest, content: Vec) -> Result<()> { + async fn media_retention_policy(&self) -> Result, Self::Error> { + let conn = self.acquire().await?; + let Some(bytes) = conn.get_kv(keys::MEDIA_RETENTION_POLICY).await? else { + return Ok(None); + }; + + Ok(Some(rmp_serde::from_slice(&bytes)?)) + } + + async fn set_media_retention_policy( + &self, + policy: MediaRetentionPolicy, + ) -> Result<(), Self::Error> { + let conn = self.acquire().await?; + let serialized_policy = rmp_serde::to_vec_named(&policy)?; + + conn.set_kv(keys::MEDIA_RETENTION_POLICY, serialized_policy).await?; + Ok(()) + } + + async fn add_media_content( + &self, + request: &MediaRequest, + content: Vec, + current_time: SystemTime, + policy: MediaRetentionPolicy, + ) -> Result<()> { + let data = self.encode_value(content)?; + + if policy.exceeds_max_file_size(data.len()) { + // The content is too big to be cached. + return Ok(()); + } + let uri = self.encode_key(keys::MEDIA, request.source.unique_key()); let format = self.encode_key(keys::MEDIA, request.format.unique_key()); - let data = self.encode_value(content)?; + let timestamp = time_to_timestamp(current_time); let conn = self.acquire().await?; conn.execute( - "INSERT OR REPLACE INTO media (uri, format, data, last_access) VALUES (?, ?, ?, CAST(strftime('%s') as INT))", - (uri, format, data), + "INSERT OR REPLACE INTO media (uri, format, data, last_access) VALUES (?, ?, ?, ?)", + (uri, format, data, timestamp), ) .await?; Ok(()) } - async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + async fn get_media_content( + &self, + request: &MediaRequest, + current_time: SystemTime, + ) -> Result>> { let uri = self.encode_key(keys::MEDIA, request.source.unique_key()); let format = self.encode_key(keys::MEDIA, request.format.unique_key()); + let timestamp = time_to_timestamp(current_time); let conn = self.acquire().await?; let data = conn @@ -166,9 +211,8 @@ impl EventCacheStore for SqliteEventCacheStore { // We need to do this first so the transaction is in write mode right away. // See: https://sqlite.org/lang_transaction.html#read_transactions_versus_write_transactions txn.execute( - "UPDATE media SET last_access = CAST(strftime('%s') as INT) \ - WHERE uri = ? AND format = ?", - (&uri, &format), + "UPDATE media SET last_access = ? WHERE uri = ? AND format = ?", + (timestamp, &uri, &format), )?; txn.query_row::, _, _>( @@ -201,6 +245,106 @@ impl EventCacheStore for SqliteEventCacheStore { Ok(()) } + + async fn clean_up_media_cache( + &self, + policy: MediaRetentionPolicy, + current_time: SystemTime, + ) -> Result<(), Self::Error> { + if !policy.has_limitations() { + // We can safely skip all the checks. + return Ok(()); + } + + let conn = self.acquire().await?; + conn.with_transaction::<_, Error, _>(move |txn| { + // First, check media content that exceed the max filesize. + if let Some(max_file_size) = policy.computed_max_file_size() { + txn.execute("DELETE FROM media WHERE length(data) > ?", (max_file_size,))?; + } + + // Then, clean up expired media content. + if let Some(last_access_expiry) = policy.last_access_expiry { + let current_timestamp = time_to_timestamp(current_time); + let expiry_secs = last_access_expiry.as_secs(); + txn.execute( + "DELETE FROM media WHERE (? - last_access) >= ?", + (current_timestamp, expiry_secs), + )?; + } + + // Finally, if the cache size is too big, remove old items until it fits. + if let Some(max_cache_size) = policy.max_cache_size { + // i64 is the integer type used by SQLite, use it here to avoid usize overflow + // during the conversion of the result. + let cache_size_int = txn + .query_row("SELECT sum(length(data)) FROM media", (), |row| { + // `sum()` returns `NULL` if there are no rows. + row.get::<_, Option>(0) + })? + .unwrap_or_default(); + let cache_size_usize = usize::try_from(cache_size_int); + + // If the cache size is overflowing or bigger than max cache size, clean up. + if cache_size_usize.is_err() + || cache_size_usize.is_ok_and(|cache_size| cache_size > max_cache_size) + { + // Get the sizes of the media contents ordered by last access. + let mut cached_stmt = txn.prepare_cached( + "SELECT rowid, length(data) FROM media ORDER BY last_access DESC", + )?; + let content_sizes = cached_stmt + .query(())? + .mapped(|row| Ok((row.get::<_, i64>(0)?, row.get::<_, usize>(1)?))); + + let mut accumulated_items_size = 0usize; + let mut limit_reached = false; + let mut rows_to_remove = Vec::new(); + + for result in content_sizes { + let (row_id, size) = match result { + Ok(content_size) => content_size, + Err(error) => { + return Err(error.into()); + } + }; + + if limit_reached { + rows_to_remove.push(row_id); + continue; + } + + match accumulated_items_size.checked_add(size) { + Some(acc) if acc > max_cache_size => { + // We can stop accumulating. + limit_reached = true; + rows_to_remove.push(row_id); + } + Some(acc) => accumulated_items_size = acc, + None => { + // The accumulated size is overflowing but the setting cannot be + // bigger than usize::MAX, we can stop accumulating. + limit_reached = true; + rows_to_remove.push(row_id); + } + }; + } + + txn.chunk_large_query_over(rows_to_remove, None, |txn, row_ids| { + let sql_params = repeat_vars(row_ids.len()); + let query = format!("DELETE FROM media WHERE rowid IN ({sql_params})"); + txn.prepare(&query)?.execute(params_from_iter(row_ids))?; + Ok(Vec::<()>::new()) + })?; + } + } + + Ok(()) + }) + .await?; + + Ok(()) + } } #[cfg(test)] @@ -211,13 +355,13 @@ mod tests { }; use matrix_sdk_base::{ - event_cache_store::{EventCacheStore, EventCacheStoreError}, + event_cache_store::{EventCacheStore, EventCacheStoreError, MediaRetentionPolicy}, event_cache_store_integration_tests, media::{MediaFormat, MediaRequest, MediaThumbnailSettings}, }; use matrix_sdk_test::async_test; use once_cell::sync::Lazy; - use ruma::{events::room::MediaSource, media::Method, mxc_uri, uint}; + use ruma::{events::room::MediaSource, media::Method, mxc_uri, time::SystemTime, uint}; use tempfile::{tempdir, TempDir}; use super::SqliteEventCacheStore; @@ -235,7 +379,7 @@ mod tests { Ok(SqliteEventCacheStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap()) } - event_cache_store_integration_tests!(); + event_cache_store_integration_tests!(with_media_size_tests); async fn get_event_cache_store_content_sorted_by_last_access( event_cache_store: &SqliteEventCacheStore, @@ -266,19 +410,19 @@ mod tests { let content: Vec = "hello world".into(); let thumbnail_content: Vec = "hello…".into(); + let policy = MediaRetentionPolicy::empty(); // Add the media. + let mut time = SystemTime::UNIX_EPOCH; event_cache_store - .add_media_content(&file_request, content.clone()) + .add_media_content(&file_request, content.clone(), time, policy) .await .expect("adding file failed"); - // Since the precision of the timestamp is in seconds, wait so the timestamps - // differ. - tokio::time::sleep(Duration::from_secs(3)).await; - + // Add the thumbnail 3 seconds later. + time = time.checked_add(Duration::from_secs(3)).expect("time should be fine"); event_cache_store - .add_media_content(&thumbnail_request, thumbnail_content.clone()) + .add_media_content(&thumbnail_request, thumbnail_content.clone(), time, policy) .await .expect("adding thumbnail failed"); @@ -290,13 +434,10 @@ mod tests { assert_eq!(contents[0], thumbnail_content, "thumbnail is not last access"); assert_eq!(contents[1], content, "file is not second-to-last access"); - // Since the precision of the timestamp is in seconds, wait so the timestamps - // differ. - tokio::time::sleep(Duration::from_secs(3)).await; - - // Access the file so its last access is more recent. + // Access the file 1 hour later so its last access is more recent. + time = time.checked_add(Duration::from_secs(3600)).expect("time should be fine"); let _ = event_cache_store - .get_media_content(&file_request) + .get_media_content(&file_request, time) .await .expect("getting file failed") .expect("file is missing"); diff --git a/crates/matrix-sdk-sqlite/src/utils.rs b/crates/matrix-sdk-sqlite/src/utils.rs index 1ec9339e7d..f0ac394745 100644 --- a/crates/matrix-sdk-sqlite/src/utils.rs +++ b/crates/matrix-sdk-sqlite/src/utils.rs @@ -19,6 +19,7 @@ use async_trait::async_trait; use deadpool_sqlite::Object as SqliteAsyncConn; use itertools::Itertools; use matrix_sdk_store_encryption::StoreCipher; +use ruma::time::SystemTime; use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction}; use crate::{ @@ -187,27 +188,27 @@ impl SqliteAsyncConnExt for SqliteAsyncConn { } pub(crate) trait SqliteTransactionExt { - fn chunk_large_query_over( + fn chunk_large_query_over( &self, - keys_to_chunk: Vec, + keys_to_chunk: Vec, result_capacity: Option, do_query: Query, ) -> Result> where Res: Send + 'static, - Query: Fn(&Transaction<'_>, Vec) -> Result> + Send + 'static; + Query: Fn(&Transaction<'_>, Vec) -> Result> + Send + 'static; } impl<'a> SqliteTransactionExt for Transaction<'a> { - fn chunk_large_query_over( + fn chunk_large_query_over( &self, - mut keys_to_chunk: Vec, + mut keys_to_chunk: Vec, result_capacity: Option, do_query: Query, ) -> Result> where Res: Send + 'static, - Query: Fn(&Transaction<'_>, Vec) -> Result> + Send + 'static, + Query: Fn(&Transaction<'_>, Vec) -> Result> + Send + 'static, { // Divide by 2 to allow space for more static parameters (not part of // `keys_to_chunk`). @@ -362,6 +363,19 @@ pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display { iter::repeat("?").take(count).format(",") } +/// Convert the given `SystemTime` to a timestamp, as the number of seconds +/// since Unix Epoch. +/// +/// Returns an `i64` as it is the numeric type used by SQLite. +pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 { + time.duration_since(SystemTime::UNIX_EPOCH) + .ok() + .and_then(|d| d.as_secs().try_into().ok()) + // It is unlikely to happen unless the time on the system is seriously wrong, but we always + // need a value. + .unwrap_or(0) +} + #[cfg(test)] mod unit_tests { use super::*; diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 729d5f73ae..46572e0b28 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -33,7 +33,7 @@ use imbl::Vector; #[cfg(feature = "e2e-encryption")] use matrix_sdk_base::crypto::store::LockableCryptoStore; use matrix_sdk_base::{ - event_cache_store::DynEventCacheStore, + event_cache_store::EventCacheStoreWrapper, store::{DynStateStore, ServerCapabilities}, sync::{Notification, RoomUpdates}, BaseClient, RoomInfoNotableUpdate, RoomState, RoomStateFilter, SendOutsideWasm, SessionMeta, @@ -584,7 +584,7 @@ impl Client { } /// Get a reference to the event cache store. - pub(crate) fn event_cache_store(&self) -> &DynEventCacheStore { + pub fn event_cache_store(&self) -> &EventCacheStoreWrapper { self.base_client().event_cache_store() }