diff --git a/Cargo.lock b/Cargo.lock index b7396d43..1f33eb37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1208,9 +1208,9 @@ checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "memmap2" @@ -1940,6 +1940,7 @@ dependencies = [ "http 1.0.0", "log", "matches", + "memchr", "native-tls", "pretty_assertions", "pretty_env_logger", @@ -1968,6 +1969,7 @@ dependencies = [ "config", "flume", "futures-util", + "memchr", "metrics", "metrics-exporter-prometheus", "parking_lot", diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 1045cfcf..a7eb3751 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -25,8 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -* Validate filters while creating subscription requests. -* Make v4::Connect::write return correct value +- Make v4::Connect::write return correct value +- Validate topic filter and topic name for MQTT v3 and v5 during `Publish`, `Subscribe`, and `Unsubscribe` operations. ### Security diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index bba64822..4b219307 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -49,6 +49,7 @@ url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } tokio-stream = "0.1.15" +memchr = "2.7.2" [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 15cd5f5a..1b15bf1d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -141,9 +141,13 @@ impl AsyncClient { where S: Into, { - let mut publish = Publish::from_bytes(topic, qos, payload); + let topic = topic.into(); + let mut publish = Publish::from_bytes(&topic, qos, payload); publish.retain = retain; let publish = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::Request(publish)); + } self.request_tx.send_async(publish).await?; Ok(()) } @@ -153,7 +157,7 @@ impl AsyncClient { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -165,8 +169,8 @@ impl AsyncClient { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::TryRequest(request)); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); } self.request_tx.try_send(request)?; Ok(()) @@ -177,11 +181,10 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -193,12 +196,11 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + if is_invalid_filter { + return Err(ClientError::Request(request)); } self.request_tx.try_send(request)?; Ok(()) @@ -206,16 +208,24 @@ impl AsyncClient { /// Sends a MQTT Unsubscribe to the `EventLoop` pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(&topic) { + return Err(ClientError::Request(request)); + } self.request_tx.send_async(request).await?; Ok(()) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(&topic) { + return Err(ClientError::Request(request)); + } self.request_tx.try_send(request)?; Ok(()) } @@ -319,8 +329,7 @@ impl Client { S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -344,7 +353,7 @@ impl Client { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -362,11 +371,10 @@ impl Client { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -382,8 +390,12 @@ impl Client { /// Sends a MQTT Unsubscribe to the `EventLoop` pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.client.request_tx.send(request)?; Ok(()) } diff --git a/rumqttc/src/mqttbytes/topic.rs b/rumqttc/src/mqttbytes/topic.rs index a66c35f8..f04bc9c7 100644 --- a/rumqttc/src/mqttbytes/topic.rs +++ b/rumqttc/src/mqttbytes/topic.rs @@ -1,12 +1,35 @@ +use memchr::{memchr, memchr2}; + +/// Maximum length of a topic or topic filter according to +/// [MQTT-4.7.3-3](https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718109) +pub const MAX_TOPIC_LEN: usize = 65535; + /// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') +pub fn has_wildcards(s: impl AsRef) -> bool { + memchr2(b'+', b'#', s.as_ref().as_bytes()).is_some() +} + +/// Check if a topic is valid for PUBLISH packet. +pub fn valid_topic(topic: impl AsRef) -> bool { + can_be_topic_or_filter(&topic) && !has_wildcards(topic) } -/// Checks if a topic is valid -pub fn valid_topic(topic: &str) -> bool { - // topic can't contain wildcards - if topic.contains('+') || topic.contains('#') { +/// Check if a topic is valid to qualify as a topic name or topic filter. +/// +/// According to MQTT v3 Spec, it has to follow the following rules: +/// 1. All Topic Names and Topic Filters MUST be at least one character long [MQTT-4.7.3-1] +/// 2. Topic Names and Topic Filters are case sensitive +/// 3. Topic Names and Topic Filters can include the space character +/// 4. A leading or trailing `/` creates a distinct Topic Name or Topic Filter +/// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid +/// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] +/// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. +fn can_be_topic_or_filter(topic_or_filter: impl AsRef) -> bool { + let topic_or_filter = topic_or_filter.as_ref(); + if topic_or_filter.is_empty() + || topic_or_filter.len() > MAX_TOPIC_LEN + || memchr(b'\0', topic_or_filter.as_bytes()).is_some() + { return false; } @@ -16,8 +39,9 @@ pub fn valid_topic(topic: &str) -> bool { /// Checks if the filter is valid /// /// -pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { +pub fn valid_filter(filter: impl AsRef) -> bool { + let filter = filter.as_ref(); + if !can_be_topic_or_filter(filter) { return false; } @@ -27,12 +51,14 @@ pub fn valid_filter(filter: &str) -> bool { // split will never return an empty iterator // even if the pattern isn't matched, the original string will be there // so it is safe to just unwrap here! - let last = hirerarchy.next().unwrap(); + let Some(last) = hirerarchy.next() else { + return false; + }; // only single '#" or '+' is allowed in last entry // invalid: sport/tennis# // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { + if last.len() != 1 && has_wildcards(last) { return false; } @@ -41,13 +67,13 @@ pub fn valid_filter(filter: &str) -> bool { // # is not allowed in filter except as a last entry // invalid: sport/tennis#/player // invalid: sport/tennis/#/ranking - if entry.contains('#') { + if memchr(b'#', entry.as_bytes()).is_some() { return false; } // + must occupy an entire level of the filter // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { + if entry.len() > 1 && memchr(b'+', entry.as_bytes()).is_some() { return false; } } @@ -60,8 +86,11 @@ pub fn valid_filter(filter: &str) -> bool { /// **NOTE**: 'topic' is a misnomer in the arg. this can also be used to match 2 wild subscriptions /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe -pub fn matches(topic: &str, filter: &str) -> bool { - if !topic.is_empty() && topic[..1].contains('$') { +pub fn matches(topic: impl AsRef, filter: impl AsRef) -> bool { + let topic = topic.as_ref(); + let filter = filter.as_ref(); + + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 42ddb57b..c32ce27b 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -21,13 +21,11 @@ impl Subscribe { } } - pub fn new_many(topics: T) -> Subscribe - where - T: IntoIterator, - { - let filters: Vec = topics.into_iter().collect(); - - Subscribe { pkid: 0, filters } + pub fn new_many(topics: impl IntoIterator) -> Subscribe { + Subscribe { + pkid: 0, + filters: topics.into_iter().collect(), + } } pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c..551b080e 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -6,9 +6,8 @@ use super::mqttbytes::v5::{ Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; -use super::mqttbytes::{valid_filter, QoS}; +use super::mqttbytes::{valid_filter, validate_topic_name_and_alias, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::valid_topic; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -84,10 +83,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); - let mut publish = Publish::new(&topic, qos, payload, properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_valid { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; @@ -138,10 +138,11 @@ impl AsyncClient { P: Into, { let topic = topic.into(); + let is_valid = validate_topic_name_and_alias(&topic, &properties); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_valid { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; @@ -209,10 +210,11 @@ impl AsyncClient { S: Into, { let topic = topic.into(); - let mut publish = Publish::new(&topic, qos, payload, properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_valid { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; @@ -255,11 +257,11 @@ impl AsyncClient { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); - let request: Request = Request::Subscribe(subscribe); - if !is_filter_valid { + let request = Request::Subscribe(subscribe); + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; @@ -286,11 +288,11 @@ impl AsyncClient { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); let request = Request::Subscribe(subscribe); - if !is_filter_valid { + if !valid_filter(topic) { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; @@ -319,14 +321,12 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics, properties); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_invalid_filter { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; Ok(()) } @@ -358,11 +358,10 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics, properties); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_invalid_filter { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; @@ -393,8 +392,12 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.request_tx.send_async(request).await?; Ok(()) } @@ -417,8 +420,12 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::TryRequest(request)); + } self.request_tx.try_send(request)?; Ok(()) } @@ -515,10 +522,11 @@ impl Client { P: Into, { let topic = topic.into(); - let mut publish = Publish::new(&topic, qos, payload, properties); + let is_valid = validate_topic_name_and_alias(&topic, &properties); + let mut publish = Publish::new(topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { + if !is_valid { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; @@ -607,11 +615,11 @@ impl Client { qos: QoS, properties: Option, ) -> Result<(), ClientError> { - let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); + let topic = topic.into(); + let filter = Filter::new(&topic, qos); let subscribe = Subscribe::new(filter, properties); let request = Request::Subscribe(subscribe); - if !is_filter_valid { + if !valid_filter(topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -655,11 +663,10 @@ impl Client { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics, properties); + let is_invalid_filter = subscribe.filters.iter().any(|t| !valid_filter(&t.path)); let request = Request::Subscribe(subscribe); - if !is_valid_filters { + if is_invalid_filter { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; @@ -709,8 +716,12 @@ impl Client { topic: S, properties: Option, ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + let topic = topic.into(); + let unsubscribe = Unsubscribe::new(&topic, properties); let request = Request::Unsubscribe(unsubscribe); + if !valid_filter(topic) { + return Err(ClientError::Request(request)); + } self.client.request_tx.send(request)?; Ok(()) } diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index c205aaa9..70b8b329 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -1,3 +1,8 @@ +// reexport valid_filter and valid_topic since they are identical in nature for +// both v3 and v5 +use super::PublishProperties; +pub use crate::mqttbytes::{valid_filter, valid_topic}; +use memchr::memchr; use std::{str::Utf8Error, vec}; /// This module is the place where all the protocol specifics gets abstracted @@ -32,61 +37,26 @@ pub fn qos(num: u8) -> Option { } } -/// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') -} - -/// Checks if a topic is valid -pub fn valid_topic(topic: &str) -> bool { - // topic can't contain wildcards - if topic.contains('+') || topic.contains('#') { - return false; - } - - true -} - -/// Checks if the filter is valid -/// -/// -pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { - return false; - } - - // rev() is used so we can easily get the last entry - let mut hirerarchy = filter.split('/').rev(); - - // split will never return an empty iterator - // even if the pattern isn't matched, the original string will be there - // so it is safe to just unwrap here! - let last = hirerarchy.next().unwrap(); - - // only single '#" or '+' is allowed in last entry - // invalid: sport/tennis# - // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { - return false; - } - - // remaining entries - for entry in hirerarchy { - // # is not allowed in filter except as a last entry - // invalid: sport/tennis#/player - // invalid: sport/tennis/#/ranking - if entry.contains('#') { - return false; - } - - // + must occupy an entire level of the filter - // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { - return false; - } - } - - true +pub(crate) fn validate_topic_name_and_alias( + topic: impl AsRef, + properties: &Option, +) -> bool { + let topic = topic.as_ref(); + let is_topic_empty = topic.is_empty(); + // The topic alias is considered valid only if it is greater than zero. + // If it is not supplied, it is still considered valid because in that + // case, the validity depends on the topic name itself. + let is_topic_valid = valid_topic(topic); + let is_alias_given = properties + .as_ref() + .map_or(false, |p| p.topic_alias.is_some()); + let is_alias_valid = properties + .as_ref() + .map_or(false, |p| p.topic_alias.map_or(false, |a| a > 0)); + + (is_topic_valid || is_topic_empty) + && (!is_topic_empty || is_alias_given) + && (!is_alias_given || is_alias_valid) } /// Checks if topic matches a filter. topic and filter validation isn't done here. @@ -95,7 +65,7 @@ pub fn valid_filter(filter: &str) -> bool { /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe pub fn matches(topic: &str, filter: &str) -> bool { - if !topic.is_empty() && topic[..1].contains('$') { + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 4167cd67..7851ff4e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -18,10 +18,10 @@ impl Subscribe { } } - pub fn new_many(filters: F, properties: Option) -> Self - where - F: IntoIterator, - { + pub fn new_many( + filters: impl IntoIterator, + properties: Option, + ) -> Self { Self { filters: filters.into_iter().collect(), properties, diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 854aa7b0..5d415dc3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -65,7 +65,7 @@ pub enum StateError { #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, #[error("Connection closed by peer abruptly")] - ConnectionAborted + ConnectionAborted, } impl From for StateError { diff --git a/rumqttd/CHANGELOG.md b/rumqttd/CHANGELOG.md index 6d981ffd..68d05300 100644 --- a/rumqttd/CHANGELOG.md +++ b/rumqttd/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - record client id for remote link's span - session present flag in connack - Make write method return the number of bytes written correctly everywhere +- Validate topic name and topic filter in MQTT v3 and v5 `Publish`, `Subscribe` and `Unsubscribe` packets. ### Security - Implement constant-time password comparison in authentication logic @@ -177,7 +178,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [rumqttd 0.12.4] - 01-02-2023 ### Fixed -- Client id with tenant prefix should be set globally (#564) +- Client id with tenant prefix should be set globally (#564) ## [rumqttd 0.12.3] - 23-01-2023 diff --git a/rumqttd/Cargo.toml b/rumqttd/Cargo.toml index 2de4f420..e1ac0902 100644 --- a/rumqttd/Cargo.toml +++ b/rumqttd/Cargo.toml @@ -39,6 +39,7 @@ axum = "0.7.4" rand = "0.8.5" uuid = { version = "1.7.0", features = ["v4", "fast-rng"] } subtle = "2.5" +memchr = "2.7.2" [features] default = ["use-rustls", "websocket"] diff --git a/rumqttd/src/protocol/mod.rs b/rumqttd/src/protocol/mod.rs index a27543d1..85f3df2e 100644 --- a/rumqttd/src/protocol/mod.rs +++ b/rumqttd/src/protocol/mod.rs @@ -10,6 +10,7 @@ use std::{io, str::Utf8Error, string::FromUtf8Error}; /// MQTT is the core protocol that this broker supports, a lot of structs closely /// map to what MQTT specifies in its protocol use bytes::{Buf, BufMut, Bytes, BytesMut}; +use memchr::{memchr, memchr2}; use crate::Notification; @@ -596,17 +597,35 @@ pub fn qos(num: u8) -> Option { } /// Checks if a topic or topic filter has wildcards -pub fn has_wildcards(s: &str) -> bool { - s.contains('+') || s.contains('#') +pub fn has_wildcards(s: impl AsRef) -> bool { + memchr2(b'+', b'#', s.as_ref().as_bytes()).is_some() } -/// Checks if a topic is valid -pub fn valid_topic(topic: &str) -> bool { - if topic.contains('+') { - return false; - } +/// Check if a topic is valid for PUBLISH packet. +pub fn valid_topic(topic: impl AsRef) -> bool { + is_valid_topic_or_filter(&topic) && !has_wildcards(topic) +} - if topic.contains('#') { +/// Maximum length of a topic or topic filter according to +/// [MQTT-4.7.3-3](https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718109) +pub const MAX_TOPIC_LEN: usize = 65535; + +/// Check if a topic is valid to qualify as a topic name or topic filter. +/// +/// According to MQTT v3 Spec, it has to follow the following rules: +/// 1. All Topic Names and Topic Filters MUST be at least one character long [MQTT-4.7.3-1] +/// 2. Topic Names and Topic Filters are case sensitive +/// 3. Topic Names and Topic Filters can include the space character +/// 4. A leading or trailing `/` creates a distinct Topic Name or Topic Filter +/// 5. A Topic Name or Topic Filter consisting only of the `/` character is valid +/// 6. Topic Names and Topic Filters MUST NOT include the null character (Unicode U+0000) [MQTT-4.7.3-2] +/// 7. Topic Names and Topic Filters are UTF-8 encoded strings, they MUST NOT encode to more than 65535 bytes. +fn is_valid_topic_or_filter(topic_or_filter: impl AsRef) -> bool { + let topic_or_filter = topic_or_filter.as_ref(); + if topic_or_filter.is_empty() + || topic_or_filter.len() > MAX_TOPIC_LEN + || memchr(b'\0', topic_or_filter.as_bytes()).is_some() + { return false; } @@ -616,32 +635,41 @@ pub fn valid_topic(topic: &str) -> bool { /// Checks if the filter is valid /// /// -pub fn valid_filter(filter: &str) -> bool { - if filter.is_empty() { +pub fn valid_filter(filter: impl AsRef) -> bool { + let filter = filter.as_ref(); + if !is_valid_topic_or_filter(filter) { + return false; + } + + // rev() is used so we can easily get the last entry + let mut hirerarchy = filter.split('/').rev(); + + // split will never return an empty iterator + // even if the pattern isn't matched, the original string will be there + // so it is safe to just unwrap here! + let Some(last) = hirerarchy.next() else { + return false; + }; + + // only single '#" or '+' is allowed in last entry + // invalid: sport/tennis# + // invalid: sport/++ + if last.len() != 1 && has_wildcards(last) { return false; } - let hirerarchy = filter.split('/').collect::>(); - if let Some((last, remaining)) = hirerarchy.split_last() { - for entry in remaining.iter() { - // # is not allowed in filter except as a last entry - // invalid: sport/tennis#/player - // invalid: sport/tennis/#/ranking - if entry.contains('#') { - return false; - } - - // + must occupy an entire level of the filter - // invalid: sport+ - if entry.len() > 1 && entry.contains('+') { - return false; - } + // remaining entries + for entry in hirerarchy { + // # is not allowed in filter except as a last entry + // invalid: sport/tennis#/player + // invalid: sport/tennis/#/ranking + if memchr(b'#', entry.as_bytes()).is_some() { + return false; } - // only single '#" or '+' is allowed in last entry - // invalid: sport/tennis# - // invalid: sport/++ - if last.len() != 1 && (last.contains('#') || last.contains('+')) { + // + must occupy an entire level of the filter + // invalid: sport+ + if entry.len() > 1 && memchr(b'+', entry.as_bytes()).is_some() { return false; } } @@ -655,7 +683,7 @@ pub fn valid_filter(filter: &str) -> bool { /// **NOTE**: make sure a topic is validated during a publish and filter is validated /// during a subscribe pub fn matches(topic: &str, filter: &str) -> bool { - if !topic.is_empty() && topic[..1].contains('$') { + if !topic.is_empty() && memchr(b'$', topic[..1].as_bytes()).is_some() { return false; } diff --git a/rumqttd/src/protocol/v4/publish.rs b/rumqttd/src/protocol/v4/publish.rs index b64cb314..0839bde6 100644 --- a/rumqttd/src/protocol/v4/publish.rs +++ b/rumqttd/src/protocol/v4/publish.rs @@ -1,5 +1,6 @@ use super::*; use bytes::{Buf, Bytes}; +use core::str::from_utf8; fn len(publish: &Publish) -> usize { let len = 2 + publish.topic.len() + publish.payload.len(); @@ -19,6 +20,10 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result 0 let pkid = match qos { QoS::AtMostOnce => 0, diff --git a/rumqttd/src/protocol/v4/subscribe.rs b/rumqttd/src/protocol/v4/subscribe.rs index 2f75299e..853c3744 100644 --- a/rumqttd/src/protocol/v4/subscribe.rs +++ b/rumqttd/src/protocol/v4/subscribe.rs @@ -13,6 +13,10 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result Err(Error::EmptySubscription), _ => Ok(Subscribe { pkid, filters }), diff --git a/rumqttd/src/protocol/v4/unsubscribe.rs b/rumqttd/src/protocol/v4/unsubscribe.rs index c1fb8dd3..0760bad1 100644 --- a/rumqttd/src/protocol/v4/unsubscribe.rs +++ b/rumqttd/src/protocol/v4/unsubscribe.rs @@ -11,6 +11,9 @@ pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result 0 { let topic_filter = read_mqtt_string(&mut bytes)?; + if !valid_filter(&topic_filter) { + return Err(Error::MalformedPacket); + } payload_bytes -= topic_filter.len() + 2; filters.push(topic_filter); } diff --git a/rumqttd/src/protocol/v5/publish.rs b/rumqttd/src/protocol/v5/publish.rs index 05313367..a11468f6 100644 --- a/rumqttd/src/protocol/v5/publish.rs +++ b/rumqttd/src/protocol/v5/publish.rs @@ -1,6 +1,7 @@ use super::*; use bytes::{Buf, Bytes}; use core::fmt; +use std::str::from_utf8; pub fn len(publish: &Publish, properties: &Option) -> usize { let mut len = 2 + publish.topic.len(); @@ -34,6 +35,10 @@ pub fn read( bytes.advance(variable_header_index); let topic = read_mqtt_bytes(&mut bytes)?; + if !valid_topic(from_utf8(&topic).map_err(|_| Error::TopicNotUtf8)?) { + return Err(Error::MalformedPacket); + } + // Packet identifier exists where QoS > 0 let pkid = match qos { QoS::AtMostOnce => 0, diff --git a/rumqttd/src/protocol/v5/subscribe.rs b/rumqttd/src/protocol/v5/subscribe.rs index 686d3686..e7117c25 100644 --- a/rumqttd/src/protocol/v5/subscribe.rs +++ b/rumqttd/src/protocol/v5/subscribe.rs @@ -32,6 +32,10 @@ pub fn read( while bytes.has_remaining() { let path = read_mqtt_string(&mut bytes)?; + if !valid_filter(&path) { + return Err(Error::MalformedPacket); + } + let options = read_u8(&mut bytes)?; let requested_qos = options & 0b0000_0011; diff --git a/rumqttd/src/protocol/v5/unsubscribe.rs b/rumqttd/src/protocol/v5/unsubscribe.rs index 1245bf7f..6f2460c9 100644 --- a/rumqttd/src/protocol/v5/unsubscribe.rs +++ b/rumqttd/src/protocol/v5/unsubscribe.rs @@ -31,6 +31,9 @@ pub fn read( let mut filters = Vec::with_capacity(1); while bytes.has_remaining() { let filter = read_mqtt_string(&mut bytes)?; + if !valid_filter(&filter) { + return Err(Error::MalformedPacket); + } filters.push(filter); }