diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index 5a4c34b79..776d5abb0 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -537,19 +537,25 @@ where // update read timer if let Some((_, max, rate)) = self.cfg.frame_read_rate() { let bytes = decoded.remains as u32; + let delta = if bytes > self.read_bytes { + (bytes - self.read_bytes).try_into().unwrap_or(u16::MAX) + } else { + bytes.try_into().unwrap_or(u16::MAX) + }; - let delta = (bytes - self.read_bytes).try_into().unwrap_or(u16::MAX); - + // read rate higher than min rate if delta >= rate { let n = now(); let next = self.shared.io.timer_deadline() + ONE_SEC; let new_timeout = if n >= next { ONE_SEC } else { next - n }; - // max timeout + // extend timeout if max.is_zero() || (n + new_timeout) <= self.read_max_timeout { - self.read_bytes = bytes; self.shared.io.stop_timer(); self.shared.io.start_timer(new_timeout); + + // store current buf size for future rate calculation + self.read_bytes = bytes; } } } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index cd95820fb..c1be89cac 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -552,7 +552,9 @@ impl Io { } else { match self.poll_read_ready(cx) { Poll::Pending | Poll::Ready(Ok(Some(()))) => { - log::debug!("not enough data to decode next frame"); + if log::log_enabled!(log::Level::Debug) && decoded.remains != 0 { + log::debug!("not enough data to decode next frame"); + } Ok(decoded) } Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))), diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index f7bc9d49d..712fa3cf3 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,9 +1,11 @@ # Changes -## [0.7.11] - 2023-11-xx +## [0.7.11] - 2023-11-20 * Refactor http/1 timeouts +* Add http/1 payload read timeout + ## [0.7.10] - 2023-11-12 * Start http client timeout after sending body diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index af45e4a92..b056d8c47 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -275,10 +275,6 @@ impl DispatcherConfig { pub(super) fn headers_read_rate(&self) -> Option<&ReadRate> { self.headers_read_rate.as_ref() } - - pub(super) fn payload_read_rate(&self) -> Option<&ReadRate> { - self.payload_read_rate.as_ref() - } } const DATE_VALUE_LENGTH_HDR: usize = 39; diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 8e309fd31..bba183dd4 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -4,7 +4,7 @@ use std::{ cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc, time, }; -use crate::io::{Filter, Io, IoBoxed, IoRef, IoStatusUpdate, RecvError}; +use crate::io::{Decoded, Filter, Io, IoBoxed, IoRef, IoStatusUpdate, RecvError}; use crate::service::{Pipeline, PipelineCall, Service}; use crate::time::now; use crate::util::{ready, Bytes}; @@ -95,7 +95,8 @@ struct DispatcherInner { config: Rc>, error: Option, payload: Option<(PayloadDecoder, PayloadSender)>, - read_bytes: u32, + read_remains: u32, + read_consumed: u32, read_max_timeout: time::Instant, _t: marker::PhantomData<(S, B)>, } @@ -134,7 +135,8 @@ where config, error: None, payload: None, - read_bytes: 0, + read_remains: 0, + read_consumed: 0, read_max_timeout: now(), _t: marker::PhantomData, }, @@ -346,7 +348,7 @@ where *this.st = State::Stop; } else { if let Poll::Ready(Err(err)) = - _poll_request_payload(&io.0, &mut this.inner.payload, cx) + this.inner._poll_request_payload(Some(&io.0), cx) { this.inner.error = Some(err); } @@ -501,7 +503,7 @@ where let result = match self.io.poll_recv_decode(&self.codec, cx) { Ok(decoded) => { if let Some(st) = - self.update_timer(decoded.item.is_some(), decoded.remains) + self.update_request_timer(decoded.item.is_some(), decoded.remains) { return Poll::Ready(st); } @@ -683,7 +685,130 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll> { - _poll_request_payload(&self.io, &mut self.payload, cx) + self._poll_request_payload::(None, cx) + } + + fn set_payload_error(&mut self, err: PayloadError) { + if let Some(ref mut payload) = self.payload { + payload.1.set_error(err); + self.payload = None; + } + } + + /// Process request's payload + fn _poll_request_payload( + &mut self, + io: Option<&Io>, + cx: &mut Context<'_>, + ) -> Poll> { + // check if payload data is required + if self.payload.is_none() { + return Poll::Ready(Ok(())); + }; + + match self.payload.as_mut().unwrap().1.poll_data_required(cx) { + PayloadStatus::Read => { + // read request payload + let mut updated = false; + loop { + let recv_result = io + .map(|io| { + io.poll_recv_decode(&self.payload.as_ref().unwrap().0, cx) + }) + .unwrap_or_else(|| { + self.io + .poll_recv_decode(&self.payload.as_ref().unwrap().0, cx) + }); + + let res = match recv_result { + Ok(decoded) => { + self.update_payload_timer(&decoded); + if let Some(item) = decoded.item { + updated = true; + Ok(item) + } else { + break; + } + } + Err(err) => Err(err), + }; + + match res { + Ok(PayloadItem::Chunk(chunk)) => { + self.payload.as_mut().unwrap().1.feed_data(chunk); + } + Ok(PayloadItem::Eof) => { + self.payload.as_mut().unwrap().1.feed_eof(); + self.payload = None; + break; + } + Err(err) => { + let err = match err { + RecvError::WriteBackpressure => { + let flush_result = io + .map(|io| io.poll_flush(cx, false)) + .unwrap_or_else(|| self.io.poll_flush(cx, false)); + + if flush_result?.is_pending() { + break; + } else { + continue; + } + } + RecvError::KeepAlive => { + if let Err(err) = self.handle_payload_timeout() { + DispatchError::from(err) + } else { + continue; + } + } + RecvError::Stop => { + self.set_payload_error(PayloadError::EncodingCorrupted); + io::Error::new( + io::ErrorKind::Other, + "Dispatcher stopped", + ) + .into() + } + RecvError::PeerGone(err) => { + self.set_payload_error(PayloadError::EncodingCorrupted); + if let Some(err) = err { + DispatchError::PeerGone(Some(err)) + } else { + ParseError::Incomplete.into() + } + } + RecvError::Decoder(e) => { + self.set_payload_error(PayloadError::EncodingCorrupted); + DispatchError::Parse(e) + } + }; + return Poll::Ready(Err(err)); + } + } + } + if updated { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + PayloadStatus::Pause => { + // stop payload timer + if self.flags.contains(Flags::READ_PL_TIMEOUT) { + self.flags.remove(Flags::READ_PL_TIMEOUT); + self.io.stop_timer(); + } + Poll::Pending + } + PayloadStatus::Dropped => { + // service call is not interested in payload + // wait until future completes and then close + // connection + self.payload = None; + Poll::Ready(Err(DispatchError::PayloadIsNotConsumed)) + } + } } /// check for io changes, could close while waiting for service call @@ -699,7 +824,55 @@ where } } - fn update_timer(&mut self, received: bool, remains: usize) -> Option> { + fn handle_payload_timeout(&mut self) -> Result<(), io::Error> { + // check payload read rate + if self.flags.contains(Flags::READ_PL_TIMEOUT) { + if let Some(ref cfg) = self.config.payload_read_rate { + let total = (self.read_remains + self.read_consumed) + .try_into() + .unwrap_or(u16::MAX); + if total > cfg.rate { + self.read_consumed = 0; + + // start timer for next period + if cfg.max_timeout.is_zero() + || (!cfg.max_timeout.is_zero() && now() < self.read_max_timeout) + { + log::trace!("Payload read rate {:?}, extend timer", total); + self.io.start_timer(cfg.timeout); + return Ok(()); + } + log::trace!("Max payload timeout has been reached"); + } + } + } + + log::trace!("Timeout during payload reading"); + self.set_payload_error(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Keep-alive", + ))); + Err(io::Error::new(io::ErrorKind::TimedOut, "Keep-alive")) + } + + fn update_payload_timer(&mut self, decoded: &Decoded) { + if self.flags.contains(Flags::READ_PL_TIMEOUT) { + self.read_remains = decoded.remains as u32; + self.read_consumed += decoded.consumed as u32; + } else if let Some(ref cfg) = self.config.payload_read_rate { + // start payload timer + self.flags.insert(Flags::READ_PL_TIMEOUT); + + self.read_remains = decoded.remains as u32; + self.read_consumed = decoded.consumed as u32; + self.io.start_timer(cfg.timeout); + if !cfg.max_timeout.is_zero() { + self.read_max_timeout = now() + cfg.max_timeout; + } + } + } + + fn update_request_timer(&mut self, received: bool, remains: usize) -> Option> { // we got parsed frame if received { // remove all timers @@ -710,8 +883,13 @@ where // update read timer if let Some(ref cfg) = self.config.headers_read_rate { let bytes = remains as u32; - let delta = (bytes - self.read_bytes).try_into().unwrap_or(u16::MAX); + let delta = if bytes > self.read_remains { + (bytes - self.read_remains).try_into().unwrap_or(u16::MAX) + } else { + bytes.try_into().unwrap_or(u16::MAX) + }; + // read rate higher than min rate if delta >= cfg.rate { let n = now(); let next = self.io.timer_deadline() + ONE_SEC; @@ -721,9 +899,11 @@ where if cfg.max_timeout.is_zero() || (n + new_timeout) <= self.read_max_timeout { - self.read_bytes = bytes; self.io.stop_timer(); self.io.start_timer(new_timeout); + + // store current buf size for future rate calculation + self.read_remains = bytes; } } } @@ -743,7 +923,7 @@ where // start read timer self.flags.insert(Flags::READ_HDRS_TIMEOUT); - self.read_bytes = remains as u32; + self.read_remains = 0; self.io.start_timer(cfg.timeout); if !cfg.max_timeout.is_zero() { self.read_max_timeout = now() + cfg.max_timeout; @@ -755,91 +935,6 @@ where } } -/// Process request's payload -fn _poll_request_payload( - io: &Io, - slf_payload: &mut Option<(PayloadDecoder, PayloadSender)>, - cx: &mut Context<'_>, -) -> Poll> { - // check if payload data is required - let payload = if let Some(ref mut payload) = slf_payload { - payload - } else { - return Poll::Ready(Ok(())); - }; - match payload.1.poll_data_required(cx) { - PayloadStatus::Read => { - // read request payload - let mut updated = false; - loop { - match io.poll_recv(&payload.0, cx) { - Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => { - updated = true; - payload.1.feed_data(chunk); - } - Poll::Ready(Ok(PayloadItem::Eof)) => { - updated = true; - payload.1.feed_eof(); - *slf_payload = None; - break; - } - Poll::Ready(Err(err)) => { - let err = match err { - RecvError::WriteBackpressure => { - if io.poll_flush(cx, false)?.is_pending() { - break; - } else { - continue; - } - } - RecvError::KeepAlive => { - payload.1.set_error(PayloadError::EncodingCorrupted); - *slf_payload = None; - io::Error::new(io::ErrorKind::Other, "Keep-alive").into() - } - RecvError::Stop => { - payload.1.set_error(PayloadError::EncodingCorrupted); - *slf_payload = None; - io::Error::new(io::ErrorKind::Other, "Dispatcher stopped") - .into() - } - RecvError::PeerGone(err) => { - payload.1.set_error(PayloadError::EncodingCorrupted); - *slf_payload = None; - if let Some(err) = err { - DispatchError::PeerGone(Some(err)) - } else { - ParseError::Incomplete.into() - } - } - RecvError::Decoder(e) => { - payload.1.set_error(PayloadError::EncodingCorrupted); - *slf_payload = None; - DispatchError::Parse(e) - } - }; - return Poll::Ready(Err(err)); - } - Poll::Pending => break, - } - } - if updated { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - PayloadStatus::Pause => Poll::Pending, - PayloadStatus::Dropped => { - // service call is not interested in payload - // wait until future completes and then close - // connection - *slf_payload = None; - Poll::Ready(Err(DispatchError::PayloadIsNotConsumed)) - } - } -} - #[cfg(test)] mod tests { use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; @@ -1308,4 +1403,62 @@ mod tests { assert_eq!(&buf[..28], b"HTTP/1.1 500 Internal Server"); assert_eq!(&buf[buf.len() - 5..], b"error"); } + + #[crate::rt_test] + async fn test_payload_timeout() { + env_logger::init(); + let mark = Arc::new(AtomicUsize::new(0)); + let mark2 = mark.clone(); + + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + + let svc = move |mut req: Request| { + let m = mark2.clone(); + async move { + // read one chunk + let mut pl = req.take_payload(); + while let Some(item) = stream_recv(&mut pl).await { + let size = m.load(Ordering::Relaxed); + if let Ok(buf) = item { + m.store(size + buf.len(), Ordering::Relaxed); + } else { + return Ok::<_, io::Error>(Response::Ok().finish()); + } + } + Ok::<_, io::Error>(Response::Ok().finish()) + } + }; + + let mut config = ServiceConfig::new( + Seconds(5).into(), + Millis(1_000), + Seconds::ZERO, + Millis(5_000), + Config::server(), + ); + config.payload_read_rate(Seconds(1), Seconds(2), 512); + let disp: Dispatcher> = Dispatcher::new( + nio::Io::new(server), + Rc::new(DispatcherConfig::new( + config, + svc.into_service(), + ExpectHandler, + None, + None, + )), + ); + crate::rt::spawn(disp); + + client.write("GET /test HTTP/1.1\r\nContent-Length: 1048576\r\n\r\n"); + sleep(Millis(50)).await; + + // send partial data to server + for _ in 1..8 { + let random_bytes: Vec = (0..256).map(|_| rand::random::()).collect(); + client.write(random_bytes); + sleep(Millis(350)).await; + } + assert!(mark.load(Ordering::Relaxed) == 1536); + } }