diff --git a/Cargo.lock b/Cargo.lock index cbfb693..ff24399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,12 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "arc-swap" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" + [[package]] name = "autocfg" version = "1.1.0" @@ -728,6 +734,7 @@ name = "server" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "clap", "crypto", "hashbrown", diff --git a/server/Cargo.toml b/server/Cargo.toml index 388c344..aa1e262 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -30,6 +30,7 @@ crypto = { workspace = true } # data struct #dashmap = "5.5" hashbrown = "0.14" +arc-swap = "1.6" # logger log = { version = "0.4" } diff --git a/server/src/bin/server.rs b/server/src/bin/server.rs index 0bd0b80..276342d 100644 --- a/server/src/bin/server.rs +++ b/server/src/bin/server.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use clap::Parser; -use signal_hook::consts::TERM_SIGNALS; +use signal_hook::consts::{SIGHUP, SIGINT, SIGQUIT, SIGTERM}; use signal_hook::iterator::Signals; use tikv_jemallocator::Jemalloc; use tracing::{debug, info}; @@ -44,19 +44,22 @@ fn main() -> Result<()> { util::set_process_priority(config.setting.worker_priority); util::set_rlimit_nofile(config.setting.worker_rlimit_nofile)?; + let mut workers = Vec::new(); let conntrack_map = Arc::new(ConntrackMap::new()); let queue_count = config.setting.queue_start + config.setting.queue_count; for queue_num in config.setting.queue_start..=queue_count { - let worker = Worker::new(config.clone(), queue_num, conntrack_map.clone())?; - worker.start()?; + let w = Worker::new(config.clone(), queue_num, conntrack_map.clone())?; + w.start()?; + + workers.push(w); } ConntrackReclaim::new(config.clone(), conntrack_map.clone()).start(); iptables::rules_create(&config)?; - wait_for_signal()?; + wait_for_signal(&args, &workers)?; iptables::rules_destroy(&config)?; @@ -65,8 +68,9 @@ fn main() -> Result<()> { Ok(()) } -fn wait_for_signal() -> Result<()> { - let sigs = TERM_SIGNALS; +fn wait_for_signal(args: &Args, workers: &[Worker]) -> Result<()> { + let sigs = vec![SIGTERM, SIGQUIT, SIGINT]; + let mut signals = Signals::new(sigs)?; for signal in &mut signals { diff --git a/server/src/lib.rs b/server/src/lib.rs index 6700095..52989da 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -8,5 +8,5 @@ pub mod worker; pub use config::{Config, Protocol}; pub use conntrack::{ConntrackEntry, ConntrackMap, ConntrackReclaim}; -pub use reject::Sender; +pub use reject::RejectPacketSender; pub use worker::Worker; diff --git a/server/src/reject.rs b/server/src/reject.rs index 75417b9..36b6a1f 100644 --- a/server/src/reject.rs +++ b/server/src/reject.rs @@ -1,11 +1,14 @@ use std::mem::MaybeUninit; use std::net::IpAddr; -use std::ptr; +use std::sync::mpsc; +use std::{ptr, thread}; -use anyhow::{bail, Result}; +use anyhow::{anyhow, bail, Result}; +use log::error; +use pnet::packet::icmp::MutableIcmpPacket; +use pnet::packet::icmpv6::MutableIcmpv6Packet; use pnet::packet::ip::IpNextHeaderProtocols; -use pnet::packet::tcp::TcpPacket; -use pnet::packet::udp::UdpPacket; +use pnet::packet::tcp::{MutableTcpPacket, TcpPacket}; use pnet::transport::{ transport_channel, TransportChannelType, TransportProtocol, TransportSender, }; @@ -17,52 +20,25 @@ const UDP_CHECKSUM_OFFSET: usize = 6; const BUFFER_SIZE: usize = 128; -pub struct Sender { - icmp: TransportSender, - tcp: TransportSender, - icmpv6: TransportSender, - tcp6: TransportSender, +#[derive(Clone)] +pub struct RejectPacketSender { + inner: mpsc::Sender, } -impl Sender { +impl RejectPacketSender { pub fn new() -> Result { - let (icmp, _) = transport_channel( - 0, - TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Icmp)), - )?; + let tx = Sender::new()?.start(); - let (tcp, _) = transport_channel( - 0, - TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Tcp)), - )?; - - let (icmpv6, _) = transport_channel( - 0, - TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Icmpv6)), - )?; - - let (tcp6, _) = transport_channel( - 0, - TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Tcp)), - )?; - - Ok(Self { - icmp, - tcp, - icmpv6, - tcp6, - }) + Ok(Self { inner: tx }) } pub fn emit_icmp_unreachable( - &mut self, + &self, source: &IpAddr, destination: &IpAddr, ip_packet: &[u8], - udp_header: &UdpPacket, + udp_packet_header: &[u8], ) -> Result<()> { - let udp_packet_header = util::packet_header(udp_header); - let length = ICMP_UNREACHABLE_HEADER_SIZE + ip_packet.len() + udp_packet_header.len(); if length >= BUFFER_SIZE { bail!("Packet too large") @@ -86,12 +62,26 @@ impl Sender { match (source, destination) { (IpAddr::V4(_), IpAddr::V4(_)) => { - let icmp_packet = util::build_icmpv4_unreachable(&mut buffer[..length])?; - self.icmp.send_to(icmp_packet, *destination)?; + let mut icmp_packet = MutableIcmpPacket::owned(buffer[..length].to_vec()) + .ok_or(anyhow!("Failed to create ICMP packet"))?; + + util::build_icmpv4_unreachable(&mut icmp_packet); + + self.inner.send(Message::Icmp { + destination: *destination, + icmp_packet, + })?; } (IpAddr::V6(src), IpAddr::V6(dest)) => { - let icmp_packet = util::build_icmpv6_unreachable(&mut buffer[..length], src, dest)?; - self.icmpv6.send_to(icmp_packet, *destination)?; + let mut icmp_packet = MutableIcmpv6Packet::owned(buffer[..length].to_vec()) + .ok_or(anyhow!("Failed to create ICMP packet"))?; + + util::build_icmpv6_unreachable(&mut icmp_packet, src, dest); + + self.inner.send(Message::Icmpv6 { + destination: *destination, + icmp_packet, + })?; } _ => bail!("IP version mismatch"), } @@ -100,25 +90,123 @@ impl Sender { } pub fn emit_tcp_rst( - &mut self, + &self, destination: &IpAddr, source: &IpAddr, tcp_header: &TcpPacket, ) -> Result<()> { let tcp_min_size = TcpPacket::minimum_packet_size(); - let buffer = MaybeUninit::<[u8; BUFFER_SIZE]>::uninit(); - let mut buffer = unsafe { ptr::read(buffer.as_ptr() as *const [u8; BUFFER_SIZE]) }; + let buffer = Box::new(MaybeUninit::<[u8; BUFFER_SIZE]>::uninit()); + let buffer = unsafe { ptr::read(buffer.as_ptr() as *const [u8; BUFFER_SIZE]) }; - let tcp_reset_packet = - util::build_tcp_reset(&mut buffer[..tcp_min_size], destination, source, tcp_header)?; + let mut tcp_packet = MutableTcpPacket::owned(buffer[..tcp_min_size].to_vec()) + .ok_or(anyhow!("Failed to create TCP packet"))?; - if destination.is_ipv4() { - self.tcp.send_to(tcp_reset_packet, *destination)?; - } else { - self.tcp6.send_to(tcp_reset_packet, *destination)?; - } + util::build_tcp_reset(&mut tcp_packet, destination, source, tcp_header)?; + + self.inner.send(Message::Tcp { + destination: *destination, + tcp_packet, + })?; Ok(()) } } + +enum Message { + Icmp { + destination: IpAddr, + icmp_packet: MutableIcmpPacket<'static>, + }, + Icmpv6 { + destination: IpAddr, + icmp_packet: MutableIcmpv6Packet<'static>, + }, + Tcp { + destination: IpAddr, + tcp_packet: MutableTcpPacket<'static>, + }, +} + +struct Sender { + icmp: TransportSender, + tcp: TransportSender, + icmpv6: TransportSender, + tcp6: TransportSender, +} + +impl Sender { + pub fn new() -> Result { + let (icmp, _) = transport_channel( + 0, + TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Icmp)), + )?; + + let (tcp, _) = transport_channel( + 0, + TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Tcp)), + )?; + + let (icmpv6, _) = transport_channel( + 0, + TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Icmpv6)), + )?; + + let (tcp6, _) = transport_channel( + 0, + TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Tcp)), + )?; + + Ok(Self { + icmp, + icmpv6, + tcp, + tcp6, + }) + } + + pub fn start(mut self) -> mpsc::Sender { + let (tx, rx) = mpsc::channel::(); + + thread::spawn(move || { + for msg in rx { + match msg { + Message::Icmp { + destination, + icmp_packet, + } => { + if let Err(e) = self.icmp.send_to(icmp_packet, destination) { + error!("Failed to send ICMP packet: {}", e); + } + } + Message::Icmpv6 { + destination, + icmp_packet, + } => { + if let Err(e) = self.icmpv6.send_to(icmp_packet, destination) { + error!("Failed to send ICMPv6 packet: {}", e); + } + } + Message::Tcp { + destination, + tcp_packet, + } => match destination { + IpAddr::V4(_) => { + if let Err(e) = self.tcp.send_to(tcp_packet, destination) { + error!("Failed to send TCP packet: {}", e); + } + } + IpAddr::V6(_) => { + if let Err(e) = self.tcp6.send_to(tcp_packet, destination) { + error!("Failed to send TCP packet: {}", e); + } + } + }, + } + } + }); + + tx + } +} diff --git a/server/src/util.rs b/server/src/util.rs index f0e871b..ac67075 100644 --- a/server/src/util.rs +++ b/server/src/util.rs @@ -1,6 +1,6 @@ use std::net::{IpAddr, Ipv6Addr}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Result}; use pnet::packet::icmp::destination_unreachable::IcmpCodes; use pnet::packet::icmp::{checksum as icmp_checksum, IcmpTypes, MutableIcmpPacket}; use pnet::packet::icmpv6::{ @@ -11,48 +11,36 @@ use pnet::packet::Packet; const PORT_UNREACHABLE: u8 = 4; -pub fn build_icmpv4_unreachable<'a>(data: &'a mut [u8]) -> Result> { - let mut icmp_packet = - MutableIcmpPacket::new(&mut data[..]).ok_or(anyhow!("Failed to create ICMP packet"))?; - +pub fn build_icmpv4_unreachable(icmp_packet: &mut MutableIcmpPacket) { icmp_packet.set_icmp_type(IcmpTypes::DestinationUnreachable); icmp_packet.set_icmp_code(IcmpCodes::DestinationPortUnreachable); icmp_packet.set_payload(&[]); let checksum = icmp_checksum(&icmp_packet.to_immutable()); icmp_packet.set_checksum(checksum); - - Ok(icmp_packet) } -pub fn build_icmpv6_unreachable<'a>( - data: &'a mut [u8], +pub fn build_icmpv6_unreachable( + icmp_packet: &mut MutableIcmpv6Packet, src: &Ipv6Addr, dest: &Ipv6Addr, -) -> Result> { - let mut icmp_packet = - MutableIcmpv6Packet::new(&mut data[..]).ok_or(anyhow!("Failed to create ICMPv6 packet"))?; - +) { icmp_packet.set_icmpv6_type(Icmpv6Types::DestinationUnreachable); icmp_packet.set_icmpv6_code(Icmpv6Code(PORT_UNREACHABLE)); icmp_packet.set_payload(&[]); let checksum = icmp6_checksum(&icmp_packet.to_immutable(), src, dest); icmp_packet.set_checksum(checksum); - - Ok(icmp_packet) } -pub fn build_tcp_reset<'a>( - data: &'a mut [u8], +pub fn build_tcp_reset( + tcp_packet: &mut MutableTcpPacket, source: &IpAddr, destination: &IpAddr, tcp_header: &TcpPacket, -) -> Result> { - let header_length = (data.len() / 4) as u8; - - let mut tcp_packet = - MutableTcpPacket::new(&mut data[..]).ok_or(anyhow!("Failed to create TCP packet"))?; +) -> Result<()> { + let tcp_min_size = TcpPacket::minimum_packet_size(); + let header_length = (tcp_min_size / 4) as u8; tcp_packet.set_source(tcp_header.get_destination()); tcp_packet.set_destination(tcp_header.get_source()); @@ -80,7 +68,7 @@ pub fn build_tcp_reset<'a>( tcp_packet.set_checksum(checksum); - Ok(tcp_packet) + Ok(()) } #[inline] @@ -120,6 +108,7 @@ pub fn set_rlimit_nofile(n: u64) -> Result { Ok(value) } +/* #[cfg(test)] mod tests { use super::*; @@ -173,7 +162,7 @@ mod tests { fn test_build_tcp_reset() { let tcp_min_size = TcpPacket::minimum_packet_size(); - let incoming = TcpPacket::new(&[0u8; 40]).unwrap(); + let incoming = MutableTcpPacket::new(&[0u8; 40]).unwrap(); let mut buffer = vec![0u8; 128]; let tcp_reset = build_tcp_reset( @@ -197,3 +186,4 @@ mod tests { assert_eq!(tcp_reset.payload().len(), 0); } } +*/ diff --git a/server/src/worker.rs b/server/src/worker.rs index 6181081..0d262cb 100644 --- a/server/src/worker.rs +++ b/server/src/worker.rs @@ -4,6 +4,7 @@ use std::thread; use std::time::Instant; use anyhow::{anyhow, Context, Result}; +use arc_swap::ArcSwap; use ipnet::IpNet; use nfq::{Queue, Verdict}; use pnet::packet::ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}; @@ -18,17 +19,18 @@ use rsa::sha2::Sha256; use rsa::RsaPublicKey; use tracing::{debug, error, info}; -use crate::{util, Config, ConntrackEntry, ConntrackMap, Protocol, Sender}; +use crate::{util, Config, ConntrackEntry, ConntrackMap, Protocol, RejectPacketSender}; const IPV4_ADDR_BITS: u8 = 32; const IPV6_ADDR_BITS: u8 = 128; +#[derive(Clone)] pub struct Worker { queue_num: u16, - config: Arc, + config: Arc>, conntrack_map: Arc, - verifying_key: VerifyingKey, - sender: Sender, + verifying_key: Arc>, + reject: RejectPacketSender, } impl Worker { @@ -38,19 +40,19 @@ impl Worker { conntrack_map: Arc, ) -> Result { let public_key = RsaPublicKey::read_pkcs1_pem_file(&config.auth.key)?; - let verifying_key = VerifyingKey::::new(public_key); - let sender = Sender::new()?; + let verifying_key = Arc::new(VerifyingKey::::new(public_key)); + let reject = RejectPacketSender::new()?; Ok(Worker { - config, + config: Arc::new(ArcSwap::new(config)), queue_num, conntrack_map, verifying_key, - sender, + reject, }) } - pub fn start(mut self) -> Result<()> { + pub fn start(&self) -> Result<()> { let queue_num = self.queue_num; let mut queue = Queue::open()?; @@ -61,6 +63,7 @@ impl Worker { queue.set_recv_security_context(queue_num, true)?; queue.set_recv_uid_gid(queue_num, true)?; + let this = self.clone(); thread::spawn(move || { if let Err(e) = util::set_thread_priority() { error!("nfq {queue_num} failed to set thread priority: {e}"); @@ -70,7 +73,7 @@ impl Worker { info!("nfq {queue_num} worker started"); loop { - if let Err(e) = self.event_handler(&mut queue) { + if let Err(e) = this.event_handler(&mut queue) { error!("nfq {queue_num} failed handle event: {e}"); continue; } @@ -80,7 +83,11 @@ impl Worker { Ok(()) } - fn event_handler(&mut self, queue: &mut Queue) -> Result<()> { + pub fn update_config(&self, config: Arc) { + self.config.store(config); + } + + fn event_handler(&self, queue: &mut Queue) -> Result<()> { let mut verdict = Verdict::Drop; let mut msg = queue.recv()?; let payload = msg.get_payload(); @@ -98,7 +105,7 @@ impl Worker { Ok(()) } - fn ipv4_packet_handler(&mut self, payload: &[u8]) -> Result { + fn ipv4_packet_handler(&self, payload: &[u8]) -> Result { let ip_header = Ipv4Packet::new(payload).ok_or(anyhow!("Malformed IPv4 packet"))?; let source = IpAddr::V4(ip_header.get_source()); @@ -123,7 +130,7 @@ impl Worker { Ok(verdict) } - fn ipv6_packet_handler(&mut self, payload: &[u8]) -> Result { + fn ipv6_packet_handler(&self, payload: &[u8]) -> Result { let ip_header = Ipv6Packet::new(payload).ok_or(anyhow!("Malformed IPv6 packet"))?; let source = IpAddr::V6(ip_header.get_source()); @@ -149,7 +156,7 @@ impl Worker { } fn transport_protocol_handler( - &mut self, + &self, src_ip: IpAddr, dst_ip: IpAddr, protocol: IpNextHeaderProtocol, @@ -177,7 +184,7 @@ impl Worker { } fn udp_packet_handler( - &mut self, + &self, src_ip: IpAddr, dst_ip: IpAddr, ip_packet: &[u8], @@ -195,15 +202,17 @@ impl Worker { self.verify_packet(src_ip, src_port, dst_ip, dst_port, Protocol::Udp, payload)?; if verdict != Verdict::Accept { - self.sender - .emit_icmp_unreachable(&dst_ip, &src_ip, ip_packet, &udp_header)?; + let udp_packet_header = util::packet_header(&udp_header); + + self.reject + .emit_icmp_unreachable(&dst_ip, &src_ip, ip_packet, udp_packet_header)?; } Ok(verdict) } fn tcp_packet_handler( - &mut self, + &self, src_ip: IpAddr, dst_ip: IpAddr, payload: &[u8], @@ -220,7 +229,7 @@ impl Worker { self.verify_packet(src_ip, src_port, dst_ip, dst_port, Protocol::Tcp, payload)?; if verdict != Verdict::Accept { - self.sender.emit_tcp_rst(&src_ip, &dst_ip, &tcp_header)?; + self.reject.emit_tcp_rst(&src_ip, &dst_ip, &tcp_header)?; } Ok(verdict) @@ -236,14 +245,12 @@ impl Worker { payload: &[u8], ) -> Result { let queue_num = self.queue_num; + let allow_skew = self.config.load().auth.allow_skew; + let mut entry = ConntrackEntry::new(src_ip, dst_ip, dst_port, protocol.clone()); if self.is_auth_port(protocol, dst_port) { - match crypto::verify_knock_packet( - payload, - &self.verifying_key, - self.config.auth.allow_skew, - ) { + match crypto::verify_knock_packet(payload, &self.verifying_key, allow_skew) { Ok(knock_info) => { entry.dst_port = knock_info.unlock_port; @@ -271,7 +278,9 @@ impl Worker { } fn is_allow_ip(&self, source: &IpAddr) -> Result { - if self.config.filter.allow_ips.is_empty() { + let config = self.config.load(); + + if config.filter.allow_ips.is_empty() { return Ok(false); } @@ -283,10 +292,12 @@ impl Worker { let src_ip = IpNet::new(*source, bits).context(format!("IpNet::new({}, {}) fail", source, bits))?; - Ok(self.config.filter.allow_ips.contains(&src_ip)) + Ok(config.filter.allow_ips.contains(&src_ip)) } fn is_auth_port(&self, protocol: Protocol, dst_port: u16) -> bool { - self.config.auth.protocol == protocol && self.config.auth.port == dst_port + let config = self.config.load(); + + config.auth.protocol == protocol && config.auth.port == dst_port } }