From 2a37b52b542e2403c2676fd7bf7203f6a5c9b029 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Wed, 28 Feb 2024 11:14:13 +0530 Subject: [PATCH 01/16] fix(mqttbytes): add function to validate topic name and filter according to MQTTv3 spec --- rumqttc/src/mqttbytes/topic.rs | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index a66c35f8..322a9917 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -1,12 +1,32 @@ +/// Maximum length of a topic or topic filter according to +/// [MQTT-4.7.3-3](https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718109) +pub const MAX_TOPIC_LEN: usize = 65535; + /// Checks if a topic or topic filter has wildcards pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') + s.contains(['+', '#']) } -/// Checks if a topic is valid +/// Check if a topic is valid for PUBLISH packet. pub fn valid_topic(topic: &str) -> bool { - // topic can't contain wildcards - if topic.contains('+') || topic.contains('#') { + is_valid_topic_or_filter(topic) && !has_wildcards(topic) +} + +/// Check if a topic is valid to qualify as a topic name or topic filter. +/// +/// According to MQTT v3 Spec, it has to follow the following rules: +/// 1. All Topic Names and Topic Filters MUST be at least one character long [MQTT-4.7.3-1] +/// 2. Topic Names and Topic Filters are case sensitive +/// 3. Topic Names and Topic Filters can include the space character +/// 4. A leading or trailing `/` creates a distinct Topic Name or Topic Filter +/// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid +/// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] +/// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. +fn is_valid_topic_or_filter(topic_or_filter: &str) -> bool { + if topic_or_filter.is_empty() + || topic_or_filter.len() > MAX_TOPIC_LEN + || topic_or_filter.contains('\0') + { return false; } @@ -17,7 +37,7 @@ pub fn valid_topic(topic: &str) -> bool { /// /// pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { + if !is_valid_topic_or_filter(filter) { return false; } From 33b438c9bf3994b43a0a66c8e272878104d2240f Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Sun, 3 Mar 2024 20:47:11 +0530 Subject: [PATCH 02/16] fix(v5/client): validate topic name and alias --- rumqttc/src/v5/client.rs | 5 ++- rumqttc/src/v5/mqttbytes/mod.rs | 79 ++++++++++----------------------- 2 files changed, 27 insertions(+), 57 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c..978a725d 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -6,7 +6,7 @@ use super::mqttbytes::v5::{ Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; -use super::mqttbytes::{valid_filter, QoS}; +use super::mqttbytes::{valid_filter, validate_topic_name_and_alias, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; @@ -84,10 +84,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); + let is_ok = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_ok { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index c205aaa9..fc76db54 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -1,4 +1,8 @@ +// reexport valid_filter and valid_topic since they are identical in nature for +// both v3 and v5 +pub use crate::mqttbytes::{valid_filter, valid_topic}; use std::{str::Utf8Error, vec}; +use super::PublishProperties; /// This module is the place where all the protocol specifics gets abstracted /// out and creates a structures which are common across protocols. Since, @@ -32,61 +36,26 @@ pub fn qos(num: u8) -> Option { } } -/// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') -} - -/// Checks if a topic is valid -pub fn valid_topic(topic: &str) -> bool { - // topic can't contain wildcards - if topic.contains('+') || topic.contains('#') { - return false; - } - - true -} - -/// Checks if the filter is valid -/// -/// -pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { - return false; - } - - // rev() is used so we can easily get the last entry - let mut hirerarchy = filter.split('/').rev(); - - // split will never return an empty iterator - // even if the pattern isn't matched, the original string will be there - // so it is safe to just unwrap here! - let last = hirerarchy.next().unwrap(); - - // only single '#" or '+' is allowed in last entry - // invalid: sport/tennis# - // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { - return false; - } - - // remaining entries - for entry in hirerarchy { - // # is not allowed in filter except as a last entry - // invalid: sport/tennis#/player - // invalid: sport/tennis/#/ranking - if entry.contains('#') { - return false; - } - - // + must occupy an entire level of the filter - // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { - return false; - } - } - - true +#[inline] +pub(crate) fn validate_topic_name_and_alias( + topic: &str, + properties: &Option, +) -> bool { + let is_topic_empty = topic.is_empty(); + // The topic alias is considered valid only if it is greater than zero. + // If it is not supplied, it is still considered valid because in that + // case, the validity depends on the topic name itself. + let is_topic_valid = valid_topic(topic); + let is_alias_given = properties + .as_ref() + .map_or(false, |p| p.topic_alias.is_some()); + let is_alias_valid = properties + .as_ref() + .map_or(false, |p| p.topic_alias.map_or(false, |a| a > 0)); + + (is_topic_valid || is_topic_empty) + && (!is_topic_empty || is_alias_given) + && (!is_alias_given || is_alias_valid) } /// Checks if topic matches a filter. topic and filter validation isn't done here. From 5b1fcdfe9ca304f9f6e12aca58d58f2f617f89eb Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Sun, 10 Mar 2024 16:37:10 +0530 Subject: [PATCH 03/16] refactor(mqttbytes): use `impl AsRef` for function parameters --- rumqttc/src/mqttbytes/topic.rs | 19 ++++++++++++------- rumqttc/src/v5/mqttbytes/mod.rs | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index 322a9917..da895ec8 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -3,13 +3,13 @@ pub const MAX_TOPIC_LEN: usize = 65535; /// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains(['+', '#']) +pub fn has_wildcards(s: impl AsRef) -> bool { + s.as_ref().contains(['+', '#']) } /// Check if a topic is valid for PUBLISH packet. -pub fn valid_topic(topic: &str) -> bool { - is_valid_topic_or_filter(topic) && !has_wildcards(topic) +pub fn valid_topic(topic: impl AsRef) -> bool { + is_valid_topic_or_filter(&topic) && !has_wildcards(topic) } /// Check if a topic is valid to qualify as a topic name or topic filter. @@ -22,7 +22,8 @@ pub fn valid_topic(topic: &str) -> bool { /// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid /// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] /// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. -fn is_valid_topic_or_filter(topic_or_filter: &str) -> bool { +fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { + let topic_or_filter = topic_or_filter.as_ref(); if topic_or_filter.is_empty() || topic_or_filter.len() > MAX_TOPIC_LEN || topic_or_filter.contains('\0') @@ -36,7 +37,8 @@ fn is_valid_topic_or_filter(topic_or_filter: &str) -> bool { /// Checks if the filter is valid /// /// -pub fn valid_filter(filter: &str) -> bool { +pub fn valid_filter(filter: impl AsRef) -> bool { + let filter = filter.as_ref(); if !is_valid_topic_or_filter(filter) { return false; } @@ -80,7 +82,10 @@ pub fn valid_filter(filter: &str) -> bool { /// **NOTE**: 'topic' is a misnomer in the arg. this can also be used to match 2 wild subscriptions /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe -pub fn matches(topic: &str, filter: &str) -> bool { +pub fn matches(topic: impl AsRef, filter: impl AsRef) -> bool { + let topic = topic.as_ref(); + let filter = filter.as_ref(); + if !topic.is_empty() && topic[..1].contains('$') { return false; } diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index fc76db54..b2b8054d 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -1,8 +1,8 @@ // reexport valid_filter and valid_topic since they are identical in nature for // both v3 and v5 +use super::PublishProperties; pub use crate::mqttbytes::{valid_filter, valid_topic}; use std::{str::Utf8Error, vec}; -use super::PublishProperties; /// This module is the place where all the protocol specifics gets abstracted /// out and creates a structures which are common across protocols. Since, @@ -36,11 +36,11 @@ pub fn qos(num: u8) -> Option { } } -#[inline] pub(crate) fn validate_topic_name_and_alias( - topic: &str, + topic: impl AsRef, properties: &Option, ) -> bool { + let topic = topic.as_ref(); let is_topic_empty = topic.is_empty(); // The topic alias is considered valid only if it is greater than zero. // If it is not supplied, it is still considered valid because in that From df194cc2be90e65b194daef31cea541d7db6b728 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Sun, 10 Mar 2024 16:38:43 +0530 Subject: [PATCH 04/16] fix: implement topic validation in v3 and v5 client --- rumqttc/src/client.rs | 66 ++++++++++++++++----------- rumqttc/src/v5/client.rs | 98 ++++++++++++++++++++++------------------ 2 files changed, 93 insertions(+), 71 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 15cd5f5a..f3cdd5d7 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -141,9 +141,13 @@ impl AsyncClient { where S: Into, { - let mut publish = Publish::from_bytes(topic, qos, payload); + let topic = topic.into(); + let mut publish = Publish::from_bytes(&topic, qos, payload); publish.retain = retain; let publish = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::Request(publish)); + } self.request_tx.send_async(publish).await?; Ok(()) } @@ -153,7 +157,7 @@ impl AsyncClient { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -165,8 +169,8 @@ impl AsyncClient { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::TryRequest(request)); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); } self.request_tx.try_send(request)?; Ok(()) @@ -175,13 +179,12 @@ impl AsyncClient { /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_err { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -191,14 +194,13 @@ impl AsyncClient { /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + if is_err { + return Err(ClientError::Request(request)); } self.request_tx.try_send(request)?; Ok(()) @@ -206,16 +208,24 @@ impl AsyncClient { /// Sends a MQTT Unsubscribe to the `EventLoop` pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(&topic) { + return Err(ClientError::Request(request)); + } self.request_tx.send_async(request).await?; Ok(()) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(&topic) { + return Err(ClientError::Request(request)); + } self.request_tx.try_send(request)?; Ok(()) } @@ -319,8 +329,7 @@ impl Client { S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -344,7 +353,7 @@ impl Client { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -360,13 +369,12 @@ impl Client { /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_err { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -375,15 +383,19 @@ impl Client { pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.client.try_subscribe_many(topics) } /// Sends a MQTT Unsubscribe to the `EventLoop` pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.client.request_tx.send(request)?; Ok(()) } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 978a725d..f5dbb18b 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,6 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::{valid_filter, validate_topic_name_and_alias, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::valid_topic; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -85,7 +84,7 @@ impl AsyncClient { { let topic = topic.into(); let is_ok = validate_topic_name_and_alias(&topic, &properties); - let mut publish = Publish::new(&topic, qos, payload, properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); if !is_ok { @@ -139,10 +138,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); + let is_ok = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_ok { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; @@ -210,10 +210,11 @@ impl AsyncClient { S: Into, { let topic = topic.into(); - let mut publish = Publish::new(&topic, qos, payload, properties); + let is_ok = validate_topic_name_and_alias(&topic, &properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_ok { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; @@ -256,11 +257,11 @@ impl AsyncClient { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); - let request: Request = Request::Subscribe(subscribe); - if !is_filter_valid { + let request = Request::Subscribe(subscribe); + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -287,11 +288,11 @@ impl AsyncClient { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); let request = Request::Subscribe(subscribe); - if !is_filter_valid { + if !valid_filter(topic) { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; @@ -318,16 +319,14 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_err { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; Ok(()) } @@ -338,14 +337,14 @@ impl AsyncClient { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_subscribe_many(topics, Some(properties)).await } pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_subscribe_many(topics, None).await } @@ -357,13 +356,12 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_err { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; @@ -376,14 +374,14 @@ impl AsyncClient { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_try_subscribe_many(topics, Some(properties)) } pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_try_subscribe_many(topics, None) } @@ -394,8 +392,12 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.request_tx.send_async(request).await?; Ok(()) } @@ -418,8 +420,12 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::TryRequest(request)); + } self.request_tx.try_send(request)?; Ok(()) } @@ -516,10 +522,11 @@ impl Client { P: Into, { let topic = topic.into(); - let mut publish = Publish::new(&topic, qos, payload, properties); + let is_ok = validate_topic_name_and_alias(&topic, &properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_ok { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; @@ -608,11 +615,11 @@ impl Client { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); let request = Request::Subscribe(subscribe); - if !is_filter_valid { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -654,13 +661,12 @@ impl Client { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_err { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -673,14 +679,14 @@ impl Client { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_subscribe_many(topics, Some(properties)) } pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.handle_subscribe_many(topics, None) } @@ -691,7 +697,7 @@ impl Client { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.client .try_subscribe_many_with_properties(topics, properties) @@ -699,7 +705,7 @@ impl Client { pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator, + T: IntoIterator + Clone, { self.client.try_subscribe_many(topics) } @@ -710,8 +716,12 @@ impl Client { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.client.request_tx.send(request)?; Ok(()) } From 4bae451369cfd6ce16bad5fc1f5ca7a739f3ec17 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Tue, 12 Mar 2024 16:53:43 +0530 Subject: [PATCH 05/16] feat(rumqttd/protocol): add helper functions to validate topic and topic filter --- rumqttd/src/protocol/mod.rs | 81 ++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/rumqttd/src/protocol/mod.rs b/rumqttd/src/protocol/mod.rs index a27543d1..640f33a5 100644 --- a/rumqttd/src/protocol/mod.rs +++ b/rumqttd/src/protocol/mod.rs @@ -596,17 +596,33 @@ pub fn qos(num: u8) -> Option { } /// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') +pub fn has_wildcards(s: impl AsRef) -> bool { + s.as_ref().contains(['+', '#']) } -/// Checks if a topic is valid -pub fn valid_topic(topic: &str) -> bool { - if topic.contains('+') { - return false; - } +/// Check if a topic is valid for PUBLISH packet. +pub fn valid_topic(topic: impl AsRef) -> bool { + is_valid_topic_or_filter(&topic) && !has_wildcards(topic) +} - if topic.contains('#') { +pub const MAX_TOPIC_LEN: usize = 65535; + +/// Check if a topic is valid to qualify as a topic name or topic filter. +/// +/// According to MQTT v3 Spec, it has to follow the following rules: +/// 1. All Topic Names and Topic Filters MUST be at least one character long [MQTT-4.7.3-1] +/// 2. Topic Names and Topic Filters are case sensitive +/// 3. Topic Names and Topic Filters can include the space character +/// 4. A leading or trailing `/` creates a distinct Topic Name or Topic Filter +/// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid +/// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] +/// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. +fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { + let topic_or_filter = topic_or_filter.as_ref(); + if topic_or_filter.is_empty() + || topic_or_filter.len() > MAX_TOPIC_LEN + || topic_or_filter.contains('\0') + { return false; } @@ -616,32 +632,39 @@ pub fn valid_topic(topic: &str) -> bool { /// Checks if the filter is valid /// /// -pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { +pub fn valid_filter(filter: impl AsRef) -> bool { + let filter = filter.as_ref(); + if !is_valid_topic_or_filter(filter) { return false; } - let hirerarchy = filter.split('/').collect::>(); - if let Some((last, remaining)) = hirerarchy.split_last() { - for entry in remaining.iter() { - // # is not allowed in filter except as a last entry - // invalid: sport/tennis#/player - // invalid: sport/tennis/#/ranking - if entry.contains('#') { - return false; - } - - // + must occupy an entire level of the filter - // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { - return false; - } + // rev() is used so we can easily get the last entry + let mut hirerarchy = filter.split('/').rev(); + + // split will never return an empty iterator + // even if the pattern isn't matched, the original string will be there + // so it is safe to just unwrap here! + let last = hirerarchy.next().unwrap(); + + // only single '#" or '+' is allowed in last entry + // invalid: sport/tennis# + // invalid: sport/++ + if last.len() != 1 && (last.contains('#') || last.contains('+')) { + return false; + } + + // remaining entries + for entry in hirerarchy { + // # is not allowed in filter except as a last entry + // invalid: sport/tennis#/player + // invalid: sport/tennis/#/ranking + if entry.contains('#') { + return false; } - // only single '#" or '+' is allowed in last entry - // invalid: sport/tennis# - // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { + // + must occupy an entire level of the filter + // invalid: sport+ + if entry.len() > 1 && entry.contains('+') { return false; } } From 4ba27a2f4dcf7d75bc1de330b2e187c6efb44adc Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Thu, 14 Mar 2024 15:09:13 +0530 Subject: [PATCH 06/16] feat(rumqttd): add topic name and filter validation for appropriate packets --- rumqttd/src/protocol/v4/publish.rs | 5 +++++ rumqttd/src/protocol/v4/subscribe.rs | 4 ++++ rumqttd/src/protocol/v4/unsubscribe.rs | 3 +++ rumqttd/src/protocol/v5/publish.rs | 5 +++++ rumqttd/src/protocol/v5/subscribe.rs | 4 ++++ rumqttd/src/protocol/v5/unsubscribe.rs | 3 +++ 6 files changed, 24 insertions(+) diff --git a/rumqttd/src/protocol/v4/publish.rs b/rumqttd/src/protocol/v4/publish.rs index b64cb314..0839bde6 100644 --- a/rumqttd/src/protocol/v4/publish.rs +++ b/rumqttd/src/protocol/v4/publish.rs @@ -1,5 +1,6 @@ use super::*; use bytes::{Buf, Bytes}; +use core::str::from_utf8; fn len(publish: &Publish) -> usize { let len = 2 + publish.topic.len() + publish.payload.len(); @@ -19,6 +20,10 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result 0 let pkid = match qos { QoS::AtMostOnce => 0, diff --git a/rumqttd/src/protocol/v4/subscribe.rs b/rumqttd/src/protocol/v4/subscribe.rs index 2f75299e..853c3744 100644 --- a/rumqttd/src/protocol/v4/subscribe.rs +++ b/rumqttd/src/protocol/v4/subscribe.rs @@ -13,6 +13,10 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result Err(Error::EmptySubscription), _ => Ok(Subscribe { pkid, filters }), diff --git a/rumqttd/src/protocol/v4/unsubscribe.rs b/rumqttd/src/protocol/v4/unsubscribe.rs index c1fb8dd3..0760bad1 100644 --- a/rumqttd/src/protocol/v4/unsubscribe.rs +++ b/rumqttd/src/protocol/v4/unsubscribe.rs @@ -11,6 +11,9 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result 0 { let topic_filter = read_mqtt_string(&mut bytes)?; + if !valid_filter(&topic_filter) { + return Err(Error::MalformedPacket); + } payload_bytes -= topic_filter.len() + 2; filters.push(topic_filter); } diff --git a/rumqttd/src/protocol/v5/publish.rs b/rumqttd/src/protocol/v5/publish.rs index 05313367..a11468f6 100644 --- a/rumqttd/src/protocol/v5/publish.rs +++ b/rumqttd/src/protocol/v5/publish.rs @@ -1,6 +1,7 @@ use super::*; use bytes::{Buf, Bytes}; use core::fmt; +use std::str::from_utf8; pub fn len(publish: &Publish, properties: &Option) -> usize { let mut len = 2 + publish.topic.len(); @@ -34,6 +35,10 @@ pub fn read( bytes.advance(variable_header_index); let topic = read_mqtt_bytes(&mut bytes)?; + if !valid_topic(from_utf8(&topic).map_err(|_| Error::TopicNotUtf8)?) { + return Err(Error::MalformedPacket); + } + // Packet identifier exists where QoS > 0 let pkid = match qos { QoS::AtMostOnce => 0, diff --git a/rumqttd/src/protocol/v5/subscribe.rs b/rumqttd/src/protocol/v5/subscribe.rs index 686d3686..e7117c25 100644 --- a/rumqttd/src/protocol/v5/subscribe.rs +++ b/rumqttd/src/protocol/v5/subscribe.rs @@ -32,6 +32,10 @@ pub fn read( while bytes.has_remaining() { let path = read_mqtt_string(&mut bytes)?; + if !valid_filter(&path) { + return Err(Error::MalformedPacket); + } + let options = read_u8(&mut bytes)?; let requested_qos = options & 0b0000_0011; diff --git a/rumqttd/src/protocol/v5/unsubscribe.rs b/rumqttd/src/protocol/v5/unsubscribe.rs index 1245bf7f..6f2460c9 100644 --- a/rumqttd/src/protocol/v5/unsubscribe.rs +++ b/rumqttd/src/protocol/v5/unsubscribe.rs @@ -31,6 +31,9 @@ pub fn read( let mut filters = Vec::with_capacity(1); while bytes.has_remaining() { let filter = read_mqtt_string(&mut bytes)?; + if !valid_filter(&filter) { + return Err(Error::MalformedPacket); + } filters.push(filter); } From 090c663b8c74abe77c9fbd2f9802cfbd8a3d0066 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Thu, 14 Mar 2024 15:31:05 +0530 Subject: [PATCH 07/16] docs(changelog): mention changes in changelog [skip ci] --- rumqttc/CHANGELOG.md | 4 ++-- rumqttd/CHANGELOG.md | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 1045cfcf..a7eb3751 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -25,8 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -* Validate filters while creating subscription requests. -* Make v4::Connect::write return correct value +- Make v4::Connect::write return correct value +- Validate topic filter and topic name for MQTT v3 and v5 during `Publish`, `Subscribe`, and `Unsubscribe` operations. ### Security diff --git a/rumqttd/CHANGELOG.md b/rumqttd/CHANGELOG.md index 6d981ffd..68d05300 100644 --- a/rumqttd/CHANGELOG.md +++ b/rumqttd/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - record client id for remote link's span - session present flag in connack - Make write method return the number of bytes written correctly everywhere +- Validate topic name and topic filter in MQTT v3 and v5 `Publish`, `Subscribe` and `Unsubscribe` packets. ### Security - Implement constant-time password comparison in authentication logic @@ -177,7 +178,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [rumqttd 0.12.4] - 01-02-2023 ### Fixed -- Client id with tenant prefix should be set globally (#564) +- Client id with tenant prefix should be set globally (#564) ## [rumqttd 0.12.3] - 23-01-2023 From bc438a218d28340c3114980be01474d29f9edab3 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Wed, 20 Mar 2024 17:31:51 +0530 Subject: [PATCH 08/16] refactor(mqttbytes/topic): use a better function name --- rumqttc/src/mqttbytes/topic.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index da895ec8..a06f9f96 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -9,7 +9,7 @@ pub fn has_wildcards(s: impl AsRef) -> bool { /// Check if a topic is valid for PUBLISH packet. pub fn valid_topic(topic: impl AsRef) -> bool { - is_valid_topic_or_filter(&topic) && !has_wildcards(topic) + can_be_topic_or_filter(&topic) && !has_wildcards(topic) } /// Check if a topic is valid to qualify as a topic name or topic filter. @@ -22,7 +22,7 @@ pub fn valid_topic(topic: impl AsRef) -> bool { /// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid /// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] /// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. -fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { +fn can_be_topic_or_filter(topic_or_filter: impl AsRef) -> bool { let topic_or_filter = topic_or_filter.as_ref(); if topic_or_filter.is_empty() || topic_or_filter.len() > MAX_TOPIC_LEN @@ -39,7 +39,7 @@ fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { /// pub fn valid_filter(filter: impl AsRef) -> bool { let filter = filter.as_ref(); - if !is_valid_topic_or_filter(filter) { + if !can_be_topic_or_filter(filter) { return false; } From 9508de74dc67d7b156c647540baf2c7963c0a1fd Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Wed, 20 Mar 2024 18:06:40 +0530 Subject: [PATCH 09/16] feat(mqttbytes): accept reference to slice of `Filter`s instead of iterator --- rumqttc/src/mqttbytes/v4/subscribe.rs | 12 +++++------- rumqttc/src/v5/mqttbytes/v5/subscribe.rs | 10 +++++----- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 42ddb57b..a12b5981 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -21,13 +21,11 @@ impl Subscribe { } } - pub fn new_many(topics: T) -> Subscribe - where - T: IntoIterator, - { - let filters: Vec = topics.into_iter().collect(); - - Subscribe { pkid: 0, filters } + pub fn new_many(topics: impl AsRef<[SubscribeFilter]>) -> Subscribe { + Subscribe { + pkid: 0, + filters: topics.as_ref().to_vec(), + } } pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 4167cd67..6cf0167d 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -18,12 +18,12 @@ impl Subscribe { } } - pub fn new_many(filters: F, properties: Option) -> Self - where - F: IntoIterator, - { + pub fn new_many( + filters: impl AsRef<[Filter]>, + properties: Option, + ) -> Self { Self { - filters: filters.into_iter().collect(), + filters: filters.as_ref().to_vec(), properties, ..Default::default() } From 165514ff0403a477aa41a60dc06577d4e1656653 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Wed, 20 Mar 2024 18:08:57 +0530 Subject: [PATCH 10/16] fix(rumqttc/client): remove `Clone` trait requirement --- rumqttc/src/client.rs | 17 ++++++++++------- rumqttc/src/v5/client.rs | 31 +++++++++++++++++-------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index f3cdd5d7..9eac40f7 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -179,9 +179,10 @@ impl AsyncClient { /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); if is_err { @@ -194,9 +195,10 @@ impl AsyncClient { /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); if is_err { @@ -369,9 +371,10 @@ impl Client { /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); if is_err { @@ -383,7 +386,7 @@ impl Client { pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.client.try_subscribe_many(topics) } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f5dbb18b..630bf12f 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -319,9 +319,10 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); if is_err { @@ -337,14 +338,14 @@ impl AsyncClient { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_subscribe_many(topics, None).await } @@ -356,9 +357,10 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); if is_err { @@ -374,14 +376,14 @@ impl AsyncClient { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_try_subscribe_many(topics, None) } @@ -661,9 +663,10 @@ impl Client { properties: Option, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { - let is_err = topics.clone().into_iter().any(|t| !valid_filter(t.path)); + let topics = topics.into_iter().collect::>(); + let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); let request = Request::Subscribe(subscribe); if is_err { @@ -679,14 +682,14 @@ impl Client { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.handle_subscribe_many(topics, None) } @@ -697,7 +700,7 @@ impl Client { properties: SubscribeProperties, ) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.client .try_subscribe_many_with_properties(topics, properties) @@ -705,7 +708,7 @@ impl Client { pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> where - T: IntoIterator + Clone, + T: IntoIterator, { self.client.try_subscribe_many(topics) } From 78d45d46b0ce953e4f7cd3c189d1642f661f1c6c Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Thu, 28 Mar 2024 22:32:18 +0530 Subject: [PATCH 11/16] perf(rumqttc): avoid creation of temporary vec while validating packet --- rumqttc/src/client.rs | 11 ++++------- rumqttc/src/mqttbytes/v4/subscribe.rs | 4 ++-- rumqttc/src/v5/client.rs | 9 +++------ rumqttc/src/v5/mqttbytes/v5/subscribe.rs | 4 ++-- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 9eac40f7..678dff4d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -181,9 +181,8 @@ impl AsyncClient { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); - let subscribe = Subscribe::new_many(topics); + let subscribe = Subscribe::new_many(topics.into_iter()); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::Request(request)); @@ -197,9 +196,8 @@ impl AsyncClient { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::Request(request)); @@ -373,9 +371,8 @@ impl Client { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::Request(request)); diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index a12b5981..c32ce27b 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -21,10 +21,10 @@ impl Subscribe { } } - pub fn new_many(topics: impl AsRef<[SubscribeFilter]>) -> Subscribe { + pub fn new_many(topics: impl IntoIterator) -> Subscribe { Subscribe { pkid: 0, - filters: topics.as_ref().to_vec(), + filters: topics.into_iter().collect(), } } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 630bf12f..d652c613 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -321,9 +321,8 @@ impl AsyncClient { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::Request(request)); @@ -359,9 +358,8 @@ impl AsyncClient { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::TryRequest(request)); @@ -665,9 +663,8 @@ impl Client { where T: IntoIterator, { - let topics = topics.into_iter().collect::>(); - let is_err = topics.iter().any(|t| !valid_filter(&t.path)); let subscribe = Subscribe::new_many(topics, properties); + let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { return Err(ClientError::Request(request)); diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 6cf0167d..7851ff4e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -19,11 +19,11 @@ impl Subscribe { } pub fn new_many( - filters: impl AsRef<[Filter]>, + filters: impl IntoIterator, properties: Option, ) -> Self { Self { - filters: filters.as_ref().to_vec(), + filters: filters.into_iter().collect(), properties, ..Default::default() } From 68f524a559dbca463b353d75af97039253751c10 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Thu, 28 Mar 2024 22:36:57 +0530 Subject: [PATCH 12/16] refactor(rumqttc/client): remove redundant call to `into_iter()` --- rumqttc/src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 678dff4d..9852bd5d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -181,7 +181,7 @@ impl AsyncClient { where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics.into_iter()); + let subscribe = Subscribe::new_many(topics); let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); if is_err { From 719efe89594d6562b4b633048593c63f15c88f84 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Tue, 2 Apr 2024 13:44:30 +0530 Subject: [PATCH 13/16] refactor(rumqttc): use descriptive variable names refactor(v5/client): use descriptive variable name --- rumqttc/src/client.rs | 12 ++++++------ rumqttc/src/v5/client.rs | 28 ++++++++++++++-------------- rumqttc/src/v5/state.rs | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 9852bd5d..1b15bf1d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -182,9 +182,9 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -197,9 +197,9 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.request_tx.try_send(request)?; @@ -372,9 +372,9 @@ impl Client { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index d652c613..551b080e 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -83,11 +83,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); - let is_ok = validate_topic_name_and_alias(&topic, &properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !is_ok { + if !is_valid { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; @@ -138,11 +138,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); - let is_ok = validate_topic_name_and_alias(&topic, &properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !is_ok { + if !is_valid { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; @@ -210,11 +210,11 @@ impl AsyncClient { S: Into, { let topic = topic.into(); - let is_ok = validate_topic_name_and_alias(&topic, &properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !is_ok { + if !is_valid { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; @@ -322,9 +322,9 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics, properties); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -359,9 +359,9 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics, properties); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; @@ -522,11 +522,11 @@ impl Client { P: Into, { let topic = topic.into(); - let is_ok = validate_topic_name_and_alias(&topic, &properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !is_ok { + if !is_valid { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; @@ -664,9 +664,9 @@ impl Client { T: IntoIterator, { let subscribe = Subscribe::new_many(topics, properties); - let is_err = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if is_err { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 854aa7b0..5d415dc3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -65,7 +65,7 @@ pub enum StateError { #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, #[error("Connection closed by peer abruptly")] - ConnectionAborted + ConnectionAborted, } impl From for StateError { From ee01feda9246b9a8afff221dce41bc19f85e5ee2 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Tue, 2 Apr 2024 13:04:47 +0530 Subject: [PATCH 14/16] build(deps): add memchr as dependency --- Cargo.lock | 6 ++++-- rumqttc/Cargo.toml | 1 + rumqttd/Cargo.toml | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7396d43..1f33eb37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1208,9 +1208,9 @@ checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "memmap2" @@ -1940,6 +1940,7 @@ dependencies = [ "http 1.0.0", "log", "matches", + "memchr", "native-tls", "pretty_assertions", "pretty_env_logger", @@ -1968,6 +1969,7 @@ dependencies = [ "config", "flume", "futures-util", + "memchr", "metrics", "metrics-exporter-prometheus", "parking_lot", diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index bba64822..4b219307 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -49,6 +49,7 @@ url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } tokio-stream = "0.1.15" +memchr = "2.7.2" [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttd/Cargo.toml b/rumqttd/Cargo.toml index 2de4f420..e1ac0902 100644 --- a/rumqttd/Cargo.toml +++ b/rumqttd/Cargo.toml @@ -39,6 +39,7 @@ axum = "0.7.4" rand = "0.8.5" uuid = { version = "1.7.0", features = ["v4", "fast-rng"] } subtle = "2.5" +memchr = "2.7.2" [features] default = ["use-rustls", "websocket"] From 8bdec24e7bf50e0e9daf95bdeebbf8490ff9b4af Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Tue, 2 Apr 2024 13:08:02 +0530 Subject: [PATCH 15/16] perf(rumqtt): use `memchr` instead of `str.contains` for better throughput in hotspots --- rumqttc/src/mqttbytes/topic.rs | 14 ++++++++------ rumqttc/src/v5/mqttbytes/mod.rs | 3 ++- rumqttd/src/protocol/mod.rs | 15 +++++++++------ 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index a06f9f96..f700d9c4 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -1,10 +1,12 @@ +use memchr::{memchr, memchr2}; + /// Maximum length of a topic or topic filter according to /// [MQTT-4.7.3-3](https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718109) pub const MAX_TOPIC_LEN: usize = 65535; /// Checks if a topic or topic filter has wildcards pub fn has_wildcards(s: impl AsRef) -> bool { - s.as_ref().contains(['+', '#']) + memchr2(b'+', b'#', s.as_ref().as_bytes()).is_some() } /// Check if a topic is valid for PUBLISH packet. @@ -26,7 +28,7 @@ fn can_be_topic_or_filter(topic_or_filter: impl AsRef) -> bool { let topic_or_filter = topic_or_filter.as_ref(); if topic_or_filter.is_empty() || topic_or_filter.len() > MAX_TOPIC_LEN - || topic_or_filter.contains('\0') + || memchr(b'\0', topic_or_filter.as_bytes()).is_some() { return false; } @@ -54,7 +56,7 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // only single '#" or '+' is allowed in last entry // invalid: sport/tennis# // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { + if last.len() != 1 && has_wildcards(last) { return false; } @@ -63,13 +65,13 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // # is not allowed in filter except as a last entry // invalid: sport/tennis#/player // invalid: sport/tennis/#/ranking - if entry.contains('#') { + if memchr(b'#', entry.as_bytes()).is_some() { return false; } // + must occupy an entire level of the filter // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { + if entry.len() > 1 && memchr(b'+', entry.as_bytes()).is_some() { return false; } } @@ -86,7 +88,7 @@ pub fn matches(topic: impl AsRef, filter: impl AsRef) -> bool { let topic = topic.as_ref(); let filter = filter.as_ref(); - if !topic.is_empty() && topic[..1].contains('$') { + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index b2b8054d..70b8b329 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -2,6 +2,7 @@ // both v3 and v5 use super::PublishProperties; pub use crate::mqttbytes::{valid_filter, valid_topic}; +use memchr::memchr; use std::{str::Utf8Error, vec}; /// This module is the place where all the protocol specifics gets abstracted @@ -64,7 +65,7 @@ pub(crate) fn validate_topic_name_and_alias( /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe pub fn matches(topic: &str, filter: &str) -> bool { - if !topic.is_empty() && topic[..1].contains('$') { + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } diff --git a/rumqttd/src/protocol/mod.rs b/rumqttd/src/protocol/mod.rs index 640f33a5..a05f3840 100644 --- a/rumqttd/src/protocol/mod.rs +++ b/rumqttd/src/protocol/mod.rs @@ -10,6 +10,7 @@ use std::{io, str::Utf8Error, string::FromUtf8Error}; /// MQTT is the core protocol that this broker supports, a lot of structs closely /// map to what MQTT specifies in its protocol use bytes::{Buf, BufMut, Bytes, BytesMut}; +use memchr::{memchr, memchr2}; use crate::Notification; @@ -597,7 +598,7 @@ pub fn qos(num: u8) -> Option { /// Checks if a topic or topic filter has wildcards pub fn has_wildcards(s: impl AsRef) -> bool { - s.as_ref().contains(['+', '#']) + memchr2(b'+', b'#', s.as_ref().as_bytes()).is_some() } /// Check if a topic is valid for PUBLISH packet. @@ -605,6 +606,8 @@ pub fn valid_topic(topic: impl AsRef) -> bool { is_valid_topic_or_filter(&topic) && !has_wildcards(topic) } +/// Maximum length of a topic or topic filter according to +/// [MQTT-4.7.3-3](https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718109) pub const MAX_TOPIC_LEN: usize = 65535; /// Check if a topic is valid to qualify as a topic name or topic filter. @@ -621,7 +624,7 @@ fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { let topic_or_filter = topic_or_filter.as_ref(); if topic_or_filter.is_empty() || topic_or_filter.len() > MAX_TOPIC_LEN - || topic_or_filter.contains('\0') + || memchr(b'\0', topic_or_filter.as_bytes()).is_some() { return false; } @@ -649,7 +652,7 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // only single '#" or '+' is allowed in last entry // invalid: sport/tennis# // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { + if last.len() != 1 && has_wildcards(last) { return false; } @@ -658,13 +661,13 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // # is not allowed in filter except as a last entry // invalid: sport/tennis#/player // invalid: sport/tennis/#/ranking - if entry.contains('#') { + if memchr(b'#', entry.as_bytes()).is_some() { return false; } // + must occupy an entire level of the filter // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { + if entry.len() > 1 && memchr(b'+', entry.as_bytes()).is_some() { return false; } } @@ -678,7 +681,7 @@ pub fn valid_filter(filter: impl AsRef) -> bool { /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe pub fn matches(topic: &str, filter: &str) -> bool { - if !topic.is_empty() && topic[..1].contains('$') { + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } From 6d307955ab94645df24f1a8cf81facc5452a6645 Mon Sep 17 00:00:00 2001 From: Arunanshu Biswas Date: Tue, 2 Apr 2024 13:20:00 +0530 Subject: [PATCH 16/16] refactor(rumqttd): avoid using unwrap --- rumqttc/src/mqttbytes/topic.rs | 4 +++- rumqttd/src/protocol/mod.rs | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index f700d9c4..f04bc9c7 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -51,7 +51,9 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // split will never return an empty iterator // even if the pattern isn't matched, the original string will be there // so it is safe to just unwrap here! - let last = hirerarchy.next().unwrap(); + let Some(last) = hirerarchy.next() else { + return false; + }; // only single '#" or '+' is allowed in last entry // invalid: sport/tennis# diff --git a/rumqttd/src/protocol/mod.rs b/rumqttd/src/protocol/mod.rs index a05f3840..85f3df2e 100644 --- a/rumqttd/src/protocol/mod.rs +++ b/rumqttd/src/protocol/mod.rs @@ -647,7 +647,9 @@ pub fn valid_filter(filter: impl AsRef) -> bool { // split will never return an empty iterator // even if the pattern isn't matched, the original string will be there // so it is safe to just unwrap here! - let last = hirerarchy.next().unwrap(); + let Some(last) = hirerarchy.next() else { + return false; + }; // only single '#" or '+' is allowed in last entry // invalid: sport/tennis#