Skip to content

Commit

Permalink
Refactor tls impl
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Nov 3, 2023
1 parent d460d9c commit aa05f0d
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 55 deletions.
4 changes: 2 additions & 2 deletions ntex-io/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Changes

## [0.3.4] - 2023-11-xx

## [0.3.4] - 2023-11-03

* Add Io::force_ready_ready() and Io::poll_force_ready_ready() methods

## [0.3.3] - 2023-09-11

Expand Down
50 changes: 47 additions & 3 deletions ntex-io/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ bitflags::bitflags! {
const RD_READY = 0b0000_0000_0010_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_0100_0000;
/// any new data is available
const RD_FORCE_READY = 0b0000_0000_1000_0000;

Check warning on line 35 in ntex-io/src/io.rs

View check run for this annotation

Codecov / codecov/patch

ntex-io/src/io.rs#L34-L35

Added lines #L34 - L35 were not covered by tests

/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
Expand Down Expand Up @@ -78,10 +80,15 @@ impl IoState {
self.flags.set(flags);
}

pub(super) fn remove_flags(&self, f: Flags) {
pub(super) fn remove_flags(&self, f: Flags) -> bool {
let mut flags = self.flags.get();
flags.remove(f);
self.flags.set(flags);
if flags.intersects(f) {
flags.remove(f);
self.flags.set(flags);
true
} else {
false
}
}

pub(super) fn notify_keepalive(&self) {
Expand Down Expand Up @@ -365,6 +372,13 @@ impl<F> Io<F> {
poll_fn(|cx| self.poll_read_ready(cx)).await
}

#[doc(hidden)]
#[inline]
/// Wait until read becomes ready.
pub async fn force_read_ready(&self) -> io::Result<Option<()>> {
poll_fn(|cx| self.poll_force_read_ready(cx)).await
}

#[inline]
/// Pause read task
pub fn pause(&self) {
Expand Down Expand Up @@ -455,6 +469,36 @@ impl<F> Io<F> {
}
}

#[doc(hidden)]
#[inline]
/// Polls for read readiness.
///
/// If the io stream is not currently ready for reading,
/// this method will store a clone of the Waker from the provided Context.
/// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
///
/// Return value
/// The function returns:
///
/// `Poll::Pending` if the io stream is not ready for reading.
/// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
/// `Poll::Ready(Ok(None))` if io stream is disconnected
/// `Some(Poll::Ready(Err(e)))` if an error is encountered.
pub fn poll_force_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
let ready = self.poll_read_ready(cx);

if ready.is_pending() {
if self.0.0.remove_flags(Flags::RD_FORCE_READY) {
Poll::Ready(Ok(Some(())))
} else {
self.0.0.insert_flags(Flags::RD_FORCE_READY);
Poll::Pending
}
} else {
ready
}
}

#[inline]
/// Decode codec item from incoming bytes stream.
///
Expand Down
4 changes: 4 additions & 0 deletions ntex-io/src/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl ReadContext {
// so we need to wake up read task to read more data
// otherwise read task would sleep forever
inner.read_task.wake();
} else if inner.flags.get().contains(Flags::RD_FORCE_READY) {
// in case of "force read" we must wake up dispatch task
// if we read any data from source
inner.dispatch_task.wake();
}

// while reading, filter wrote some data
Expand Down
4 changes: 4 additions & 0 deletions ntex-tls/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.3.2] - 2023-11-03

* Improve implementation

## [0.3.1] - 2023-09-11

* Add missing fmt::Debug impls
Expand Down
8 changes: 4 additions & 4 deletions ntex-tls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.3.1"
version = "0.3.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
Expand All @@ -26,9 +26,9 @@ rustls = ["tls_rust"]

[dependencies]
ntex-bytes = "0.1.19"
ntex-io = "0.3.3"
ntex-util = "0.3.2"
ntex-service = "1.2.6"
ntex-io = "0.3.4"
ntex-util = "0.3.3"
ntex-service = "1.2.7"
log = "0.4"
pin-project-lite = "0.2"

Expand Down
70 changes: 24 additions & 46 deletions ntex-tls/src/openssl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::{Cell, RefCell};
use std::{any, cmp, error::Error, fmt, io, task::Context, task::Poll};
use std::cell::RefCell;
use std::{any, cmp, error::Error, fmt, io, task::Poll};

use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
use ntex_util::{future::BoxFuture, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509;

Expand All @@ -25,7 +25,6 @@ pub struct PeerCertChain(pub Vec<X509>);
#[derive(Debug)]
pub struct SslFilter {
inner: RefCell<SslStream<IoInner>>,
handshake: Cell<bool>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -147,7 +146,7 @@ impl FilterLayer for SslFilter {
buf.with_write_buf(|b| {
self.with_buffers(b, || {
buf.with_dst(|dst| {
let mut new_bytes = usize::from(self.handshake.get());
let mut new_bytes = 0;
loop {
buf.resize_buf(dst);

Expand Down Expand Up @@ -270,27 +269,22 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
destination: None,
};
let filter = SslFilter {
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);

poll_fn(|cx| {
log::debug!("Accepting tls connection");
loop {
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
if handle_result(&io, result).await?.is_some() {
break
}
}

io.filter().handshake.set(false);
Ok(io)
})
.await
Expand Down Expand Up @@ -327,55 +321,39 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
destination: None,
};
let filter = SslFilter {
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);

poll_fn(|cx| {
loop {
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
if handle_result(&io, result).await?.is_some() {
break
}
}

io.filter().handshake.set(false);
Ok(io)
})
}
}

fn handle_result<T, F>(
result: Result<T, ssl::Error>,
io: &Io<F>,
cx: &mut Context<'_>,
) -> Poll<Result<T, Box<dyn Error>>> {
async fn handle_result<T, F>(io: &Io<F>, result: Result<T, ssl::Error>) -> io::Result<Option<T>> {
match result {
Ok(v) => Poll::Ready(Ok(v)),
Ok(v) => Ok(Some(v)),
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ => {
match ready!(io.poll_read_ready(cx)) {
Ok(None) => Err::<_, Box<dyn Error>>(
io::Error::new(io::ErrorKind::Other, "disconnected").into(),
),
Err(err) => Err(err.into()),
_ => Ok(()),
}?;
Poll::Pending
}
ssl::ErrorCode::WANT_WRITE => {
let _ = io.poll_flush(cx, true)?;
Poll::Pending
let res = io.force_read_ready().await;
match res? {
None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")),

Check warning on line 351 in ntex-tls/src/openssl/mod.rs

View check run for this annotation

Codecov / codecov/patch

ntex-tls/src/openssl/mod.rs#L351

Added line #L351 was not covered by tests
_ => Ok(None),
}
}
_ => Poll::Ready(Err(Box::new(e))),
ssl::ErrorCode::WANT_WRITE => Ok(None),
_ => Err(io::Error::new(io::ErrorKind::Other, e)),

Check warning on line 356 in ntex-tls/src/openssl/mod.rs

View check run for this annotation

Codecov / codecov/patch

ntex-tls/src/openssl/mod.rs#L355-L356

Added lines #L355 - L356 were not covered by tests
},
}
}
Expand Down

0 comments on commit aa05f0d

Please sign in to comment.