diff --git a/src/rust/inetstack/protocols/layer4/tcp/active_open.rs b/src/rust/inetstack/protocols/layer4/tcp/active_open.rs index 6c7a74e16..3313fb4dc 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/active_open.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/active_open.rs @@ -14,7 +14,7 @@ use crate::{ constants::{FALLBACK_MSS, MAX_WINDOW_SCALE}, established::{ congestion_control::{self, CongestionControl}, - EstablishedSocket, + SharedEstablishedSocket, }, header::{TcpHeader, TcpOptions2}, SeqNumber, @@ -91,7 +91,7 @@ impl SharedActiveOpenSocket { }))) } - fn process_ack(&mut self, header: TcpHeader) -> Result { + fn process_ack(&mut self, header: TcpHeader) -> Result { let expected_seq: SeqNumber = self.local_isn + SeqNumber::from(1); // Bail if we didn't receive a ACK packet with the right sequence number. @@ -191,7 +191,7 @@ impl SharedActiveOpenSocket { "Window scale: local {}, remote {}", local_window_scale, remote_window_scale ); - Ok(EstablishedSocket::new( + Ok(SharedEstablishedSocket::new( self.local, self.remote, self.runtime.clone(), @@ -212,7 +212,7 @@ impl SharedActiveOpenSocket { )?) } - pub async fn connect(mut self) -> Result { + pub async fn connect(mut self) -> Result { // Start connection handshake. let handshake_retries: usize = self.tcp_config.get_handshake_retries(); let handshake_timeout = self.tcp_config.get_handshake_timeout(); diff --git a/src/rust/inetstack/protocols/layer4/tcp/established/ctrlblk.rs b/src/rust/inetstack/protocols/layer4/tcp/established/ctrlblk.rs index 42ad2f29a..c097e7df2 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/established/ctrlblk.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/established/ctrlblk.rs @@ -6,43 +6,18 @@ //====================================================================================================================== use crate::{ - async_timer, - collections::{async_queue::SharedAsyncQueue, async_value::SharedAsyncValue}, - inetstack::protocols::{ - layer3::SharedLayer3Endpoint, - layer4::tcp::{ - constants::MSL, - established::{ - congestion_control::{self, CongestionControlConstructor}, - receiver::Receiver, - sender::Sender, - }, - header::TcpHeader, - SeqNumber, - }, - MAX_HEADER_SIZE, - }, - runtime::{ - fail::Fail, - memory::DemiBuffer, - network::{config::TcpConfig, socket::option::TcpSocketOptions}, - yield_with_timeout, SharedDemiRuntime, SharedObject, - }, -}; -use ::futures::{never::Never, pin_mut, FutureExt}; -use ::std::{ - net::{Ipv4Addr, SocketAddrV4}, - ops::{Deref, DerefMut}, - time::{Duration, Instant}, + inetstack::protocols::layer4::tcp::{established::congestion_control, established::Receiver, established::Sender}, + runtime::network::{config::TcpConfig, socket::option::TcpSocketOptions}, }; +use ::std::net::SocketAddrV4; //====================================================================================================================== // Structures //====================================================================================================================== -// TCP Connection State. -// Note: This ControlBlock structure is only used after we've reached the ESTABLISHED state, so states LISTEN, -// SYN_RCVD, and SYN_SENT aren't included here. +/// TCP Connection State. +/// Note: This ControlBlock structure is only used after we've reached the ESTABLISHED state, so states LISTEN, +/// SYN_RCVD, and SYN_SENT aren't included here. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum State { Established, @@ -55,388 +30,46 @@ pub enum State { Closed, } -//====================================================================================================================== -// Control Block -//====================================================================================================================== - /// Transmission control block for representing our TCP connection. +/// This struct has only public members because includes state for both the send and receive path and is accessed by +/// both. pub struct ControlBlock { - local: SocketAddrV4, - remote: SocketAddrV4, - - layer3_endpoint: SharedLayer3Endpoint, - runtime: SharedDemiRuntime, - tcp_config: TcpConfig, - socket_options: TcpSocketOptions, - - // TCP Connection State. - state: State, - - // Send Sequence Variables from RFC 793. - - // SND.UNA - send unacknowledged - // SND.NXT - send next - // SND.WND - send window - // SND.UP - send urgent pointer - not implemented - // SND.WL1 - segment sequence number used for last window update - // SND.WL2 - segment acknowledgment number used for last window - // update - // ISS - initial send sequence number - - // Send queues - // SND.retrasmission_queue - queue of unacknowledged sent data. - // SND.unsent - queue of unsent data that we do not have the windows for. - // Previous send variables and queues. - // TODO: Consider incorporating this directly into ControlBlock. - sender: Sender, - // Receive Sequence Variables from RFC 793. - - // RCV.NXT - receive next - // RCV.WND - receive window - // RCV.UP - receive urgent pointer - not implemented - // IRS - initial receive sequence number - // Receive-side state information. TODO: Consider incorporating this directly into ControlBlock. - receiver: Receiver, + pub local: SocketAddrV4, + pub remote: SocketAddrV4, + pub tcp_config: TcpConfig, + pub socket_options: TcpSocketOptions, + pub state: State, + pub sender: Sender, + pub receiver: Receiver, // Congestion control trait implementation we're currently using. // TODO: Consider switching this to a static implementation to avoid V-table call overhead. - congestion_control_algorithm: Box, + pub congestion_control_algorithm: Box, } -#[derive(Clone)] -pub struct SharedControlBlock(SharedObject); +//====================================================================================================================== +// Associated Functions //====================================================================================================================== -impl SharedControlBlock { +impl ControlBlock { pub fn new( local: SocketAddrV4, remote: SocketAddrV4, - layer3_endpoint: SharedLayer3Endpoint, - runtime: SharedDemiRuntime, tcp_config: TcpConfig, - default_socket_options: TcpSocketOptions, - // In RFC 793, this is IRS. - receive_initial_seq_no: SeqNumber, - receive_ack_delay_timeout_secs: Duration, - receive_window_size_frames: u32, - receive_window_scale_shift_bits: u8, - // In RFC 793, this ISS. - sender_initial_seq_no: SeqNumber, - send_window_size_frames: u32, - send_window_scale_shift_bits: u8, - sender_mss: usize, - congestion_control_algorithm_constructor: CongestionControlConstructor, - congestion_control_options: Option, - mut recv_queue: SharedAsyncQueue<(Ipv4Addr, TcpHeader, DemiBuffer)>, + socket_options: TcpSocketOptions, + sender: Sender, + receiver: Receiver, + congestion_control_algorithm: Box, ) -> Self { - let sender: Sender = Sender::new( - sender_initial_seq_no, - send_window_size_frames, - send_window_scale_shift_bits, - sender_mss, - ); - let receiver: Receiver = Receiver::new( - receive_initial_seq_no, - receive_initial_seq_no, - receive_ack_delay_timeout_secs, - receive_window_size_frames, - receive_window_scale_shift_bits, - ); - let congestion_control_algorithm = - congestion_control_algorithm_constructor(sender_mss, sender_initial_seq_no, congestion_control_options); - let mut self_: Self = Self(SharedObject::::new(ControlBlock { + Self { local, remote, - layer3_endpoint, - runtime, tcp_config, - socket_options: default_socket_options, - sender, + socket_options, state: State::Established, + sender, receiver, congestion_control_algorithm, - })); - trace!("receive_queue size {:?}", recv_queue.len()); - // Process all pending received packets while setting up the connection. - while let Some((_ipv4_addr, header, data)) = recv_queue.try_pop() { - self_.receive(header, data); - } - - self_ - } - - pub fn get_local(&self) -> SocketAddrV4 { - self.local - } - - pub fn get_remote(&self) -> SocketAddrV4 { - self.remote - } - - pub fn get_now(&self) -> Instant { - self.runtime.get_now() - } - - pub fn receive(&mut self, tcp_hdr: TcpHeader, buf: DemiBuffer) { - debug!( - "{:?} Connection Receiving {} bytes + {:?}", - self.state, - buf.len(), - tcp_hdr, - ); - - let cb: Self = self.clone(); - let now: Instant = self.runtime.get_now(); - self.receiver.receive(tcp_hdr, buf, cb, now); - } - - pub fn congestion_control_watch_retransmit_now_flag(&self) -> SharedAsyncValue { - self.congestion_control_algorithm.get_retransmit_now_flag() - } - - pub fn congestion_control_on_fast_retransmit(&mut self) { - self.congestion_control_algorithm.on_fast_retransmit() - } - - pub fn congestion_control_on_rto(&mut self, send_unacknowledged: SeqNumber) { - self.congestion_control_algorithm.on_rto(send_unacknowledged) - } - - pub fn congestion_control_on_send(&mut self, rto: Duration, num_sent_bytes: u32) { - self.congestion_control_algorithm.on_send(rto, num_sent_bytes) - } - - pub fn congestion_control_on_cwnd_check_before_send(&mut self) { - self.congestion_control_algorithm.on_cwnd_check_before_send() - } - - pub fn congestion_control_get_cwnd(&self) -> SharedAsyncValue { - self.congestion_control_algorithm.get_cwnd() - } - - pub fn congestion_control_get_limited_transmit_cwnd_increase(&self) -> SharedAsyncValue { - self.congestion_control_algorithm.get_limited_transmit_cwnd_increase() - } - - pub fn process_ack(&mut self, header: &TcpHeader, now: Instant) -> Result<(), Fail> { - let send_unacknowledged: SeqNumber = self.sender.get_unacked_seq_no(); - let send_next: SeqNumber = self.sender.get_next_seq_no(); - - // TODO: Restructure this call into congestion control to either integrate it directly or make it more fine- - // grained. It currently duplicates the new/duplicate ack check itself internally, which is inefficient. - // We should either make separate calls for each case or integrate those cases directly. - let rto: Duration = self.sender.get_rto(); - - self.congestion_control_algorithm - .on_ack_received(rto, send_unacknowledged, send_next, header.ack_num); - - // Check whether this is an ack for data that we have sent. - if header.ack_num <= send_next { - // Does not matter when we get this since the clock will not move between the beginning of packet - // processing and now without a call to advance_clock. - self.sender.process_ack(header, now); - } else { - // This segment acknowledges data we have yet to send!? Send an ACK and drop the segment. - // TODO: See RFC 5961, this could be a Blind Data Injection Attack. - let cause: String = format!("Received segment acknowledging data we have yet to send!"); - warn!("process_ack(): {}", cause); - self.send_ack(); - return Err(Fail::new(libc::EBADMSG, &cause)); } - Ok(()) - } - - pub fn get_unacked_seq_no(&self) -> SeqNumber { - self.sender.get_unacked_seq_no() - } - - /// Fetch a TCP header filling out various values based on our current state. - /// TODO: Fix the "filling out various values based on our current state" part to actually do that correctly. - pub fn tcp_header(&self) -> TcpHeader { - let mut header: TcpHeader = TcpHeader::new(self.local.port(), self.remote.port()); - header.window_size = self.receiver.hdr_window_size(); - - // Note that once we reach a synchronized state we always include a valid acknowledgement number. - header.ack = true; - header.ack_num = self.receiver.receive_next_seq_no(); - - // Return this header. - header - } - - /// Send an ACK to our peer, reflecting our current state. - pub fn send_ack(&mut self) { - trace!("sending ack"); - let mut header: TcpHeader = self.tcp_header(); - - // TODO: Think about moving this to tcp_header() as well. - let seq_num: SeqNumber = self.sender.get_next_seq_no(); - header.seq_num = seq_num; - self.emit(header, None); - } - - /// Transmit this message to our connected peer. - pub fn emit(&mut self, header: TcpHeader, body: Option) { - // Only perform this debug print in debug builds. debug_assertions is compiler set in non-optimized builds. - let mut pkt = match body { - Some(body) => { - debug!("Sending {} bytes + {:?}", body.len(), header); - body - }, - _ => { - debug!("Sending 0 bytes + {:?}", header); - DemiBuffer::new_with_headroom(0, MAX_HEADER_SIZE as u16) - }, - }; - - // This routine should only ever be called to send TCP segments that contain a valid ACK value. - debug_assert!(header.ack); - - let remote_ipv4_addr: Ipv4Addr = self.remote.ip().clone(); - header.serialize_and_attach( - &mut pkt, - self.local.ip(), - self.remote.ip(), - self.tcp_config.get_tx_checksum_offload(), - ); - - // Call lower L3 layer to send the segment. - if let Err(e) = self - .layer3_endpoint - .transmit_tcp_packet_nonblocking(remote_ipv4_addr, pkt) - { - warn!("could not emit packet: {:?}", e); - return; - } - - // Post-send operations follow. - // Review: We perform these after the send, in order to keep send latency as low as possible. - - // Since we sent an ACK, cancel any outstanding delayed ACK request. - self.receiver.set_receive_ack_deadline(None); - } - pub async fn push(&mut self, buf: DemiBuffer) -> Result<(), Fail> { - let cb: Self = self.clone(); - self.sender.push(buf, cb).await - } - - pub async fn pop(&mut self, size: Option) -> Result { - self.receiver.pop(size).await - } - - pub fn process_fin(&mut self) { - let state = match self.state { - State::Established => State::CloseWait, - State::FinWait1 => State::Closing, - State::FinWait2 => State::TimeWait, - state => unreachable!("Cannot be in any other state at this point: {:?}", state), - }; - self.state = state; - } - - pub fn get_state(&self) -> State { - self.state - } - // This coroutine runs the close protocol. - pub async fn close(&mut self) -> Result<(), Fail> { - // Assert we are in a valid state and move to new state. - match self.state { - State::Established => self.local_close().await, - State::CloseWait => self.remote_already_closed().await, - _ => { - let cause: String = format!("socket is already closing"); - error!("close(): {}", cause); - Err(Fail::new(libc::EBADF, &cause)) - }, - } - } - - async fn local_close(&mut self) -> Result<(), Fail> { - // 1. Start close protocol by setting state and sending FIN. - self.state = State::FinWait1; - self.sender.push_fin_and_wait_for_ack().await?; - - // 2. Got ACK to our FIN. Check if we also received a FIN from remote in the meantime. - let state: State = self.state; - match state { - State::FinWait1 => { - self.state = State::FinWait2; - // Haven't received a FIN yet from remote, so wait. - self.receiver.wait_for_fin().await?; - }, - State::Closing => self.state = State::TimeWait, - state => unreachable!("Cannot be in any other state at this point: {:?}", state), - }; - // 3. TIMED_WAIT - debug_assert_eq!(self.state, State::TimeWait); - trace!("socket options: {:?}", self.socket_options.get_linger()); - let timeout: Duration = self.socket_options.get_linger().unwrap_or(MSL * 2); - yield_with_timeout(timeout).await; - self.state = State::Closed; - Ok(()) - } - - async fn remote_already_closed(&mut self) -> Result<(), Fail> { - // 0. Move state forward - self.state = State::LastAck; - // 1. Send FIN and wait for ack before closing. - self.sender.push_fin_and_wait_for_ack().await?; - self.state = State::Closed; - Ok(()) - } - - pub async fn background(&self) { - let acknowledger = async_timer!( - "tcp::established::background::acknowledger", - self.clone().background_acknowledger() - ) - .fuse(); - pin_mut!(acknowledger); - - let retransmitter = async_timer!( - "tcp::established::background::retransmitter", - self.clone().background_retransmitter() - ) - .fuse(); - pin_mut!(retransmitter); - - let sender = async_timer!("tcp::established::background::sender", self.clone().background_sender()).fuse(); - pin_mut!(sender); - - let r = futures::join!(acknowledger, retransmitter, sender); - error!("Connection terminated: {:?}", r); - } - - pub async fn background_retransmitter(mut self) -> Result { - let cb: Self = self.clone(); - self.sender.background_retransmitter(cb).await - } - - pub async fn background_sender(mut self) -> Result { - let cb: Self = self.clone(); - self.sender.background_sender(cb).await - } - - pub async fn background_acknowledger(mut self) -> Result { - let cb: Self = self.clone(); - self.receiver.acknowledger(cb).await - } -} - -//====================================================================================================================== -// Trait Implementations -//====================================================================================================================== - -impl Deref for SharedControlBlock { - type Target = ControlBlock; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -impl DerefMut for SharedControlBlock { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.deref_mut() } } diff --git a/src/rust/inetstack/protocols/layer4/tcp/established/mod.rs b/src/rust/inetstack/protocols/layer4/tcp/established/mod.rs index 5adab7ffb..06fc7df8a 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/established/mod.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/established/mod.rs @@ -8,110 +8,227 @@ mod rto; mod sender; use crate::{ + async_timer, collections::async_queue::SharedAsyncQueue, inetstack::protocols::{ layer3::SharedLayer3Endpoint, layer4::tcp::{ - congestion_control::CongestionControlConstructor, established::ctrlblk::SharedControlBlock, - header::TcpHeader, SeqNumber, + congestion_control::CongestionControlConstructor, + established::{ctrlblk::ControlBlock, ctrlblk::State, receiver::Receiver, sender::Sender}, + header::TcpHeader, + SeqNumber, }, }, runtime::{ fail::Fail, memory::DemiBuffer, - network::{config::TcpConfig, socket::option::TcpSocketOptions}, - SharedDemiRuntime, + network::{config::TcpConfig, consts::MSL, socket::option::TcpSocketOptions}, + yield_with_timeout, SharedDemiRuntime, SharedObject, }, - QToken, }; +use ::futures::pin_mut; use ::futures::FutureExt; use ::std::{ net::{Ipv4Addr, SocketAddrV4}, + ops::{Deref, DerefMut}, time::Duration, + time::Instant, }; -#[derive(Clone)] pub struct EstablishedSocket { - pub cb: SharedControlBlock, - // We need this to eventually stop the background task on close. - #[allow(unused)] + // All shared state for this established TCP connection. + cb: ControlBlock, runtime: SharedDemiRuntime, - /// The background co-routines handles various tasks, such as retransmission and acknowledging. - /// We annotate it as unused because the compiler believes that it is never called which is not the case. - #[allow(unused)] - background_task_qt: QToken, + layer3_endpoint: SharedLayer3Endpoint, } -impl EstablishedSocket { +#[derive(Clone)] +pub struct SharedEstablishedSocket(SharedObject); + +impl SharedEstablishedSocket { pub fn new( local: SocketAddrV4, remote: SocketAddrV4, mut runtime: SharedDemiRuntime, layer3_endpoint: SharedLayer3Endpoint, - recv_queue: SharedAsyncQueue<(Ipv4Addr, TcpHeader, DemiBuffer)>, + mut recv_queue: SharedAsyncQueue<(Ipv4Addr, TcpHeader, DemiBuffer)>, tcp_config: TcpConfig, default_socket_options: TcpSocketOptions, receiver_seq_no: SeqNumber, - ack_delay_timeout: Duration, - receiver_window_size: u32, - receiver_window_scale: u8, + ack_delay_timeout_secs: Duration, + receiver_window_size_frames: u32, + receiver_window_scale_bits: u8, sender_seq_no: SeqNumber, - sender_window_size: u32, - sender_window_scale: u8, + sender_window_size_frames: u32, + sender_window_scale_bits: u8, sender_mss: usize, cc_constructor: CongestionControlConstructor, congestion_control_options: Option, ) -> Result { - // TODO: Maybe add the queue descriptor here. - let cb = SharedControlBlock::new( + let sender: Sender = Sender::new( + sender_seq_no, + sender_window_size_frames, + sender_window_scale_bits, + sender_mss, + ); + let receiver: Receiver = Receiver::new( + receiver_seq_no, + receiver_seq_no, + ack_delay_timeout_secs, + receiver_window_size_frames, + receiver_window_scale_bits, + ); + + let congestion_control_algorithm = cc_constructor(sender_mss, sender_seq_no, congestion_control_options); + let cb = ControlBlock::new( local, remote, - layer3_endpoint, - runtime.clone(), tcp_config, default_socket_options, - receiver_seq_no, - ack_delay_timeout, - receiver_window_size, - receiver_window_scale, - sender_seq_no, - sender_window_size, - sender_window_scale, - sender_mss, - cc_constructor, - congestion_control_options, - recv_queue.clone(), + sender, + receiver, + congestion_control_algorithm, ); + let mut me: Self = Self(SharedObject::new(EstablishedSocket { + cb, + runtime: runtime.clone(), + layer3_endpoint, + })); - let cb2: SharedControlBlock = cb.clone(); - let qt: QToken = runtime.insert_background_coroutine( + trace!("inital receive_queue size {:?}", recv_queue.len()); + // Process all pending received packets while setting up the connection. + while let Some((_ipv4_addr, header, data)) = recv_queue.try_pop() { + me.receive(header, data); + } + let me2: Self = me.clone(); + runtime.insert_background_coroutine( "bgc::inetstack::tcp::established::background", - Box::pin(async move { cb2.background().await }.fuse()), + Box::pin(async move { me2.background().await }.fuse()), )?; - Ok(Self { - cb, - background_task_qt: qt.clone(), - runtime: runtime.clone(), - }) + Ok(me) + } + + pub fn receive(&mut self, tcp_hdr: TcpHeader, buf: DemiBuffer) { + debug!( + "{:?} Connection Receiving {} bytes + {:?}", + self.cb.state, + buf.len(), + tcp_hdr, + ); + + let now: Instant = self.runtime.get_now(); + let mut layer3_endpoint: SharedLayer3Endpoint = self.layer3_endpoint.clone(); + Receiver::receive(&mut self.cb, &mut layer3_endpoint, tcp_hdr, buf, now); + } + + // This coroutine runs the close protocol. + pub async fn close(&mut self) -> Result<(), Fail> { + // Assert we are in a valid state and move to new state. + match self.cb.state { + State::Established => self.local_close().await, + State::CloseWait => self.remote_already_closed().await, + _ => { + let cause: String = format!("socket is already closing"); + error!("close(): {}", cause); + Err(Fail::new(libc::EBADF, &cause)) + }, + } + } + + async fn local_close(&mut self) -> Result<(), Fail> { + // 1. Start close protocol by setting state and sending FIN. + self.cb.state = State::FinWait1; + Sender::push_fin_and_wait_for_ack(&mut self.cb).await?; + + // 2. Got ACK to our FIN. Check if we also received a FIN from remote in the meantime. + let state: State = self.cb.state; + match state { + State::FinWait1 => { + self.cb.state = State::FinWait2; + // Haven't received a FIN yet from remote, so wait. + self.cb.receiver.wait_for_fin().await?; + }, + State::Closing => self.cb.state = State::TimeWait, + state => unreachable!("Cannot be in any other state at this point: {:?}", state), + }; + // 3. TIMED_WAIT + debug_assert_eq!(self.cb.state, State::TimeWait); + trace!("socket options: {:?}", self.cb.socket_options.get_linger()); + let timeout: Duration = self.cb.socket_options.get_linger().unwrap_or(MSL * 2); + yield_with_timeout(timeout).await; + self.cb.state = State::Closed; + Ok(()) + } + + async fn remote_already_closed(&mut self) -> Result<(), Fail> { + // 0. Move state forward + self.cb.state = State::LastAck; + // 1. Send FIN and wait for ack before closing. + Sender::push_fin_and_wait_for_ack(&mut self.cb).await?; + self.cb.state = State::Closed; + Ok(()) } pub async fn push(&mut self, buf: DemiBuffer) -> Result<(), Fail> { - self.cb.push(buf).await + let mut runtime: SharedDemiRuntime = self.runtime.clone(); + let mut layer3_endpoint: SharedLayer3Endpoint = self.layer3_endpoint.clone(); + Sender::push(&mut self.cb, &mut layer3_endpoint, &mut runtime, buf).await } pub async fn pop(&mut self, size: Option) -> Result { - self.cb.pop(size).await + self.cb.receiver.pop(size).await } - pub async fn close(&mut self) -> Result<(), Fail> { - self.cb.close().await + pub fn endpoints(&self) -> (SocketAddrV4, SocketAddrV4) { + (self.cb.local, self.cb.remote) } - pub fn endpoints(&self) -> (SocketAddrV4, SocketAddrV4) { - (self.cb.get_local(), self.cb.get_remote()) + async fn background(self) { + let mut me: Self = self.clone(); + let acknowledger = async_timer!("tcp::established::background::acknowledger", async { + let mut layer3_endpoint: SharedLayer3Endpoint = me.layer3_endpoint.clone(); + Receiver::acknowledger(&mut me.cb, &mut layer3_endpoint).await + }) + .fuse(); + pin_mut!(acknowledger); + + let mut me2: Self = self.clone(); + let retransmitter = async_timer!("tcp::established::background::retransmitter", async { + let mut layer3_endpoint: SharedLayer3Endpoint = me2.layer3_endpoint.clone(); + let mut runtime: SharedDemiRuntime = me2.runtime.clone(); + Sender::background_retransmitter(&mut me2.cb, &mut layer3_endpoint, &mut runtime).await + }) + .fuse(); + pin_mut!(retransmitter); + + let mut me3: Self = self.clone(); + let sender = async_timer!("tcp::established::background::sender", async { + let mut layer3_endpoint: SharedLayer3Endpoint = me3.layer3_endpoint.clone(); + let mut runtime: SharedDemiRuntime = me3.runtime.clone(); + Sender::background_sender(&mut me3.cb, &mut layer3_endpoint, &mut runtime).await + }) + .fuse(); + pin_mut!(sender); + + let result = futures::join!(acknowledger, retransmitter, sender); + debug!("{:?}", result); } +} + +//====================================================================================================================== +// Trait Implementations +//====================================================================================================================== + +impl Deref for SharedEstablishedSocket { + type Target = EstablishedSocket; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} - pub fn get_cb(&self) -> SharedControlBlock { - self.cb.clone() +impl DerefMut for SharedEstablishedSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() } } diff --git a/src/rust/inetstack/protocols/layer4/tcp/established/receiver.rs b/src/rust/inetstack/protocols/layer4/tcp/established/receiver.rs index 9a56901ab..7e98ff302 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/established/receiver.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/established/receiver.rs @@ -13,8 +13,13 @@ use ::std::{ use crate::{ collections::{async_queue::AsyncQueue, async_value::SharedAsyncValue}, expect_ok, - inetstack::protocols::layer4::tcp::{ - established::ctrlblk::State, established::SharedControlBlock, header::TcpHeader, SeqNumber, + inetstack::protocols::{ + layer3::SharedLayer3Endpoint, + layer4::tcp::{ + established::{ctrlblk::State, ControlBlock, Sender}, + header::TcpHeader, + SeqNumber, + }, }, runtime::{fail::Fail, memory::DemiBuffer}, }; @@ -31,10 +36,9 @@ use ::futures::never::Never; const MAX_OUT_OF_ORDER_SIZE_FRAMES: usize = 16; //====================================================================================================================== -// Data Structures +// Structures //====================================================================================================================== -// TODO: Consider incorporating this directly into ControlBlock. pub struct Receiver { // // Receive Sequence Space: @@ -54,7 +58,7 @@ pub struct Receiver { reader_next_seq_no: SeqNumber, // Sequence number of the next byte of data (or FIN) that we expect to receive. In RFC 793 terms, this is RCV.NXT. - receive_next_seq_no: SeqNumber, + pub receive_next_seq_no: SeqNumber, // Sequnce number of the last byte of data (FIN). fin_seq_no: SharedAsyncValue>, @@ -62,24 +66,21 @@ pub struct Receiver { // Pop queue. Contains in-order received (and acknowledged) data ready for the application to read. pop_queue: AsyncQueue, + // The amount of time before we will send a bare ACK. ack_delay_timeout_secs: Duration, - - ack_deadline_time_secs: SharedAsyncValue>, + // The deadline when we will send a bare ACK if there are no outgoing packets by then. + pub ack_deadline_time_secs: SharedAsyncValue>, // This is our receive buffer size, which is also the maximum size of our receive window. // Note: The maximum possible advertised window is 1 GiB with window scaling and 64 KiB without. buffer_size_frames: u32, - // TODO: Review how this is used. We could have separate window scale factors, so there should be one for the - // receiver and one for the sender. - // This is the receive-side window scale factor. // This is the number of bits to shift to convert to/from the scaled value, and has a maximum value of 14. window_scale_shift_bits: u8, // Queue of out-of-order segments. This is where we hold onto data that we've received (because it was within our // receive window) but can't yet present to the user because we're missing some other data that comes between this // and what we've already presented to the user. - // out_of_order_frames: VecDeque<(SeqNumber, DemiBuffer)>, } @@ -108,6 +109,86 @@ impl Receiver { } } + // This function causes a EOF to be returned to the user. We also know that there will be no more incoming + // data after this sequence number. + fn push_fin(&mut self) { + self.pop_queue.push(DemiBuffer::new(0)); + debug_assert_eq!(self.receive_next_seq_no, self.fin_seq_no.get().unwrap()); + // Reset it to wake up any close coroutines waiting for FIN to arrive. + self.fin_seq_no.set(Some(self.receive_next_seq_no)); + // Move RECV_NXT over the FIN. + self.receive_next_seq_no = self.receive_next_seq_no + 1.into(); + } + + pub fn get_receive_window_size(&self) -> u32 { + let bytes_unread: u32 = (self.receive_next_seq_no - self.reader_next_seq_no).into(); + self.buffer_size_frames - bytes_unread + } + + pub fn hdr_window_size(&self) -> u16 { + let window_size: u32 = self.get_receive_window_size(); + let hdr_window_size: u16 = expect_ok!( + (window_size >> self.window_scale_shift_bits).try_into(), + "Window size overflow" + ); + debug!( + "Window size -> {} (hdr {}, scale {})", + (hdr_window_size as u32) << self.window_scale_shift_bits, + hdr_window_size, + self.window_scale_shift_bits, + ); + hdr_window_size + } + + // This routine takes an incoming in-order TCP segment and adds the data to the user's receive queue. If the new + // segment fills a "hole" in the receive sequence number space allowing previously stored out-of-order data to now + // be received, it receives that too. + // + // This routine also updates receive_next to reflect any data now considered "received". + fn receive_data(&mut self, seg_start: SeqNumber, buf: DemiBuffer) { + // This routine should only be called with in-order segment data. + debug_assert_eq!(seg_start, self.receive_next_seq_no); + + // Push the new segment data onto the end of the receive queue. + self.receive_next_seq_no = self.receive_next_seq_no + SeqNumber::from(buf.len() as u32); + // This inserts the segment and wakes a waiting pop coroutine. + self.pop_queue.push(buf); + + // Okay, we've successfully received some new data. Check if any of the formerly out-of-order data waiting in + // the out-of-order queue is now in-order. If so, we can move it to the receive queue. + while !self.out_of_order_frames.is_empty() { + if let Some(stored_entry) = self.out_of_order_frames.front() { + if stored_entry.0 == self.receive_next_seq_no { + // Move this entry's buffer from the out-of-order store to the receive queue. + // This data is now considered to be "received" by TCP, and included in our RCV.NXT calculation. + debug!("Recovering out-of-order packet at {}", self.receive_next_seq_no); + if let Some(temp) = self.out_of_order_frames.pop_front() { + self.receive_next_seq_no = self.receive_next_seq_no + SeqNumber::from(temp.1.len() as u32); + // This inserts the segment and wakes a waiting pop coroutine. + self.pop_queue.push(temp.1); + } + } else { + // Since our out-of-order list is sorted, we can stop when the next segment is not in sequence. + break; + } + } + } + } + + // Block until the remote sends a FIN (plus all previous data has arrived). + pub async fn wait_for_fin(&mut self) -> Result<(), Fail> { + let mut fin_seq_no: Option = self.fin_seq_no.get(); + loop { + match fin_seq_no { + Some(fin_seq_no) if self.receive_next_seq_no >= fin_seq_no => return Ok(()), + _ => { + fin_seq_no = self.fin_seq_no.wait_for_change(None).await?; + }, + } + } + } + + // Block until some data is received, up to an optional size. pub async fn pop(&mut self, size: Option) -> Result { debug!("waiting on pop {:?}", size); let buf: DemiBuffer = if let Some(size) = size { @@ -135,44 +216,28 @@ impl Receiver { Ok(buf) } - pub fn receive(&mut self, tcp_hdr: TcpHeader, buf: DemiBuffer, cb: SharedControlBlock, now: Instant) { - match self.process_packet(tcp_hdr, buf, cb, now) { + // Receive a single incoming packet from layer3. + pub fn receive( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + tcp_hdr: TcpHeader, + buf: DemiBuffer, + now: Instant, + ) { + match Self::process_packet(cb, layer3_endpoint, tcp_hdr, buf, now) { Ok(()) => (), Err(e) => debug!("Dropped packet: {:?}", e), } } - fn push_fin(&mut self) { - debug!("notifying FIN"); - self.pop_queue.push(DemiBuffer::new(0)); - debug_assert_eq!(self.receive_next_seq_no, self.fin_seq_no.get().unwrap()); - // Reset it to wake up any close coroutines waiting for FIN to arrive. - self.fin_seq_no.set(Some(self.receive_next_seq_no)); - // Move RECV_NXT over the FIN. - self.receive_next_seq_no = self.receive_next_seq_no + 1.into(); - } - - // Return Ok after FIN arrives (plus all previous data). - pub async fn wait_for_fin(&mut self) -> Result<(), Fail> { - let mut fin_seq_no: Option = self.fin_seq_no.get(); - loop { - match fin_seq_no { - Some(fin_seq_no) if self.receive_next_seq_no >= fin_seq_no => return Ok(()), - _ => { - fin_seq_no = self.fin_seq_no.wait_for_change(None).await?; - }, - } - } - } - /// This is the main function for processing an incoming packet during the Established state when the connection is /// active. Each step in this function return Ok if there is further processing to be done and EBADMSG if the /// packet should be dropped after the step. fn process_packet( - &mut self, + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, mut header: TcpHeader, mut data: DemiBuffer, - mut cb: SharedControlBlock, now: Instant, ) -> Result<(), Fail> { let mut seg_start: SeqNumber = header.seq_num; @@ -180,17 +245,18 @@ impl Receiver { let mut seg_len: u32 = data.len() as u32; // Check if the segment is in the receive window and trim off everything else. - self.check_segment_in_window( + Self::check_segment_in_window( + cb, + layer3_endpoint, &mut header, &mut data, &mut seg_start, &mut seg_end, &mut seg_len, - &mut cb, )?; - self.check_rst(&header)?; - self.check_syn(&header)?; - self.process_ack(&header, &mut cb, now)?; + Self::check_rst(&header)?; + Self::check_syn(&header)?; + Self::check_and_process_ack(cb, &header, now)?; // TODO: Check the URG bit. If we decide to support this, how should we do it? if header.urg { @@ -198,12 +264,12 @@ impl Receiver { } if data.len() > 0 { - self.process_data(data, seg_start, seg_end, seg_len, &mut cb)?; + Self::process_data(cb, layer3_endpoint, data, seg_start, seg_end, seg_len)?; } - // Process FIN flag. + // Deal with FIN flag, saving the FIN for later if it is out of order. if header.fin { - match self.fin_seq_no.get() { + match cb.receiver.fin_seq_no.get() { // We've already received this FIN, so ignore. Some(seq_no) if seq_no != seg_end => warn!( "Received a FIN with a different sequence number, ignoring. previous={:?} new={:?}", @@ -212,35 +278,40 @@ impl Receiver { Some(_) => trace!("Received duplicate FIN"), None => { trace!("Received FIN"); - self.fin_seq_no.set(seg_end.into()) + cb.receiver.fin_seq_no.set(seg_end.into()); }, } - } - // Check whether we've received the last packet. - if self + }; + + // Check whether we've received the last packet in this TCP stream. + if cb + .receiver .fin_seq_no .get() - .is_some_and(|seq_no| seq_no == self.receive_next_seq_no) + .is_some_and(|seq_no| seq_no == cb.receiver.receive_next_seq_no) { - self.process_fin(&mut cb); + // Once we know there is no more data coming, begin closing down the connection. + Self::process_fin(cb); } + + // Send an ack on every FIN. We do this separately here because if the FIN is in order, we ack it after the + // previous line, otherwise we do not ack the FIN. if header.fin { - // Send ack for out of order FIN. trace!("Acking FIN"); - cb.send_ack() + Sender::send_ack(cb, layer3_endpoint) } + // We should ACK this segment, preferably via piggybacking on a response. - // TODO: Consider replacing the delayed ACK timer with a simple flag. - if self.ack_deadline_time_secs.get().is_none() { + if cb.receiver.ack_deadline_time_secs.get().is_none() { // Start the delayed ACK timer to ensure an ACK gets sent soon even if no piggyback opportunity occurs. - let timeout: Duration = self.ack_delay_timeout_secs; + let timeout: Duration = cb.receiver.ack_delay_timeout_secs; // Getting the current time is extremely cheap as it is just a variable lookup. - self.ack_deadline_time_secs.set(Some(now + timeout)); + cb.receiver.ack_deadline_time_secs.set(Some(now + timeout)); } else { // We already owe our peer an ACK (the timer was already running), so cancel the timer and ACK now. - self.ack_deadline_time_secs.set(None); + cb.receiver.ack_deadline_time_secs.set(None); trace!("process_packet(): sending ack on deadline expiration"); - cb.send_ack(); + Sender::send_ack(cb, layer3_endpoint); } Ok(()) @@ -250,15 +321,14 @@ impl Receiver { // window, or is a non-data segment with a sequence number that falls within the window). Unacceptable segments // should be ACK'd (unless they are RSTs), and then dropped. // Returns Ok if further processing is needed and EBADMSG if the packet is not within the receive window. - fn check_segment_in_window( - &mut self, + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, header: &mut TcpHeader, data: &mut DemiBuffer, seg_start: &mut SeqNumber, seg_end: &mut SeqNumber, seg_len: &mut u32, - cb: &mut SharedControlBlock, ) -> Result<(), Fail> { // [From RFC 793] // There are four cases for the acceptability test for an incoming segment: @@ -289,9 +359,9 @@ impl Receiver { *seg_end = *seg_start + SeqNumber::from(*seg_len - 1); } - let receive_next: SeqNumber = self.receive_next_seq_no; + let receive_next: SeqNumber = cb.receiver.receive_next_seq_no; - let after_receive_window: SeqNumber = receive_next + SeqNumber::from(self.get_receive_window_size()); + let after_receive_window: SeqNumber = receive_next + SeqNumber::from(cb.receiver.get_receive_window_size()); // Check if this segment fits in our receive window. // In the optimal case it starts at RCV.NXT, so we check for that first. @@ -302,10 +372,9 @@ impl Receiver { // See if it is a complete duplicate, or if some of the data is new. if *seg_end < receive_next { // This is an entirely duplicate (i.e. old) segment. ACK (if not RST) and drop. - // if !header.rst { trace!("check_segment_in_window(): send ack on duplicate segment"); - cb.send_ack(); + Sender::send_ack(cb, layer3_endpoint); } let cause: String = format!("duplicate packet"); error!("check_segment_in_window(): {}", cause); @@ -313,7 +382,6 @@ impl Receiver { } else { // Some of this segment's data is new. Cut the duplicate data off of the front. // If there is a SYN at the start of this segment, remove it too. - // let mut duplicate: u32 = u32::from(receive_next - *seg_start); *seg_start = *seg_start + SeqNumber::from(duplicate); *seg_len -= duplicate; @@ -329,13 +397,11 @@ impl Receiver { } else { // This segment contains entirely new data, but is later in the sequence than what we're expecting. // See if any part of the data fits within our receive window. - // if *seg_start >= after_receive_window { // This segment is completely outside of our window. ACK (if not RST) and drop. - // if !header.rst { trace!("check_segment_in_window(): send ack on out-of-window segment"); - cb.send_ack(); + Sender::send_ack(cb, layer3_endpoint); } let cause: String = format!("packet outside of receive window"); error!("check_segment_in_window(): {}", cause); @@ -373,9 +439,8 @@ impl Receiver { debug_assert!(receive_next <= *seg_start && *seg_end < after_receive_window); Ok(()) } - // Check the RST bit. - fn check_rst(&mut self, header: &TcpHeader) -> Result<(), Fail> { + fn check_rst(header: &TcpHeader) -> Result<(), Fail> { if header.rst { // TODO: RFC 5961 "Blind Reset Attack Using the RST Bit" prevention would have us ACK and drop if the new // segment doesn't start precisely on RCV.NXT. @@ -391,7 +456,7 @@ impl Receiver { } // Check the SYN bit. - fn check_syn(&mut self, header: &TcpHeader) -> Result<(), Fail> { + fn check_syn(header: &TcpHeader) -> Result<(), Fail> { // Note: RFC 793 says to check security/compartment and precedence next, but those are largely deprecated. // Check the SYN bit. @@ -412,7 +477,7 @@ impl Receiver { } // Check the ACK bit. - fn process_ack(&mut self, header: &TcpHeader, cb: &mut SharedControlBlock, now: Instant) -> Result<(), Fail> { + fn check_and_process_ack(cb: &mut ControlBlock, header: &TcpHeader, now: Instant) -> Result<(), Fail> { if !header.ack { // All segments on established connections should be ACKs. Drop this segment. let cause: String = format!("Received non-ACK segment on established connection"); @@ -423,34 +488,32 @@ impl Receiver { // TODO: RFC 5961 "Blind Data Injection Attack" prevention would have us perform additional ACK validation // checks here. - // Process the ACK. - // Start by checking that the ACK acknowledges something new. - // TODO: Look into removing Watched types. - // - cb.process_ack(header, now) + Sender::process_ack(cb, header, now); + + Ok(()) } fn process_data( - &mut self, + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, data: DemiBuffer, seg_start: SeqNumber, seg_end: SeqNumber, seg_len: u32, - cb: &mut SharedControlBlock, ) -> Result<(), Fail> { // We can only process in-order data. Check for out-of-order segment. - if seg_start != self.receive_next_seq_no { + if seg_start != cb.receiver.receive_next_seq_no { debug!("Received out-of-order segment"); debug_assert_ne!(seg_len, 0); // This segment is out-of-order. If it carries data, we should store it for later processing // after the "hole" in the sequence number space has been filled. - match cb.get_state() { + match cb.state { State::Established | State::FinWait1 | State::FinWait2 => { debug_assert_eq!(seg_len, data.len() as u32); - self.store_out_of_order_segment(seg_start, seg_end, data); + cb.receiver.store_out_of_order_segment(seg_start, seg_end, data); // Sending an ACK here is only a "MAY" according to the RFCs, but helpful for fast retransmit. trace!("process_data(): send ack on out-of-order segment"); - cb.send_ack(); + Sender::send_ack(cb, layer3_endpoint); }, state => warn!("Ignoring data received after FIN (in state {:?}).", state), } @@ -460,7 +523,7 @@ impl Receiver { } // We can only legitimately receive data in ESTABLISHED, FIN-WAIT-1, and FIN-WAIT-2. - self.receive_data(seg_start, data); + cb.receiver.receive_data(seg_start, data); Ok(()) } @@ -567,80 +630,22 @@ impl Receiver { } } - // This routine takes an incoming in-order TCP segment and adds the data to the user's receive queue. If the new - // segment fills a "hole" in the receive sequence number space allowing previously stored out-of-order data to now - // be received, it receives that too. - // - // This routine also updates receive_next to reflect any data now considered "received". - // - // Returns true if a previously out-of-order segment containing a FIN has now been received. - // - fn receive_data(&mut self, seg_start: SeqNumber, buf: DemiBuffer) { - // This routine should only be called with in-order segment data. - debug_assert_eq!(seg_start, self.receive_next_seq_no); - - // Push the new segment data onto the end of the receive queue. - self.receive_next_seq_no = self.receive_next_seq_no + SeqNumber::from(buf.len() as u32); - // This inserts the segment and wakes a waiting pop coroutine. - debug!("pushing buffer"); - self.pop_queue.push(buf); - - // Okay, we've successfully received some new data. Check if any of the formerly out-of-order data waiting in - // the out-of-order queue is now in-order. If so, we can move it to the receive queue. - while !self.out_of_order_frames.is_empty() { - if let Some(stored_entry) = self.out_of_order_frames.front() { - if stored_entry.0 == self.receive_next_seq_no { - // Move this entry's buffer from the out-of-order store to the receive queue. - // This data is now considered to be "received" by TCP, and included in our RCV.NXT calculation. - debug!("Recovering out-of-order packet at {}", self.receive_next_seq_no); - if let Some(temp) = self.out_of_order_frames.pop_front() { - self.receive_next_seq_no = self.receive_next_seq_no + SeqNumber::from(temp.1.len() as u32); - // This inserts the segment and wakes a waiting pop coroutine. - self.pop_queue.push(temp.1); - } - } else { - // Since our out-of-order list is sorted, we can stop when the next segment is not in sequence. - break; - } - } - } - } - - pub fn set_receive_ack_deadline(&mut self, ack_deadline_timeout_secs: Option) { - self.ack_deadline_time_secs.set(ack_deadline_timeout_secs) - } - - fn process_fin(&mut self, cb: &mut SharedControlBlock) { - cb.process_fin(); - self.push_fin(); - } - - pub fn receive_next_seq_no(&self) -> SeqNumber { - self.receive_next_seq_no - } - - pub fn get_receive_window_size(&self) -> u32 { - let bytes_unread: u32 = (self.receive_next_seq_no - self.reader_next_seq_no).into(); - self.buffer_size_frames - bytes_unread - } - - pub fn hdr_window_size(&self) -> u16 { - let window_size: u32 = self.get_receive_window_size(); - let hdr_window_size: u16 = expect_ok!( - (window_size >> self.window_scale_shift_bits).try_into(), - "Window size overflow" - ); - debug!( - "Window size -> {} (hdr {}, scale {})", - (hdr_window_size as u32) << self.window_scale_shift_bits, - hdr_window_size, - self.window_scale_shift_bits, - ); - hdr_window_size + fn process_fin(cb: &mut ControlBlock) { + let state = match cb.state { + State::Established => State::CloseWait, + State::FinWait1 => State::Closing, + State::FinWait2 => State::TimeWait, + state => unreachable!("Cannot be in any other state at this point: {:?}", state), + }; + cb.state = state; + cb.receiver.push_fin(); } - pub async fn acknowledger(&mut self, mut cb: SharedControlBlock) -> Result { - let mut ack_deadline: SharedAsyncValue> = self.ack_deadline_time_secs.clone(); + pub async fn acknowledger( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + ) -> Result { + let mut ack_deadline: SharedAsyncValue> = cb.receiver.ack_deadline_time_secs.clone(); let mut deadline: Option = ack_deadline.get(); loop { // TODO: Implement TCP delayed ACKs, subject to restrictions from RFC 1122 @@ -654,7 +659,7 @@ impl Receiver { continue; }, Err(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => { - cb.send_ack(); + Sender::send_ack(cb, layer3_endpoint); deadline = ack_deadline.get(); }, Err(_) => { diff --git a/src/rust/inetstack/protocols/layer4/tcp/established/sender.rs b/src/rust/inetstack/protocols/layer4/tcp/established/sender.rs index 730f4f879..d88787b7c 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/established/sender.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/established/sender.rs @@ -7,21 +7,24 @@ use crate::{ collections::{async_queue::SharedAsyncQueue, async_value::SharedAsyncValue}, - inetstack::protocols::layer4::tcp::{ - established::{rto::RtoCalculator, SharedControlBlock}, - header::TcpHeader, - SeqNumber, + inetstack::protocols::{ + layer3::SharedLayer3Endpoint, + layer4::tcp::{ + established::{rto::RtoCalculator, ControlBlock}, + header::TcpHeader, + SeqNumber, + }, + MAX_HEADER_SIZE, }, - runtime::{conditional_yield_until, fail::Fail, memory::DemiBuffer}, + runtime::{conditional_yield_until, fail::Fail, memory::DemiBuffer, SharedDemiRuntime}, }; -use ::futures::{pin_mut, select_biased, FutureExt}; +use ::futures::{never::Never, pin_mut, select_biased, FutureExt}; use ::libc::{EBUSY, EINVAL}; use ::std::{ - fmt, + cmp, fmt, + net::Ipv4Addr, time::{Duration, Instant}, }; -use futures::never::Never; -use std::cmp; //====================================================================================================================== // Data Structures @@ -50,8 +53,6 @@ const MIN_UNACKED_QUEUE_SIZE_FRAMES: usize = 64; // of the unacked queue, below which memory allocation is not required. const MIN_UNSENT_QUEUE_SIZE_FRAMES: usize = 64; -// TODO: Consider moving retransmit timer and congestion control fields out of this structure. -// TODO: Make all public fields in this structure private. pub struct Sender { // // Send Sequence Space: @@ -81,7 +82,7 @@ pub struct Sender { rto_calculator: RtoCalculator, // In RFC 793 terms, this is SND.NXT. - send_next_seq_no: SharedAsyncValue, + pub send_next_seq_no: SharedAsyncValue, // Sequence number of next data to be pushed but not sent. When there is an open window, this is equivalent to // send_next_seq_no. @@ -130,10 +131,93 @@ impl Sender { } } + fn process_acked_fin(&mut self, bytes_remaining: usize, ack_num: SeqNumber) -> usize { + // This buffer is the end-of-send marker. So we should only have one byte of acknowledged + // sequence space remaining (corresponding to our FIN). + debug_assert_eq!(bytes_remaining, 1); + + // Double check that the ack is for the FIN sequence number. + debug_assert_eq!( + ack_num, + self.fin_seq_no + .map(|s| { s + 1.into() }) + .expect("should have a FIN set") + ); + 0 + } + + fn process_acked_segment(&mut self, bytes_remaining: usize, mut segment: UnackedSegment, now: Instant) -> usize { + // Add sample for RTO if we have an initial transmit time. + // Note that in the case of repacketization, an ack for the first byte is enough for the time sample because it still represents the RTO for that single byte. + // TODO: TCP timestamp support. + if let Some(initial_tx) = segment.initial_tx { + self.rto_calculator.add_sample(now - initial_tx); + } + + let mut data: DemiBuffer = segment + .bytes + .take() + .expect("there should be data because this is not a FIN."); + if data.len() > bytes_remaining { + // Put this segment on the unacknowledged list. + let unacked_segment = UnackedSegment { + bytes: Some( + data.split_back(bytes_remaining) + .expect("Should be able to split back because we just checked the length"), + ), + initial_tx: None, + }; + // Leave this segment on the unacknowledged queue. + self.unacked_queue.push_front(unacked_segment); + 0 + } else { + bytes_remaining - data.len() + } + } + + fn update_retransmit_deadline(&mut self, now: Instant) -> Option { + match self.unacked_queue.get_front() { + Some(UnackedSegment { + bytes: _, + initial_tx: Some(initial_tx), + }) => Some(*initial_tx + self.rto_calculator.rto()), + Some(UnackedSegment { + bytes: _, + initial_tx: None, + }) => Some(now + self.rto_calculator.rto()), + None => None, + } + } + + fn update_send_window(&mut self, header: &TcpHeader) { + // Make sure the ack num is bigger than the last one that we used to update the send window. + if self.send_window_last_update_seq < header.seq_num + || (self.send_window_last_update_seq == header.seq_num + && self.send_window_last_update_ack <= header.ack_num) + { + self.send_window + .set((header.window_size as u32) << self.send_window_scale_shift_bits); + self.send_window_last_update_seq = header.seq_num; + self.send_window_last_update_ack = header.ack_num; + + debug!( + "Updating window size -> {} (hdr {}, scale {})", + self.send_window.get(), + header.window_size, + self.send_window_scale_shift_bits, + ); + } + } + // This function sends a packet and waits for it to be acked. - pub async fn push(&mut self, mut buf: DemiBuffer, mut cb: SharedControlBlock) -> Result<(), Fail> { + pub async fn push( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + runtime: &mut SharedDemiRuntime, + mut buf: DemiBuffer, + ) -> Result<(), Fail> { // If the user is done sending (i.e. has called close on this connection), then they shouldn't be sending. - debug_assert!(self.fin_seq_no.is_none()); + debug_assert!(cb.sender.fin_seq_no.is_none()); // Our API supports send buffers up to usize (variable, depends upon architecture) in size. While we could // allow for larger send buffers, it is simpler and more practical to limit a single send to 1 GiB, which is // also the maximum value a TCP can advertise as its receive window (with maximum window scaling). @@ -147,22 +231,22 @@ impl Sender { .map_err(|_| Fail::new(EINVAL, "buffer too large"))?; // TODO: We need to fix this the correct way: limit our send buffer size to the amount we're willing to buffer. - if self.unsent_queue.len() > UNSENT_QUEUE_CUTOFF { + if cb.sender.unsent_queue.len() > UNSENT_QUEUE_CUTOFF { return Err(Fail::new(EBUSY, "too many packets to send")); } // Place the buffer in the unsent queue. - self.unsent_next_seq_no = self.unsent_next_seq_no + (buf.len() as u32).into(); - if self.send_window.get() > 0 { - self.send_segment(&mut buf, &mut cb); + cb.sender.unsent_next_seq_no = cb.sender.unsent_next_seq_no + (buf.len() as u32).into(); + if cb.sender.send_window.get() > 0 { + Self::send_segment(cb, layer3_endpoint, runtime.get_now(), &mut buf); } if buf.len() > 0 { - self.unsent_queue.push(Some(buf)); + cb.sender.unsent_queue.push(Some(buf)); } // Wait until the sequnce number of the pushed buffer is acknowledged. - let mut send_unacked_watched: SharedAsyncValue = self.send_unacked.clone(); - let ack_seq_no: SeqNumber = self.unsent_next_seq_no; + let mut send_unacked_watched: SharedAsyncValue = cb.sender.send_unacked.clone(); + let ack_seq_no: SeqNumber = cb.sender.unsent_next_seq_no; debug_assert!(send_unacked_watched.get() < ack_seq_no); while send_unacked_watched.get() < ack_seq_no { send_unacked_watched.wait_for_change(None).await?; @@ -171,70 +255,77 @@ impl Sender { } // Places a FIN marker in the outgoing data stream. No data can be pushed after this. - pub async fn push_fin_and_wait_for_ack(&mut self) -> Result<(), Fail> { - debug_assert!(self.fin_seq_no.is_none()); + pub async fn push_fin_and_wait_for_ack(cb: &mut ControlBlock) -> Result<(), Fail> { + debug_assert!(cb.sender.fin_seq_no.is_none()); // TODO: We need to fix this the correct way: limit our send buffer size to the amount we're willing to buffer. - if self.unsent_queue.len() > UNSENT_QUEUE_CUTOFF { + if cb.sender.unsent_queue.len() > UNSENT_QUEUE_CUTOFF { return Err(Fail::new(EBUSY, "too many packets to send")); } - self.fin_seq_no = Some(self.unsent_next_seq_no); - self.unsent_next_seq_no = self.unsent_next_seq_no + 1.into(); - self.unsent_queue.push(None); + cb.sender.fin_seq_no = Some(cb.sender.unsent_next_seq_no); + cb.sender.unsent_next_seq_no = cb.sender.unsent_next_seq_no + 1.into(); + cb.sender.unsent_queue.push(None); - let mut send_unacked_watched: SharedAsyncValue = self.send_unacked.clone(); - let fin_ack_num: SeqNumber = self.unsent_next_seq_no; - while self.send_unacked.get() < fin_ack_num { + let mut send_unacked_watched: SharedAsyncValue = cb.sender.send_unacked.clone(); + let fin_ack_num: SeqNumber = cb.sender.unsent_next_seq_no; + while cb.sender.send_unacked.get() < fin_ack_num { send_unacked_watched.wait_for_change(None).await?; } Ok(()) } - pub async fn background_sender(&mut self, mut cb: SharedControlBlock) -> Result { + pub async fn background_sender( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + runtime: &mut SharedDemiRuntime, + ) -> Result { loop { // Get next bit of unsent data. - if let Some(buf) = self.unsent_queue.pop(None).await? { - self.send_buffer(buf, &mut cb).await?; + if let Some(buf) = cb.sender.unsent_queue.pop(None).await? { + Self::send_buffer(cb, layer3_endpoint, runtime.get_now(), buf).await?; } else { - let now: Instant = cb.get_now(); - self.send_fin(&mut cb, now)?; + Self::send_fin(cb, layer3_endpoint, runtime.get_now())?; // Exit the loop because we no longer have anything to process return Err(Fail::new(libc::ECONNRESET, "Processed and sent FIN")); } } } - fn send_fin(&mut self, cb: &mut SharedControlBlock, now: Instant) -> Result<(), Fail> { - let mut header: TcpHeader = cb.tcp_header(); - header.seq_num = self.send_next_seq_no.get(); - debug_assert!(self.fin_seq_no.is_some_and(|s| { s == header.seq_num })); + fn send_fin(cb: &mut ControlBlock, layer3_endpoint: &mut SharedLayer3Endpoint, now: Instant) -> Result<(), Fail> { + let mut header: TcpHeader = Self::tcp_header(cb, None); + debug_assert!(cb.sender.fin_seq_no.is_some_and(|s| { s == header.seq_num })); header.fin = true; - cb.emit(header, None); + Self::emit(cb, layer3_endpoint, header, None); // Update SND.NXT. - self.send_next_seq_no.modify(|s| s + 1.into()); + cb.sender.send_next_seq_no.modify(|s| s + 1.into()); // Add the FIN to our unacknowledged queue. let unacked_segment = UnackedSegment { bytes: None, initial_tx: Some(now), }; - self.unacked_queue.push(unacked_segment); + cb.sender.unacked_queue.push(unacked_segment); // Set the retransmit timer. - if self.retransmit_deadline_time_secs.get().is_none() { - let rto: Duration = self.rto_calculator.rto(); - trace!("set retransmit: {:?}", rto); - self.retransmit_deadline_time_secs.set(Some(now + rto)); + if cb.sender.retransmit_deadline_time_secs.get().is_none() { + let rto: Duration = cb.sender.rto_calculator.rto(); + cb.sender.retransmit_deadline_time_secs.set(Some(now + rto)); } Ok(()) } - async fn send_buffer(&mut self, mut buffer: DemiBuffer, cb: &mut SharedControlBlock) -> Result<(), Fail> { - let mut send_unacked_watched: SharedAsyncValue = self.send_unacked.clone(); - let mut cwnd_watched: SharedAsyncValue = cb.congestion_control_get_cwnd(); + async fn send_buffer( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + now: Instant, + mut buffer: DemiBuffer, + ) -> Result<(), Fail> { + let mut send_unacked_watched: SharedAsyncValue = cb.sender.send_unacked.clone(); + let mut cwnd_watched: SharedAsyncValue = cb.congestion_control_algorithm.get_cwnd(); // The limited transmit algorithm may increase the effective size of cwnd by up to 2 * mss. - let mut ltci_watched: SharedAsyncValue = cb.congestion_control_get_limited_transmit_cwnd_increase(); - let mut win_sz_watched: SharedAsyncValue = self.send_window.clone(); + let mut ltci_watched: SharedAsyncValue = + cb.congestion_control_algorithm.get_limited_transmit_cwnd_increase(); + let mut win_sz_watched: SharedAsyncValue = cb.sender.send_window.clone(); // Try in a loop until we send this segment. loop { @@ -242,13 +333,13 @@ impl Sender { // repeatedly send window probes until window opens up. if win_sz_watched.get() == 0 { // Send a window probe (this is a one-byte packet designed to elicit a window update from our peer). - self.send_window_probe(buffer.split_front(1)?, cb).await?; + Self::send_window_probe(cb, layer3_endpoint, now, buffer.split_front(1)?).await?; } else { // TODO: Nagle's algorithm - We need to coalese small buffers together to send MSS sized packets. // TODO: Silly window syndrome - See RFC 1122's discussion of the SWS avoidance algorithm. // We have some window, try to send some or all of the segment. - let _: usize = self.send_segment(&mut buffer, cb); + let _: usize = Self::send_segment(cb, layer3_endpoint, now, &mut buffer); // If the buffer is now empty, then we sent all of it. if buffer.len() == 0 { return Ok(()); @@ -257,7 +348,7 @@ impl Sender { // the segment. futures::select_biased! { _ = send_unacked_watched.wait_for_change(None).fuse() => (), - _ = self.send_next_seq_no.wait_for_change(None).fuse() => (), + _ = cb.sender.send_next_seq_no.wait_for_change(None).fuse() => (), _ = win_sz_watched.wait_for_change(None).fuse() => (), _ = cwnd_watched.wait_for_change(None).fuse() => (), _ = ltci_watched.wait_for_change(None).fuse() => (), @@ -266,44 +357,55 @@ impl Sender { } } - async fn send_window_probe(&mut self, probe: DemiBuffer, cb: &mut SharedControlBlock) -> Result<(), Fail> { + async fn send_window_probe( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + now: Instant, + probe: DemiBuffer, + ) -> Result<(), Fail> { // Update SND.NXT. - self.send_next_seq_no.modify(|s| s + SeqNumber::from(1)); + cb.sender.send_next_seq_no.modify(|s| s + SeqNumber::from(1)); // Add the probe byte (as a new separate buffer) to our unacknowledged queue. let unacked_segment = UnackedSegment { bytes: Some(probe.clone()), - initial_tx: Some(cb.get_now()), + initial_tx: Some(now), }; - self.unacked_queue.push(unacked_segment); + cb.sender.unacked_queue.push(unacked_segment); // Note that we loop here *forever*, exponentially backing off. // TODO: Use the correct PERSIST mode timer here. let mut timeout: Duration = Duration::from_secs(1); - let mut win_sz_watched: SharedAsyncValue = self.send_window.clone(); + let mut win_sz_watched: SharedAsyncValue = cb.sender.send_window.clone(); loop { // Create packet. - let mut header: TcpHeader = cb.tcp_header(); - header.seq_num = self.send_next_seq_no.get(); - cb.emit(header, Some(probe.clone())); + let header: TcpHeader = Self::tcp_header(cb, None); + Self::emit(cb, layer3_endpoint, header, Some(probe.clone())); match win_sz_watched.wait_for_change(Some(timeout)).await { Ok(_) => return Ok(()), Err(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => timeout *= 2, - Err(_) => unreachable!( - "either the ack deadline changed or the deadline passed, no other errors are possible!" - ), + Err(_) => { + unreachable!( + "either the ack deadline changed or the deadline passed, no other errors are possible!" + ) + }, } } } // Takes a segment and attempts to send it. The buffer must be non-zero length and the function returns the number // of bytes sent. - fn send_segment(&mut self, segment: &mut DemiBuffer, cb: &mut SharedControlBlock) -> usize { + fn send_segment( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + now: Instant, + segment: &mut DemiBuffer, + ) -> usize { let buf_len: usize = segment.len(); debug_assert_ne!(buf_len, 0); // Check window size. - let max_frame_size_bytes: usize = match self.get_open_window_size_bytes(cb) { + let max_frame_size_bytes: usize = match Self::get_open_window_size_bytes(cb) { 0 => return 0, size => size, }; @@ -325,52 +427,75 @@ impl Sender { let segment_data_len: u32 = segment_data.len() as u32; - let rto: Duration = self.rto_calculator.rto(); - cb.congestion_control_on_send(rto, (self.send_next_seq_no.get() - self.send_unacked.get()).into()); + let rto: Duration = cb.sender.rto_calculator.rto(); + cb.congestion_control_algorithm.on_send( + rto, + (cb.sender.send_next_seq_no.get() - cb.sender.send_unacked.get()).into(), + ); // Prepare the segment and send it. - let mut header: TcpHeader = cb.tcp_header(); - header.seq_num = self.send_next_seq_no.get(); + let mut header: TcpHeader = Self::tcp_header(cb, None); if do_push { header.psh = true; } - cb.emit(header, Some(segment_data.clone())); + Self::emit(cb, layer3_endpoint, header, Some(segment_data.clone())); // Update SND.NXT. - self.send_next_seq_no.modify(|s| s + SeqNumber::from(segment_data_len)); + cb.sender + .send_next_seq_no + .modify(|s| s + SeqNumber::from(segment_data_len)); // Put this segment on the unacknowledged list. let unacked_segment = UnackedSegment { bytes: Some(segment_data), - initial_tx: Some(cb.get_now()), + initial_tx: Some(now), }; - self.unacked_queue.push(unacked_segment); + cb.sender.unacked_queue.push(unacked_segment); // Set the retransmit timer. - if self.retransmit_deadline_time_secs.get().is_none() { - let rto: Duration = self.rto_calculator.rto(); - self.retransmit_deadline_time_secs.set(Some(cb.get_now() + rto)); + if cb.sender.retransmit_deadline_time_secs.get().is_none() { + let rto: Duration = cb.sender.rto_calculator.rto(); + cb.sender.retransmit_deadline_time_secs.set(Some(now + rto)); } segment_data_len as usize } - fn get_open_window_size_bytes(&mut self, cb: &mut SharedControlBlock) -> usize { + /// Fetch a TCP header filling out various values based on our current state. + /// If a sequence number is provided, use it otherwise, use the current unsent sequence number. + /// The only time that the unsent sequence number is not used is when we are retransmitting. + pub fn tcp_header(cb: &mut ControlBlock, seq_num: Option) -> TcpHeader { + let mut header: TcpHeader = TcpHeader::new(cb.local.port(), cb.remote.port()); + header.window_size = cb.receiver.hdr_window_size(); + + // Note that once we reach a synchronized state we always include a valid acknowledgement number. + header.ack = true; + header.ack_num = cb.receiver.receive_next_seq_no; + header.seq_num = seq_num.unwrap_or(cb.sender.send_next_seq_no.get()); + + // Return this header. + header + } + + fn get_open_window_size_bytes(cb: &mut ControlBlock) -> usize { // Calculate amount of data in flight (SND.NXT - SND.UNA). - let send_unacknowledged: SeqNumber = self.send_unacked.get(); - let send_next: SeqNumber = self.send_next_seq_no.get(); + let send_unacknowledged: SeqNumber = cb.sender.send_unacked.get(); + let send_next: SeqNumber = cb.sender.send_next_seq_no.get(); let sent_data: u32 = (send_next - send_unacknowledged).into(); // Before we get cwnd for the check, we prompt it to shrink it if the connection has been idle. - cb.congestion_control_on_cwnd_check_before_send(); - let cwnd: SharedAsyncValue = cb.congestion_control_get_cwnd(); + cb.congestion_control_algorithm.on_cwnd_check_before_send(); + let cwnd: SharedAsyncValue = cb.congestion_control_algorithm.get_cwnd(); // The limited transmit algorithm can increase the effective size of cwnd by up to 2MSS. - let effective_cwnd: u32 = cwnd.get() + cb.congestion_control_get_limited_transmit_cwnd_increase().get(); + let effective_cwnd: u32 = cwnd.get() + + cb.congestion_control_algorithm + .get_limited_transmit_cwnd_increase() + .get(); - let win_sz: u32 = self.send_window.get(); + let win_sz: u32 = cb.sender.send_window.get(); if Self::has_open_window(win_sz, sent_data, effective_cwnd) { - Self::calculate_open_window_bytes(win_sz, sent_data, self.mss, effective_cwnd) + Self::calculate_open_window_bytes(win_sz, sent_data, cb.sender.mss, effective_cwnd) } else { 0 } @@ -387,20 +512,26 @@ impl Sender { ) } - pub async fn background_retransmitter(&mut self, mut cb: SharedControlBlock) -> Result { + pub async fn background_retransmitter( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + runtime: &mut SharedDemiRuntime, + ) -> Result { // Watch the retransmission deadline. - let mut rtx_deadline_watched: SharedAsyncValue> = self.retransmit_deadline_time_secs.clone(); + let mut rtx_deadline_watched: SharedAsyncValue> = + cb.sender.retransmit_deadline_time_secs.clone(); // Watch the fast retransmit flag. - let mut rtx_fast_retransmit_watched: SharedAsyncValue = cb.congestion_control_watch_retransmit_now_flag(); + let mut rtx_fast_retransmit_watched: SharedAsyncValue = + cb.congestion_control_algorithm.get_retransmit_now_flag(); loop { let rtx_deadline: Option = rtx_deadline_watched.get(); let rtx_fast_retransmit: bool = rtx_fast_retransmit_watched.get(); if rtx_fast_retransmit { // Notify congestion control about fast retransmit. - cb.congestion_control_on_fast_retransmit(); + cb.congestion_control_algorithm.on_fast_retransmit(); // Retransmit earliest unacknowledged segment. - self.retransmit(&mut cb); + Self::retransmit(cb, layer3_endpoint); continue; } @@ -413,29 +544,26 @@ impl Sender { }; pin_mut!(something_changed); match conditional_yield_until(something_changed, rtx_deadline).await { - Ok(()) => match self.fin_seq_no { - Some(fin_seq_no) if self.send_unacked.get() > fin_seq_no => { + Ok(()) => match cb.sender.fin_seq_no { + Some(fin_seq_no) if cb.sender.send_unacked.get() > fin_seq_no => { return Err(Fail::new(libc::ECONNRESET, "connection closed")); }, _ => continue, }, Err(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => { // Retransmit timeout. - trace!("retransmit wake"); // Notify congestion control about RTO. - // TODO: Is this the best place for this? - // TODO: Why call into ControlBlock to get SND.UNA when congestion_control_on_rto() has access to it? - cb.congestion_control_on_rto(self.send_unacked.get()); + cb.congestion_control_algorithm.on_rto(cb.sender.send_unacked.get()); // RFC 6298 Section 5.4: Retransmit earliest unacknowledged segment. - self.retransmit(&mut cb); + Self::retransmit(cb, layer3_endpoint); // RFC 6298 Section 5.5: Back off the retransmission timer. - self.rto_calculator.back_off(); + cb.sender.rto_calculator.back_off(); // RFC 6298 Section 5.6: Restart the retransmission timer with the new RTO. - let deadline: Instant = cb.get_now() + self.rto_calculator.rto(); - self.retransmit_deadline_time_secs.set(Some(deadline)); + let deadline: Instant = runtime.get_now() + cb.sender.rto_calculator.rto(); + cb.sender.retransmit_deadline_time_secs.set(Some(deadline)); }, Err(_) => { unreachable!( @@ -447,8 +575,8 @@ impl Sender { } /// Retransmits the earliest segment that has not (yet) been acknowledged by our peer. - pub fn retransmit(&mut self, cb: &mut SharedControlBlock) { - match self.unacked_queue.get_front_mut() { + pub fn retransmit(cb: &mut ControlBlock, layer3_endpoint: &mut SharedLayer3Endpoint) { + match cb.sender.unacked_queue.get_front_mut() { Some(segment) => { // We're retransmitting this, so we can no longer use an ACK for it as an RTT measurement (as we can't // tell if the ACK is for the original or the retransmission). Remove the transmission timestamp from @@ -461,39 +589,39 @@ impl Sender { // TODO: Issue #198 Repacketization - we should send a full MSS (and set the FIN flag if applicable). // Prepare and send the segment. - let mut header: TcpHeader = cb.tcp_header(); - header.seq_num = self.send_unacked.get(); + let mut header: TcpHeader = Self::tcp_header(cb, Some(cb.sender.send_unacked.get())); // If data exists, then this is a regular packet, otherwise, its a FIN. if data.is_some() { header.psh = true; } else { header.fin = true; } - cb.emit(header, data); + Self::emit(cb, layer3_endpoint, header, data); }, None => (), } } // Process an ack. - pub fn process_ack(&mut self, header: &TcpHeader, now: Instant) { + pub fn process_ack(cb: &mut ControlBlock, header: &TcpHeader, now: Instant) { // Start by checking that the ACK acknowledges something new. - // TODO: Look into removing Watched types. - let send_unacknowledged: SeqNumber = self.send_unacked.get(); + let send_unacknowledged: SeqNumber = cb.sender.send_unacked.get(); if send_unacknowledged < header.ack_num { // Remove the now acknowledged data from the unacknowledged queue, update the acked sequence number // and update the sender window. // Convert the difference in sequence numbers into a u32. - let bytes_acknowledged: u32 = (header.ack_num - self.send_unacked.get()).into(); + let bytes_acknowledged: u32 = (header.ack_num - cb.sender.send_unacked.get()).into(); // Convert that into a usize for counting bytes to remove from the unacked queue. let mut bytes_remaining: usize = bytes_acknowledged as usize; // Remove bytes from the unacked queue. while bytes_remaining != 0 { - bytes_remaining = match self.unacked_queue.try_pop() { - Some(segment) if segment.bytes.is_none() => self.process_acked_fin(bytes_remaining, header.ack_num), - Some(segment) => self.process_acked_segment(bytes_remaining, segment, now), + bytes_remaining = match cb.sender.unacked_queue.try_pop() { + Some(segment) if segment.bytes.is_none() => { + cb.sender.process_acked_fin(bytes_remaining, header.ack_num) + }, + Some(segment) => cb.sender.process_acked_segment(bytes_remaining, segment, now), None => { unreachable!("There should be enough data in the unacked_queue for the number of bytes acked") }, // Shouldn't have bytes_remaining with no segments remaining in unacked_queue. @@ -501,19 +629,21 @@ impl Sender { } // Update SND.UNA to SEG.ACK. - self.send_unacked.set(header.ack_num); + cb.sender.send_unacked.set(header.ack_num); // Check and update send window if necessary. - self.update_send_window(header); + cb.sender.update_send_window(header); // Reset the retransmit timer if necessary. If there is more data that hasn't been acked, then set to the // next segment deadline, otherwise, do not set. - let retransmit_deadline_time_secs: Option = self.update_retransmit_deadline(now); + let retransmit_deadline_time_secs: Option = cb.sender.update_retransmit_deadline(now); #[cfg(debug_assertions)] if retransmit_deadline_time_secs.is_none() { - debug_assert_eq!(self.send_next_seq_no.get(), header.ack_num); + debug_assert_eq!(cb.sender.send_next_seq_no.get(), header.ack_num); } - self.retransmit_deadline_time_secs.set(retransmit_deadline_time_secs); + cb.sender + .retransmit_deadline_time_secs + .set(retransmit_deadline_time_secs); } else { // Duplicate ACK (doesn't acknowledge anything new). We can mostly ignore this, except for fast-retransmit. // TODO: Implement fast-retransmit. In which case, we'd increment our dup-ack counter here. @@ -521,100 +651,56 @@ impl Sender { } } - fn process_acked_fin(&mut self, bytes_remaining: usize, ack_num: SeqNumber) -> usize { - // This buffer is the end-of-send marker. So we should only have one byte of acknowledged - // sequence space remaining (corresponding to our FIN). - debug_assert_eq!(bytes_remaining, 1); - - // Double check that the ack is for the FIN sequence number. - debug_assert_eq!( - ack_num, - self.fin_seq_no - .map(|s| { s + 1.into() }) - .expect("should have a FIN set") - ); - 0 + /// Send an ACK to our peer, reflecting our current state. + pub fn send_ack(cb: &mut ControlBlock, layer3_endpoint: &mut SharedLayer3Endpoint) { + trace!("sending ack"); + let header: TcpHeader = Self::tcp_header(cb, None); + Self::emit(cb, layer3_endpoint, header, None); } - fn process_acked_segment(&mut self, bytes_remaining: usize, mut segment: UnackedSegment, now: Instant) -> usize { - // Add sample for RTO if we have an initial transmit time. - // Note that in the case of repacketization, an ack for the first byte is enough for the time sample because it still represents the RTO for that single byte. - // TODO: TCP timestamp support. - if let Some(initial_tx) = segment.initial_tx { - self.rto_calculator.add_sample(now - initial_tx); - } - - let mut data: DemiBuffer = segment - .bytes - .take() - .expect("there should be data because this is not a FIN."); - if data.len() > bytes_remaining { - // Put this segment on the unacknowledged list. - let unacked_segment = UnackedSegment { - bytes: Some( - data.split_back(bytes_remaining) - .expect("Should be able to split back because we just checked the length"), - ), - initial_tx: None, - }; - // Leave this segment on the unacknowledged queue. - self.unacked_queue.push_front(unacked_segment); - 0 - } else { - bytes_remaining - data.len() - } - } + /// Transmit this message to our connected peer. + pub fn emit( + cb: &mut ControlBlock, + layer3_endpoint: &mut SharedLayer3Endpoint, + header: TcpHeader, + body: Option, + ) { + // Only perform this debug print in debug builds. debug_assertions is compiler set in non-optimized builds. + let mut pkt = match body { + Some(body) => { + debug!("Sending {} bytes + {:?}", body.len(), header); + body + }, + _ => { + debug!("Sending 0 bytes + {:?}", header); + DemiBuffer::new_with_headroom(0, MAX_HEADER_SIZE as u16) + }, + }; - fn update_retransmit_deadline(&self, now: Instant) -> Option { - match self.unacked_queue.get_front() { - Some(UnackedSegment { - bytes: _, - initial_tx: Some(initial_tx), - }) => Some(*initial_tx + self.rto_calculator.rto()), - Some(UnackedSegment { - bytes: _, - initial_tx: None, - }) => Some(now + self.rto_calculator.rto()), - None => None, - } - } + // This routine should only ever be called to send TCP segments that contain a valid ACK value. + debug_assert!(header.ack); - fn update_send_window(&mut self, header: &TcpHeader) { - // Make sure the ack num is bigger than the last one that we used to update the send window. - if self.send_window_last_update_seq < header.seq_num - || (self.send_window_last_update_seq == header.seq_num - && self.send_window_last_update_ack <= header.ack_num) - { - self.send_window - .set((header.window_size as u32) << self.send_window_scale_shift_bits); - self.send_window_last_update_seq = header.seq_num; - self.send_window_last_update_ack = header.ack_num; + let remote_ipv4_addr: Ipv4Addr = cb.remote.ip().clone(); + header.serialize_and_attach( + &mut pkt, + cb.local.ip(), + cb.remote.ip(), + cb.tcp_config.get_tx_checksum_offload(), + ); - debug!( - "Updating window size -> {} (hdr {}, scale {})", - self.send_window.get(), - header.window_size, - self.send_window_scale_shift_bits, - ); + // Call lower L3 layer to send the segment. + if let Err(e) = layer3_endpoint.transmit_tcp_packet_nonblocking(remote_ipv4_addr, pkt) { + warn!("could not emit packet: {:?}", e); + return; } - } - // Get SD.UNA. - pub fn get_unacked_seq_no(&self) -> SeqNumber { - self.send_unacked.get() - } + // Post-send operations follow. + // Review: We perform these after the send, in order to keep send latency as low as possible. - // Get SND.NXT. - pub fn get_next_seq_no(&self) -> SeqNumber { - self.send_next_seq_no.get() - } - - // Get the current estimate of RTO. - pub fn get_rto(&self) -> Duration { - self.rto_calculator.rto() + // Since we sent an ACK, cancel any outstanding delayed ACK request. + cb.receiver.ack_deadline_time_secs.set(None); } } - //====================================================================================================================== // Trait Implementations //====================================================================================================================== diff --git a/src/rust/inetstack/protocols/layer4/tcp/passive_open.rs b/src/rust/inetstack/protocols/layer4/tcp/passive_open.rs index 31cf886cb..f63fdcdcf 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/passive_open.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/passive_open.rs @@ -17,7 +17,7 @@ use crate::{ constants::FALLBACK_MSS, established::{ congestion_control::{self, CongestionControl}, - EstablishedSocket, + SharedEstablishedSocket, }, header::{TcpHeader, TcpOptions2}, isn_generator::IsnGenerator, @@ -59,7 +59,7 @@ pub struct PassiveSocket { // TCP Connection State. state: SharedAsyncValue, connections: HashMap>, - ready: AsyncQueue<(SocketAddrV4, Result)>, + ready: AsyncQueue<(SocketAddrV4, Result)>, max_backlog: usize, isn_generator: IsnGenerator, local: SocketAddrV4, @@ -90,7 +90,7 @@ impl SharedPassiveSocket { Ok(Self(SharedObject::::new(PassiveSocket { state: SharedAsyncValue::new(State::Listening), connections: HashMap::>::new(), - ready: AsyncQueue::<(SocketAddrV4, Result)>::default(), + ready: AsyncQueue::<(SocketAddrV4, Result)>::default(), max_backlog, isn_generator: IsnGenerator::new(nonce), local, @@ -107,7 +107,7 @@ impl SharedPassiveSocket { } /// Accept a new connection by fetching one from the queue of requests, blocking if there are no new requests. - pub async fn do_accept(&mut self) -> Result { + pub async fn do_accept(&mut self) -> Result { let (_, new_socket) = self.ready.pop(None).await?; new_socket } @@ -131,7 +131,7 @@ impl SharedPassiveSocket { // See if this packet is for an already established but not accepted socket. if let Some((_, socket)) = self.ready.get_values().find(|(addr, _)| *addr == remote) { if let Ok(socket) = socket { - socket.get_cb().receive(tcp_hdr, buf); + socket.clone().receive(tcp_hdr, buf); } return; } @@ -366,7 +366,7 @@ impl SharedPassiveSocket { header_window_size: u16, remote_window_scale: Option, mss: usize, - ) -> Result { + ) -> Result { let (ipv4_hdr, tcp_hdr, buf) = recv_queue.pop(None).await?; debug!("Received ACK: {:?}", tcp_hdr); @@ -418,7 +418,7 @@ impl SharedPassiveSocket { recv_queue.push((ipv4_hdr, tcp_hdr, buf)); } - let new_socket: EstablishedSocket = EstablishedSocket::new( + let new_socket: SharedEstablishedSocket = SharedEstablishedSocket::new( self.local, remote, self.runtime.clone(), @@ -441,7 +441,7 @@ impl SharedPassiveSocket { Ok(new_socket) } - fn complete_handshake(&mut self, remote: SocketAddrV4, result: Result) { + fn complete_handshake(&mut self, remote: SocketAddrV4, result: Result) { self.connections.remove(&remote); self.ready.push((remote, result)); } diff --git a/src/rust/inetstack/protocols/layer4/tcp/socket.rs b/src/rust/inetstack/protocols/layer4/tcp/socket.rs index bff256db2..a4c495265 100644 --- a/src/rust/inetstack/protocols/layer4/tcp/socket.rs +++ b/src/rust/inetstack/protocols/layer4/tcp/socket.rs @@ -10,7 +10,7 @@ use crate::{ inetstack::protocols::{ layer3::SharedLayer3Endpoint, layer4::tcp::{ - active_open::SharedActiveOpenSocket, established::EstablishedSocket, header::TcpHeader, + active_open::SharedActiveOpenSocket, established::SharedEstablishedSocket, header::TcpHeader, passive_open::SharedPassiveSocket, SeqNumber, }, }, @@ -42,8 +42,8 @@ pub enum SocketState { Bound(SocketAddrV4), Listening(SharedPassiveSocket), Connecting(SharedActiveOpenSocket), - Established(EstablishedSocket), - Closing(EstablishedSocket), + Established(SharedEstablishedSocket), + Closing(SharedEstablishedSocket), } //====================================================================================================================== @@ -83,7 +83,7 @@ impl SharedTcpSocket { } pub fn new_established( - socket: EstablishedSocket, + socket: SharedEstablishedSocket, runtime: SharedDemiRuntime, layer3_endpoint: SharedLayer3Endpoint, tcp_config: TcpConfig, @@ -163,7 +163,7 @@ impl SharedTcpSocket { SocketState::Listening(ref listening_socket) => listening_socket.clone(), _ => unreachable!("State machine check should ensure that this socket is listening"), }; - let new_socket: EstablishedSocket = listening_socket.do_accept().await?; + let new_socket: SharedEstablishedSocket = listening_socket.do_accept().await?; // Insert queue into queue table and get new queue descriptor. let new_queue = Self::new_established( new_socket, @@ -289,8 +289,8 @@ impl SharedTcpSocket { }, SocketState::Listening(ref mut socket) => socket.receive(ip_hdr, tcp_hdr, buf), SocketState::Connecting(ref mut socket) => socket.receive(ip_hdr, tcp_hdr, buf), - SocketState::Established(ref socket) => socket.get_cb().receive(tcp_hdr, buf), - SocketState::Closing(ref socket) => socket.get_cb().receive(tcp_hdr, buf), + SocketState::Established(ref mut socket) => socket.receive(tcp_hdr, buf), + SocketState::Closing(ref mut socket) => socket.receive(tcp_hdr, buf), } }