diff --git a/ntex-tokio/src/io.rs b/ntex-tokio/src/io.rs index a0c52f3b..6c2a1f53 100644 --- a/ntex-tokio/src/io.rs +++ b/ntex-tokio/src/io.rs @@ -49,10 +49,29 @@ impl ntex_io::AsyncRead for Read { async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { // read data from socket let result = poll_fn(|cx| { + let mut n = 0; let mut io = self.0.borrow_mut(); - poll_read_buf(Pin::new(&mut *io), cx, &mut buf) + loop { + return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? { + Poll::Pending => { + if n > 0 { + Poll::Ready(Ok(n)) + } else { + Poll::Pending + } + } + Poll::Ready(size) => { + n += size; + if n > 0 && buf.remaining_mut() > 0 { + continue; + } + Poll::Ready(Ok(n)) + } + }; + } }) .await; + (buf, result) } } @@ -85,6 +104,34 @@ impl ntex_io::AsyncWrite for Write { } } +pub fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut BytesVec, +) -> Poll> { + let n = { + let dst = + unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + if io.poll_read(cx, &mut buf)?.is_pending() { + return Poll::Pending; + } + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} + /// Flush write buffer to underlying I/O stream. pub(super) fn flush_io( io: &mut T, @@ -254,10 +301,29 @@ mod unixstream { async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { // read data from socket let result = poll_fn(|cx| { + let mut n = 0; let mut io = self.0.borrow_mut(); - poll_read_buf(Pin::new(&mut *io), cx, &mut buf) + loop { + return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? { + Poll::Pending => { + if n > 0 { + Poll::Ready(Ok(n)) + } else { + Poll::Pending + } + } + Poll::Ready(size) => { + n += size; + if n > 0 && buf.remaining_mut() > 0 { + continue; + } + Poll::Ready(Ok(n)) + } + }; + } }) .await; + (buf, result) } } @@ -290,31 +356,3 @@ mod unixstream { } } } - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut BytesVec, -) -> Poll> { - let n = { - let dst = - unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - if io.poll_read(cx, &mut buf)?.is_pending() { - return Poll::Pending; - } - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -}