Skip to content

Commit

Permalink
feat(rumqttd/broker): implemented filter for publish packets
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `Router::new` now takes an list
of `PublishFilterRef` as additional parameter
  • Loading branch information
mdrssv authored and shimunn committed Sep 16, 2024
1 parent 2377e4e commit bf942e5
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rumqttd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tracing_subscriber::{
pub use link::alerts;
pub use link::local;
pub use link::meters;
pub use router::{Alert, Forward, IncomingMeter, Meter, Notification, OutgoingMeter};
pub use router::{Alert, Forward, IncomingMeter, Meter, Notification, OutgoingMeter, PublishFilter, PublishFilterRef};
use segments::Storage;
pub use server::Broker;

Expand Down
134 changes: 134 additions & 0 deletions rumqttd/src/router/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::{fmt::Debug, ops::Deref, sync::Arc};

use crate::protocol::{Publish, PublishProperties};

/// Filter for [`Publish`] packets
pub trait PublishFilter {
/// Determines weather an [`Publish`] packet should be processed
/// Arguments:
/// * `packet`: to be published, may be modified if necessary
/// * `properties`: received along with the packet, may be `None` for older MQTT versions
/// Returns: [`bool`] indicating if the packet should be processed
fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool;
}

/// Container for either an owned [`PublishFilter`] or an `'static` reference
#[derive(Clone)]
pub enum PublishFilterRef {
Owned(Arc<dyn PublishFilter + Send + Sync>),
Static(&'static (dyn PublishFilter + Send + Sync)),
}

impl Debug for PublishFilterRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Owned(_arg0) => f.debug_tuple("Owned").finish(),
Self::Static(_arg0) => f.debug_tuple("Static").finish(),
}
}
}

impl Deref for PublishFilterRef {
type Target = dyn PublishFilter;

fn deref(&self) -> &Self::Target {
match self {
Self::Static(filter) => *filter,
Self::Owned(filter) => &**filter,
}
}
}

/// Implements [`PublishFilter`] for any ordinary function
impl<F> PublishFilter for F
where
F: Fn(&mut Publish, Option<&mut PublishProperties>) -> bool + Send + Sync,
{
fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool {
self(packet, properties)
}
}

/// Implements the conversion
/// ```rust
/// # use rumqttd::{protocol::{Publish, PublishProperties}, PublishFilterRef};
/// fn filter_static(packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool {
/// todo!()
/// }
///
/// let filter = PublishFilterRef::from(&filter_static);
/// # assert!(matches!(filter, PublishFilterRef::Static(_)));
/// ```
impl<F> From<&'static F> for PublishFilterRef
where
F: Fn(&mut Publish, Option<&mut PublishProperties>) -> bool + Send + Sync,
{
fn from(value: &'static F) -> Self {
Self::Static(value)
}
}

/// Implements the conversion
/// ```rust
/// # use std::boxed::Box;
/// # use rumqttd::{protocol::{Publish, PublishProperties}, PublishFilter, PublishFilterRef};
/// #[derive(Clone)]
/// struct MyFilter {}
///
/// impl PublishFilter for MyFilter {
/// fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool {
/// todo!()
/// }
/// }
/// let boxed: Box<MyFilter> = Box::new(MyFilter {});
///
/// let filter = PublishFilterRef::from(boxed);
/// # assert!(matches!(filter, PublishFilterRef::Owned(_)));
/// ```
impl<T> From<Arc<T>> for PublishFilterRef
where
T: PublishFilter + 'static + Send + Sync,
{
fn from(value: Arc<T>) -> Self {
Self::Owned(value)
}
}

impl<T> From<Box<T>> for PublishFilterRef
where
T: PublishFilter + 'static + Send + Sync,
{
fn from(value: Box<T>) -> Self {
Self::Owned(Arc::<T>::from(value))
}
}

#[cfg(test)]
mod tests {
use super::*;

fn filter_static(_packet: &mut Publish, _properties: Option<&mut PublishProperties>) -> bool {
true
}
struct Prejudiced(bool);

impl PublishFilter for Prejudiced {
fn filter(&self, _packet: &mut Publish,_propertiess: Option<&mut PublishProperties>) -> bool {
self.0
}
}
#[test]
fn static_filter() {
fn is_send<T: Send>(_: &T) {}
fn takes_static_filter(filter: impl Into<PublishFilterRef>) {
assert!(matches!(filter.into(), PublishFilterRef::Static(_)));
}
fn takes_owned_filter(filter: impl Into<PublishFilterRef>) {
assert!(matches!(filter.into(), PublishFilterRef::Owned(_)));
}
takes_static_filter(&filter_static);
let boxed: PublishFilterRef = Box::new(Prejudiced(false)).into();
is_send(&boxed);
takes_owned_filter(boxed);
}
}
2 changes: 2 additions & 0 deletions rumqttd/src/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ mod routing;
mod scheduler;
pub(crate) mod shared_subs;
mod waiters;
mod filter;

