Skip to content

Commit

Permalink
server: introduced reject packet sender
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaobo Liu <cppcoffee@gmail.com>
  • Loading branch information
cppcoffee committed Jan 15, 2024
1 parent 01f89d0 commit c934805
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 111 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ crypto = { workspace = true }
# data struct
#dashmap = "5.5"
hashbrown = "0.14"
arc-swap = "1.6"

# logger
log = { version = "0.4" }
Expand Down
16 changes: 10 additions & 6 deletions server/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use anyhow::Result;
use clap::Parser;
use signal_hook::consts::TERM_SIGNALS;
use signal_hook::consts::{SIGHUP, SIGINT, SIGQUIT, SIGTERM};
use signal_hook::iterator::Signals;
use tikv_jemallocator::Jemalloc;
use tracing::{debug, info};
Expand Down Expand Up @@ -44,19 +44,22 @@ fn main() -> Result<()> {
util::set_process_priority(config.setting.worker_priority);
util::set_rlimit_nofile(config.setting.worker_rlimit_nofile)?;

let mut workers = Vec::new();
let conntrack_map = Arc::new(ConntrackMap::new());

let queue_count = config.setting.queue_start + config.setting.queue_count;
for queue_num in config.setting.queue_start..=queue_count {
let worker = Worker::new(config.clone(), queue_num, conntrack_map.clone())?;
worker.start()?;
let w = Worker::new(config.clone(), queue_num, conntrack_map.clone())?;
w.start()?;

workers.push(w);
}

ConntrackReclaim::new(config.clone(), conntrack_map.clone()).start();

iptables::rules_create(&config)?;

wait_for_signal()?;
wait_for_signal(&args, &workers)?;

iptables::rules_destroy(&config)?;

Expand All @@ -65,8 +68,9 @@ fn main() -> Result<()> {
Ok(())
}

fn wait_for_signal() -> Result<()> {
let sigs = TERM_SIGNALS;
fn wait_for_signal(args: &Args, workers: &[Worker]) -> Result<()> {
let sigs = vec![SIGTERM, SIGQUIT, SIGINT];

let mut signals = Signals::new(sigs)?;

for signal in &mut signals {
Expand Down
2 changes: 1 addition & 1 deletion server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pub mod worker;

pub use config::{Config, Protocol};
pub use conntrack::{ConntrackEntry, ConntrackMap, ConntrackReclaim};
pub use reject::Sender;
pub use reject::RejectPacketSender;
pub use worker::Worker;
194 changes: 141 additions & 53 deletions server/src/reject.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::mem::MaybeUninit;
use std::net::IpAddr;
use std::ptr;
use std::sync::mpsc;
use std::{ptr, thread};

use anyhow::{bail, Result};
use anyhow::{anyhow, bail, Result};
use log::error;
use pnet::packet::icmp::MutableIcmpPacket;
use pnet::packet::icmpv6::MutableIcmpv6Packet;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::tcp::TcpPacket;
use pnet::packet::udp::UdpPacket;
use pnet::packet::tcp::{MutableTcpPacket, TcpPacket};
use pnet::transport::{
transport_channel, TransportChannelType, TransportProtocol, TransportSender,
};
Expand All @@ -17,52 +20,25 @@ const UDP_CHECKSUM_OFFSET: usize = 6;

const BUFFER_SIZE: usize = 128;

pub struct Sender {
icmp: TransportSender,
tcp: TransportSender,
icmpv6: TransportSender,
tcp6: TransportSender,
#[derive(Clone)]
pub struct RejectPacketSender {
inner: mpsc::Sender<Message>,
}

impl Sender {
impl RejectPacketSender {
pub fn new() -> Result<Self> {
let (icmp, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Icmp)),
)?;
let tx = Sender::new()?.start();

let (tcp, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Tcp)),
)?;

let (icmpv6, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Icmpv6)),
)?;

let (tcp6, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Tcp)),
)?;

Ok(Self {
icmp,
tcp,
icmpv6,
tcp6,
})
Ok(Self { inner: tx })
}

