From 20937b96fe7ede92f313ae8bd921cd947b5fd2c0 Mon Sep 17 00:00:00 2001 From: shouya <526598+shouya@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:35:12 +0900 Subject: [PATCH] feat(client): Caching requests to servers for feed (#12) This PR implements a client that caches the get requests to the servers when fetching feeds. It is used in places where `reqwest::Client` is used. Currently there are two places: - the client used to fetch the source (default ttl: 15 minutes) - the full_text filter (default ttl: 12 hours) This mechanism is expected to help in speeding up the fetching of feeds especially for repeated full-text requests. I'm also planning to employ this cache tests to insert custom fixtures for certain web addresses for future feature tests. --------- Co-authored-by: Shou Ya --- Cargo.lock | 21 +++++ Cargo.toml | 2 + src/client.rs | 103 +++++++++++++++++++++++- src/client/cache.rs | 151 ++++++++++++++++++++++++++++++++++++ src/filter/full_text.rs | 26 +++---- src/filter/simplify_html.rs | 3 +- src/server/endpoint.rs | 24 +++--- src/util.rs | 4 + 8 files changed, 305 insertions(+), 29 deletions(-) create mode 100644 src/client/cache.rs diff --git a/Cargo.lock b/Cargo.lock index fd89383..aa8816b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "anstream" version = "0.6.5" @@ -776,6 +782,10 @@ name = "hashbrown" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -1093,6 +1103,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "lru" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2c024b41519440580066ba82aab04092b333e09066a5eb86c7c4890df31f22" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -1875,12 +1894,14 @@ dependencies = [ "duration-str", "ego-tree", "either", + "encoding_rs", "futures", "html5ever", "htmlescape", "http 1.0.0", "itertools", "lazy_static", + "lru", "mime", "paste", "readability", diff --git a/Cargo.toml b/Cargo.toml index ba2a41c..66b6d32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,8 @@ either = "1.9.0" # used for returning sum types from the JS runtime # Web client (blocking and async both used, blocking used in the JS runtime) # TODO: upgrade reqwest after its hyper 1.0 upgrade reqwest = { version = "0.11.23", default-features = false, features = ["blocking", "rustls-tls", "trust-dns"] } +encoding_rs = "0.8.33" +lru = "0.12.2" # Used in sanitize filter to remove/replace text contents regex = "1.10.2" diff --git a/src/client.rs b/src/client.rs index cb52c32..e01601f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,24 @@ +mod cache; + use std::time::Duration; use reqwest::header::HeaderMap; use serde::{Deserialize, Serialize}; +use url::Url; use crate::util::Result; +use self::cache::{Response, ResponseCache}; + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ClientConfig { user_agent: Option, accept: Option, set_cookie: Option, referer: Option, + cache_size: Option, + #[serde(deserialize_with = "duration_str::deserialize_option_duration")] + cache_ttl: Option, #[serde(default = "default_timeout")] #[serde(deserialize_with = "duration_str::deserialize_duration")] timeout: Duration, @@ -24,6 +32,8 @@ impl Default for ClientConfig { set_cookie: None, referer: None, timeout: default_timeout(), + cache_size: None, + cache_ttl: None, } } } @@ -67,11 +77,100 @@ impl ClientConfig { builder } - pub fn build(&self) -> Result { - Ok(self.to_builder().build()?) + pub fn build(&self, default_cache_ttl: Duration) -> Result { + let reqwest_client = self.to_builder().build()?; + Ok(Client::new( + self.cache_size.unwrap_or(0), + self.cache_ttl.unwrap_or(default_cache_ttl), + reqwest_client, + )) + } +} + +pub struct Client { + cache: ResponseCache, + client: reqwest::Client, +} + +impl Client { + fn new( + cache_size: usize, + cache_ttl: Duration, + client: reqwest::Client, + ) -> Self { + Self { + cache: ResponseCache::new(cache_size, cache_ttl), + client, + } + } + + pub async fn get(&self, url: &Url) -> Result { + self.get_with(url, |req| req).await + } + + pub async fn get_with( + &self, + url: &Url, + f: impl FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder, + ) -> Result { + if let Some(resp) = self.cache.get_cached(url) { + return Ok(resp); + } + + let resp = f(self.client.get(url.clone())).send().await?; + let resp = Response::from_reqwest_resp(resp).await?; + self.cache.insert(url.clone(), resp.clone()); + Ok(resp) + } + + #[cfg(test)] + pub fn insert(&self, url: Url, response: Response) { + self.cache.insert(url, response); } } fn default_timeout() -> Duration { Duration::from_secs(10) } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_client_cache() { + let client = Client::new(1, Duration::from_secs(1), reqwest::Client::new()); + let url = Url::parse("http://example.com").unwrap(); + let body: Box = "foo".into(); + let response = Response::new( + url.clone(), + reqwest::StatusCode::OK, + HeaderMap::new(), + body.into(), + ); + + client.insert(url.clone(), response.clone()); + let actual = client.get(&url).await.unwrap(); + let expected = response; + + assert_eq!(actual.url(), expected.url()); + assert_eq!(actual.status(), expected.status()); + assert_eq!(actual.headers(), expected.headers()); + assert_eq!(actual.body(), expected.body()); + } + + const YT_SCISHOW_FEED_URL: &str = "https://www.youtube.com/feeds/videos.xml?channel_id=UCZYTClx2T1of7BRZ86-8fow"; + + #[tokio::test] + async fn test_client() { + let client = Client::new(0, Duration::from_secs(1), reqwest::Client::new()); + let url = Url::parse(YT_SCISHOW_FEED_URL).unwrap(); + let resp = client.get(&url).await.unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + assert_eq!( + resp.content_type().unwrap().to_string(), + "text/xml; charset=utf-8" + ); + assert!(resp.text().unwrap().contains("SciShow")); + } +} diff --git a/src/client/cache.rs b/src/client/cache.rs new file mode 100644 index 0000000..901109c --- /dev/null +++ b/src/client/cache.rs @@ -0,0 +1,151 @@ +use std::{ + num::NonZeroUsize, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + +use lru::LruCache; +use mime::Mime; +use reqwest::header::HeaderMap; +use url::Url; + +use crate::util::{Error, Result}; + +struct Timed { + value: T, + created: Instant, +} + +pub struct ResponseCache { + map: RwLock>>, + timeout: Duration, +} + +impl ResponseCache { + pub fn new(max_entries: usize, timeout: Duration) -> Self { + let max_entries = max_entries.try_into().unwrap_or(NonZeroUsize::MIN); + Self { + map: RwLock::new(LruCache::new(max_entries)), + timeout, + } + } + + pub fn get_cached(&self, url: &Url) -> Option { + let mut map = self.map.write().ok()?; + let Some(entry) = map.get(url) else { + return None; + }; + if entry.created.elapsed() > self.timeout { + map.pop(url); + return None; + } + Some(entry.value.clone()) + } + + pub fn insert(&self, url: Url, response: Response) -> Option<()> { + let timed = Timed { + value: response, + created: Instant::now(), + }; + self.map.write().ok()?.push(url, timed); + Some(()) + } +} + +#[derive(Clone)] +pub struct Response { + inner: Arc, +} + +struct InnerResponse { + url: Url, + status: reqwest::StatusCode, + headers: HeaderMap, + body: Box<[u8]>, +} + +impl Response { + pub async fn from_reqwest_resp(resp: reqwest::Response) -> Result { + let status = resp.status(); + let headers = resp.headers().clone(); + let url = resp.url().clone(); + let body = resp.bytes().await?.to_vec().into_boxed_slice(); + let resp = InnerResponse { + url, + status, + headers, + body, + }; + + Ok(Self { + inner: Arc::new(resp), + }) + } + + #[cfg(test)] + pub fn new( + url: Url, + status: reqwest::StatusCode, + headers: HeaderMap, + body: Box<[u8]>, + ) -> Self { + Self { + inner: Arc::new(InnerResponse { + url, + status, + headers, + body, + }), + } + } + + pub fn error_for_status(self) -> Result { + let status = self.inner.status; + if status.is_client_error() || status.is_server_error() { + return Err(Error::HttpStatus(status, self.inner.url.clone())); + } + + Ok(self) + } + + pub fn header(&self, name: &str) -> Option<&str> { + self.inner.headers.get(name).and_then(|v| v.to_str().ok()) + } + + pub fn text_with_charset(&self, default_encoding: &str) -> Result { + let content_type = self.content_type(); + let encoding_name = content_type + .as_ref() + .and_then(|mime| { + mime.get_param("charset").map(|charset| charset.as_str()) + }) + .unwrap_or(default_encoding); + let encoding = encoding_rs::Encoding::for_label(encoding_name.as_bytes()) + .unwrap_or(encoding_rs::UTF_8); + + let full = &self.inner.body; + let (text, _, _) = encoding.decode(full); + Ok(text.into_owned()) + } + + pub fn text(&self) -> Result { + self.text_with_charset("utf-8") + } + + pub fn content_type(&self) -> Option { + self.header("content-type").and_then(|v| v.parse().ok()) + } + + pub fn url(&self) -> &Url { + &self.inner.url + } + pub fn status(&self) -> reqwest::StatusCode { + self.inner.status + } + pub fn headers(&self) -> &HeaderMap { + &self.inner.headers + } + pub fn body(&self) -> &[u8] { + &self.inner.body + } +} diff --git a/src/filter/full_text.rs b/src/filter/full_text.rs index 137b958..df06a7a 100644 --- a/src/filter/full_text.rs +++ b/src/filter/full_text.rs @@ -1,8 +1,11 @@ +use std::time::Duration; + use futures::{stream, StreamExt}; -use mime::Mime; + use serde::{Deserialize, Serialize}; +use url::Url; -use crate::client; +use crate::client::{self, Client}; use crate::feed::{Feed, Post}; use crate::html::convert_relative_url; use crate::util::{Error, Result}; @@ -23,7 +26,7 @@ pub struct FullTextConfig { } pub struct FullTextFilter { - client: reqwest::Client, + client: Client, parallelism: usize, append_mode: bool, keep_element: Option, @@ -36,7 +39,9 @@ impl FeedFilterConfig for FullTextConfig { type Filter = FullTextFilter; async fn build(self) -> Result { - let client = self.client.unwrap_or_default().build()?; + // default cache ttl is 12 hours + let default_cache_ttl = Duration::from_secs(12 * 60 * 60); + let client = self.client.unwrap_or_default().build(default_cache_ttl)?; let parallelism = self.parallelism.unwrap_or(DEFAULT_PARALLELISM); let append_mode = self.append_mode.unwrap_or(false); let simplify = self.simplify.unwrap_or(false); @@ -59,14 +64,9 @@ impl FeedFilterConfig for FullTextConfig { impl FullTextFilter { async fn fetch_html(&self, url: &str) -> Result { - let resp = self.client.get(url).send().await?; - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("text/html") - .parse::() - .map_err(|_| Error::Message("invalid content_type".to_string()))?; + let url = Url::parse(url)?; + let resp = self.client.get(&url).await?; + let content_type = resp.content_type().unwrap_or(mime::TEXT_HTML); if content_type.essence_str() != "text/html" { return Err(Error::Message(format!( @@ -76,7 +76,7 @@ impl FullTextFilter { } let resp = resp.error_for_status()?; - let text = resp.text().await?; + let text = resp.text()?; Ok(text) } diff --git a/src/filter/simplify_html.rs b/src/filter/simplify_html.rs index 901cd3c..c769434 100644 --- a/src/filter/simplify_html.rs +++ b/src/filter/simplify_html.rs @@ -1,5 +1,6 @@ use readability::extractor::extract; use serde::{Deserialize, Serialize}; +use url::Url; use crate::feed::Feed; use crate::util::Result; @@ -40,7 +41,7 @@ impl FeedFilter for SimplifyHtmlFilter { } pub(super) fn simplify(text: &str, url: &str) -> Option { - let url = reqwest::Url::parse(url).ok()?; + let url = Url::parse(url).ok()?; let mut text = std::io::Cursor::new(text); let product = extract(&mut text, &url).ok()?; Some(product.content) diff --git a/src/server/endpoint.rs b/src/server/endpoint.rs index bab0454..29fc42d 100644 --- a/src/server/endpoint.rs +++ b/src/server/endpoint.rs @@ -2,6 +2,7 @@ use std::convert::Infallible; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use axum::body::Body; use axum::response::IntoResponse; @@ -11,7 +12,7 @@ use serde::{Deserialize, Serialize}; use tower::Service; use url::Url; -use crate::client::ClientConfig; +use crate::client::{Client, ClientConfig}; use crate::feed::Feed; use crate::filter::{BoxedFilter, FeedFilter, FilterConfig}; use crate::util::{Error, Result}; @@ -52,7 +53,7 @@ pub struct EndpointService { source: Option, content_type: Option, filters: Arc>, - client: Arc, + client: Arc, } #[derive(Clone, Default)] @@ -212,7 +213,8 @@ impl EndpointService { filters.push(filter); } - let client = config.client.unwrap_or_default().build()?; + let default_cache_ttl = Duration::from_secs(15 * 60); + let client = config.client.unwrap_or_default().build(default_cache_ttl)?; let source = match config.source { Some(source) => Some(Url::parse(&source)?), None => None, @@ -267,24 +269,20 @@ impl EndpointService { async fn fetch_feed(&self, source: &Url) -> Result { let resp = self .client - .get(source.to_string()) - .header("Accept", "text/html,application/xml") - .send() + .get_with(source, |builder| { + builder.header("Accept", "text/html,application/xml") + }) .await? .error_for_status()?; - let resp_content_type = resp - .headers() - .get("content-type") - .and_then(|x| x.to_str().ok()) - .and_then(|x| x.parse::().ok()) - .map(|x| x.essence_str().to_owned()); + let resp_content_type = + resp.content_type().map(|x| x.essence_str().to_owned()); let content_type = self .content_type .as_deref() .or(resp_content_type.as_deref()); - let content = resp.text().await?; + let content = resp.text()?; let feed = match content_type { Some("text/html") => Feed::from_html_content(&content, source)?, diff --git a/src/util.rs b/src/util.rs index 23bfe2e..893b866 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use url::Url; pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -47,6 +48,9 @@ pub enum Error { #[error("Reqwest client error {0:?}")] Reqwest(#[from] reqwest::Error), + #[error("HTTP status error {0:?} (url: {1})")] + HttpStatus(reqwest::StatusCode, Url), + #[error("Js execution error {0:?}")] Js(#[from] rquickjs::Error),