pub use filter::{PublishFilter, PublishFilterRef};
pub use alertlog::Alert;
pub use connection::Connection;
pub use routing::Router;
Expand Down
23 changes: 16 additions & 7 deletions rumqttd/src/router/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub struct Router {
connections: Slab<Connection>,
/// Connection map from device id to connection id
connection_map: HashMap<String, ConnectionId>,
/// Filters to be applied to an [`Publish`] packets payload
publish_filters: Vec<PublishFilterRef>,
/// Subscription map to interested connection ids
subscription_map: HashMap<Filter, HashSet<ConnectionId>>,
/// Incoming data grouped by connection
Expand Down Expand Up @@ -105,7 +107,7 @@ pub struct Router {
}

impl Router {
pub fn new(router_id: RouterId, config: RouterConfig) -> Router {
pub fn new(router_id: RouterId, publish_filters: Vec<PublishFilterRef>, config: RouterConfig) -> Router {
let (router_tx, router_rx) = bounded(1000);

let meters = Slab::with_capacity(10);
Expand All @@ -129,6 +131,7 @@ impl Router {
alerts,
connections,
connection_map: Default::default(),
publish_filters,
subscription_map: Default::default(),
ibufs,
obufs,
Expand Down Expand Up @@ -557,13 +560,18 @@ impl Router {

for packet in packets.drain(0..) {
match packet {
Packet::Publish(publish, properties) => {
Packet::Publish(mut publish, mut properties) => {
println!("publish: {publish:?} payload: {:?}", publish.payload.to_vec());
let span = tracing::error_span!("publish", topic = ?publish.topic, pkid = publish.pkid);
let _guard = span.enter();

let qos = publish.qos;
let pkid = publish.pkid;


// Decide weather to keep or discard this packet
// Packet will be discard if *at least one* filter returns *false*
let keep = self.publish_filters.iter().fold(true,|keep,f| keep && f.filter(&mut publish, properties.as_mut())) ;

// Prepare acks for the above publish
// If any of the publish in the batch results in force flush,
// set global force flush flag. Force flush is triggered when the
Expand All @@ -577,12 +585,11 @@ impl Router {
// coordinate using multiple offsets, and we don't have any idea how to do so right now.
// Currently as we don't have replication, we just use a single offset, even when appending to
// multiple commit logs.

match qos {
QoS::AtLeastOnce => {
let puback = PubAck {
pkid,
reason: PubAckReason::Success,
reason: if keep { PubAckReason::Success } else { PubAckReason::PayloadFormatInvalid },
};

let ackslog = self.ackslog.get_mut(id).unwrap();
Expand All @@ -592,7 +599,7 @@ impl Router {
QoS::ExactlyOnce => {
let pubrec = PubRec {
pkid,
reason: PubRecReason::Success,
reason: if keep { PubRecReason::Success } else { PubRecReason::PayloadFormatInvalid },
};

let ackslog = self.ackslog.get_mut(id).unwrap();
Expand All @@ -604,7 +611,9 @@ impl Router {
// Do nothing
}
};

if !keep {
break;
}
self.router_meters.total_publishes += 1;

// Try to append publish to commitlog
Expand Down
10 changes: 8 additions & 2 deletions rumqttd/src/server/broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::{io, thread};

use crate::link::console;
use crate::link::local::{self, LinkRx, LinkTx};
use crate::router::{Event, Router};
use crate::router::{Event, PublishFilterRef, Router};
use crate::{Config, ConnectionId, ServerSettings};

use tokio::net::{TcpListener, TcpStream};
Expand Down Expand Up @@ -71,9 +71,13 @@ pub struct Broker {

impl Broker {
pub fn new(config: Config) -> Broker {
Self::with_filter(config, Vec::new())
}

pub fn with_filter(config: Config, publish_filters: Vec<PublishFilterRef>) -> Broker {
let config = Arc::new(config);
let router_config = config.router.clone();
let router: Router = Router::new(config.id, router_config);
let router: Router = Router::new(config.id, publish_filters, router_config);

// Setup cluster if cluster settings are configured.
match config.cluster.clone() {
Expand All @@ -96,6 +100,8 @@ impl Broker {
}
}



// pub fn new_local_cluster(
// config: Config,
// node_id: NodeId,
Expand Down

0 comments on commit bf942e5

Please sign in to comment.