pub fn emit_icmp_unreachable(
&mut self,
&self,
source: &IpAddr,
destination: &IpAddr,
ip_packet: &[u8],
udp_header: &UdpPacket,
udp_packet_header: &[u8],
) -> Result<()> {
let udp_packet_header = util::packet_header(udp_header);

let length = ICMP_UNREACHABLE_HEADER_SIZE + ip_packet.len() + udp_packet_header.len();
if length >= BUFFER_SIZE {
bail!("Packet too large")
Expand All @@ -86,12 +62,26 @@ impl Sender {

match (source, destination) {
(IpAddr::V4(_), IpAddr::V4(_)) => {
let icmp_packet = util::build_icmpv4_unreachable(&mut buffer[..length])?;
self.icmp.send_to(icmp_packet, *destination)?;
let mut icmp_packet = MutableIcmpPacket::owned(buffer[..length].to_vec())
.ok_or(anyhow!("Failed to create ICMP packet"))?;

util::build_icmpv4_unreachable(&mut icmp_packet);

self.inner.send(Message::Icmp {
destination: *destination,
icmp_packet,
})?;
}
(IpAddr::V6(src), IpAddr::V6(dest)) => {
let icmp_packet = util::build_icmpv6_unreachable(&mut buffer[..length], src, dest)?;
self.icmpv6.send_to(icmp_packet, *destination)?;
let mut icmp_packet = MutableIcmpv6Packet::owned(buffer[..length].to_vec())
.ok_or(anyhow!("Failed to create ICMP packet"))?;

util::build_icmpv6_unreachable(&mut icmp_packet, src, dest);

self.inner.send(Message::Icmpv6 {
destination: *destination,
icmp_packet,
})?;
}
_ => bail!("IP version mismatch"),
}
Expand All @@ -100,25 +90,123 @@ impl Sender {
}

pub fn emit_tcp_rst(
&mut self,
&self,
destination: &IpAddr,
source: &IpAddr,
tcp_header: &TcpPacket,
) -> Result<()> {
let tcp_min_size = TcpPacket::minimum_packet_size();

let buffer = MaybeUninit::<[u8; BUFFER_SIZE]>::uninit();
let mut buffer = unsafe { ptr::read(buffer.as_ptr() as *const [u8; BUFFER_SIZE]) };
let buffer = Box::new(MaybeUninit::<[u8; BUFFER_SIZE]>::uninit());
let buffer = unsafe { ptr::read(buffer.as_ptr() as *const [u8; BUFFER_SIZE]) };

let tcp_reset_packet =
util::build_tcp_reset(&mut buffer[..tcp_min_size], destination, source, tcp_header)?;
let mut tcp_packet = MutableTcpPacket::owned(buffer[..tcp_min_size].to_vec())
.ok_or(anyhow!("Failed to create TCP packet"))?;

if destination.is_ipv4() {
self.tcp.send_to(tcp_reset_packet, *destination)?;
} else {
self.tcp6.send_to(tcp_reset_packet, *destination)?;
}
util::build_tcp_reset(&mut tcp_packet, destination, source, tcp_header)?;

self.inner.send(Message::Tcp {
destination: *destination,
tcp_packet,
})?;

Ok(())
}
}

enum Message {
Icmp {
destination: IpAddr,
icmp_packet: MutableIcmpPacket<'static>,
},
Icmpv6 {
destination: IpAddr,
icmp_packet: MutableIcmpv6Packet<'static>,
},
Tcp {
destination: IpAddr,
tcp_packet: MutableTcpPacket<'static>,
},
}

struct Sender {
icmp: TransportSender,
tcp: TransportSender,
icmpv6: TransportSender,
tcp6: TransportSender,
}

impl Sender {
pub fn new() -> Result<Self> {
let (icmp, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Icmp)),
)?;

let (tcp, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Tcp)),
)?;

let (icmpv6, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Icmpv6)),
)?;

let (tcp6, _) = transport_channel(
0,
TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Tcp)),
)?;

Ok(Self {
icmp,
icmpv6,
tcp,
tcp6,
})
}

pub fn start(mut self) -> mpsc::Sender<Message> {
let (tx, rx) = mpsc::channel::<Message>();

thread::spawn(move || {
for msg in rx {
match msg {
Message::Icmp {
destination,
icmp_packet,
} => {
if let Err(e) = self.icmp.send_to(icmp_packet, destination) {
error!("Failed to send ICMP packet: {}", e);
}
}
Message::Icmpv6 {
destination,
icmp_packet,
} => {
if let Err(e) = self.icmpv6.send_to(icmp_packet, destination) {
error!("Failed to send ICMPv6 packet: {}", e);
}
}
Message::Tcp {
destination,
tcp_packet,
} => match destination {
IpAddr::V4(_) => {
if let Err(e) = self.tcp.send_to(tcp_packet, destination) {
error!("Failed to send TCP packet: {}", e);
}
}
IpAddr::V6(_) => {
if let Err(e) = self.tcp6.send_to(tcp_packet, destination) {
error!("Failed to send TCP packet: {}", e);
}
}
},
}
}
});

tx
}
}
Loading

0 comments on commit c934805

Please sign in to comment.