diff --git a/Cargo.lock b/Cargo.lock index 8323dc93..c6c41260 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -695,7 +695,7 @@ checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" dependencies = [ "futures-core", "futures-sink", - "spin", + "spin 0.9.8", ] [[package]] @@ -1905,6 +1905,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + [[package]] name = "ring" version = "0.17.8" @@ -1915,8 +1930,8 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin", - "untrusted", + "spin 0.9.8", + "untrusted 0.9.0", "windows-sys 0.52.0", ] @@ -1953,6 +1968,7 @@ dependencies = [ "rustls-native-certs", "rustls-pemfile", "rustls-webpki", + "scram", "serde", "thiserror", "tokio", @@ -2053,7 +2069,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", - "ring", + "ring 0.17.8", "rustls-pki-types", "rustls-webpki", "subtle", @@ -2095,9 +2111,9 @@ version = "0.102.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" dependencies = [ - "ring", + "ring 0.17.8", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -2127,6 +2143,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scram" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7679a5e6b97bac99b2c208894ba0d34b17d9657f0b728c1cd3bf1c5f7f6ebe88" +dependencies = [ + "base64 0.13.1", + "rand", + "ring 0.16.20", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -2300,6 +2327,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spin" version = "0.9.8" @@ -2774,6 +2807,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index ce94f005..e5f6df44 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`. * `Auth` packet as per MQTT5 standards * Allow configuring the `nodelay` property of underlying TCP client with the `tcp_nodelay` field in `NetworkOptions` +* `MqttOptions::set_auth_manager` that allows users to set their own authentication manager that implements the `AuthManager` trait. +* `Client::reauth` that enables users to send `AUTH` packet for re-authentication purposes. + ### Changed diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 72dca5e8..63e2032c 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -21,6 +21,7 @@ use-rustls = ["dep:tokio-rustls", "dep:rustls-webpki", "dep:rustls-pemfile", "de use-native-tls = ["dep:tokio-native-tls", "dep:native-tls"] websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] +auth-scram = ["dep:scram"] [dependencies] futures-util = { version = "0.3", default-features = false, features = ["std", "sink"] } @@ -50,6 +51,8 @@ url = { version = "2", default-features = false, optional = true } async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } tokio-stream = "0.1.15" fixedbitset = "0.5.7" +#auth +scram = { version = "0.6.0", optional = true } [dev-dependencies] bincode = "1.3.3" @@ -59,6 +62,11 @@ pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +[[example]] +name = "async_auth_oauth" +path = "examples/async_auth_oauth.rs" +required-features = ["use-rustls"] + [[example]] name = "tls" path = "examples/tls.rs" diff --git a/rumqttc/examples/async_auth_oauth.rs b/rumqttc/examples/async_auth_oauth.rs new file mode 100644 index 00000000..f854e514 --- /dev/null +++ b/rumqttc/examples/async_auth_oauth.rs @@ -0,0 +1,64 @@ +use rumqttc::v5::mqttbytes::v5::AuthProperties; +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use rumqttc::{TlsConfiguration, Transport}; +use std::error::Error; +use std::sync::Arc; +use tokio::task; +use tokio_rustls::rustls::ClientConfig; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let pubsub_access_token = ""; + + let mut mqttoptions = MqttOptions::new("client1-session1", "MQTT hostname", 8883); + mqttoptions.set_authentication_method(Some("OAUTH2-JWT".to_string())); + mqttoptions.set_authentication_data(Some(pubsub_access_token.into())); + + // Use rustls-native-certs to load root certificates from the operating system. + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + root_cert_store.add_parsable_certificates( + rustls_native_certs::load_native_certs().expect("could not load platform certs"), + ); + + let client_config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let transport = Transport::Tls(TlsConfiguration::Rustls(Arc::new(client_config.into()))); + + mqttoptions.set_transport(transport); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + task::spawn(async move { + client.subscribe("topic1", QoS::AtLeastOnce).await.unwrap(); + client + .publish("topic1", QoS::AtLeastOnce, false, "hello world") + .await + .unwrap(); + + // Re-authentication test. + let props = AuthProperties { + method: Some("OAUTH2-JWT".to_string()), + data: Some(pubsub_access_token.into()), + reason: None, + user_properties: Vec::new(), + }; + + client.reauth(Some(props)).await.unwrap(); + }); + + loop { + let notification = eventloop.poll().await; + + match notification { + Ok(event) => println!("{:?}", event), + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} diff --git a/rumqttc/examples/async_auth_scram.rs b/rumqttc/examples/async_auth_scram.rs new file mode 100644 index 00000000..6aaff04d --- /dev/null +++ b/rumqttc/examples/async_auth_scram.rs @@ -0,0 +1,162 @@ +use bytes::Bytes; +use flume::bounded; +use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; +use rumqttc::v5::{AsyncClient, AuthManager, MqttOptions}; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; +use std::error::Error; +use std::sync::{Arc, Mutex}; +use tokio::task; + +#[derive(Debug)] +struct ScramAuthManager<'a> { + #[allow(dead_code)] + user: &'a str, + #[allow(dead_code)] + password: &'a str, + #[cfg(feature = "auth-scram")] + scram: Option>, +} + +impl<'a> ScramAuthManager<'a> { + fn new(user: &'a str, password: &'a str) -> ScramAuthManager<'a> { + ScramAuthManager { + user, + password, + #[cfg(feature = "auth-scram")] + scram: None, + } + } + + fn auth_start(&mut self) -> Result, String> { + #[cfg(feature = "auth-scram")] + { + let scram = ScramClient::new(self.user, self.password, None); + let (scram, client_first) = scram.client_first(); + self.scram = Some(scram); + + Ok(Some(client_first.into())) + } + + #[cfg(not(feature = "auth-scram"))] + Ok(Some("client first message".into())) + } +} + +impl<'a> AuthManager for ScramAuthManager<'a> { + fn auth_continue( + &mut self, + #[allow(unused_variables)] + auth_prop: Option, + ) -> Result, String> { + #[cfg(feature = "auth-scram")] + { + // Unwrap the properties. + let prop = auth_prop.unwrap(); + + // Check if the authentication method is SCRAM-SHA-256 + if let Some(auth_method) = &prop.method { + if auth_method != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } + } else { + return Err("Invalid authentication method".to_string()); + } + + if self.scram.is_none() { + return Err("Invalid state".to_string()); + } + + let scram = self.scram.take().unwrap(); + + let auth_data = String::from_utf8(prop.data.unwrap().to_vec()).unwrap(); + + // Process the server first message and reassign the SCRAM state. + let scram = match scram.handle_server_first(&auth_data) { + Ok(scram) => scram, + Err(e) => return Err(e.to_string()), + }; + + // Get the client final message and reassign the SCRAM state. + let (_, client_final) = scram.client_final(); + + Ok(Some(AuthProperties{ + method: Some("SCRAM-SHA-256".to_string()), + data: Some(client_final.into()), + reason: None, + user_properties: Vec::new(), + })) + } + + #[cfg(not(feature = "auth-scram"))] + Ok(Some(AuthProperties { + method: Some("SCRAM-SHA-256".to_string()), + data: Some("client final message".into()), + reason: None, + user_properties: Vec::new(), + })) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let mut authmanager = ScramAuthManager::new("user1", "123456"); + let client_first = authmanager.auth_start().unwrap(); + let authmanager = Arc::new(Mutex::new(authmanager)); + + let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); + mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); + mqttoptions.set_authentication_data(client_first); + mqttoptions.set_auth_manager(authmanager.clone()); + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + let (tx, rx) = bounded(1); + + task::spawn(async move { + client + .subscribe("rumqtt_auth/topic", QoS::AtLeastOnce) + .await + .unwrap(); + client + .publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world") + .await + .unwrap(); + + // Wait for the connection to be established. + rx.recv_async().await.unwrap(); + + // Reauthenticate using SCRAM-SHA-256 + let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); + let properties = AuthProperties { + method: Some("SCRAM-SHA-256".to_string()), + data: client_first, + reason: None, + user_properties: Vec::new(), + }; + client.reauth(Some(properties)).await.unwrap(); + }); + + loop { + let notification = eventloop.poll().await; + + match notification { + Ok(event) => { + println!("Event = {:?}", event); + match event { + rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { + tx.send_async("Connected").await.unwrap(); + } + _ => {} + } + } + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} diff --git a/rumqttc/examples/sync_auth_scram.rs b/rumqttc/examples/sync_auth_scram.rs new file mode 100644 index 00000000..c96b42f1 --- /dev/null +++ b/rumqttc/examples/sync_auth_scram.rs @@ -0,0 +1,157 @@ +use bytes::Bytes; +use flume::bounded; +use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; +use rumqttc::v5::{AuthManager, Client, MqttOptions}; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; +use std::error::Error; +use std::sync::{Arc, Mutex}; +use std::thread; + +#[derive(Debug)] +struct ScramAuthManager<'a> { + #[allow(dead_code)] + user: &'a str, + #[allow(dead_code)] + password: &'a str, + #[cfg(feature = "auth-scram")] + scram: Option>, +} + +impl<'a> ScramAuthManager<'a> { + fn new(user: &'a str, password: &'a str) -> ScramAuthManager<'a> { + ScramAuthManager { + user, + password, + #[cfg(feature = "auth-scram")] + scram: None, + } + } + + fn auth_start(&mut self) -> Result, String> { + #[cfg(feature = "auth-scram")] + { + let scram = ScramClient::new(self.user, self.password, None); + let (scram, client_first) = scram.client_first(); + self.scram = Some(scram); + + Ok(Some(client_first.into())) + } + + #[cfg(not(feature = "auth-scram"))] + Ok(Some("client first message".into())) + } +} + +impl<'a> AuthManager for ScramAuthManager<'a> { + fn auth_continue( + &mut self, + #[allow(unused_variables)] + auth_prop: Option, + ) -> Result, String> { + #[cfg(feature = "auth-scram")] + { + // Unwrap the properties. + let prop = auth_prop.unwrap(); + + // Check if the authentication method is SCRAM-SHA-256 + if let Some(auth_method) = &prop.method { + if auth_method != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } + } else { + return Err("Invalid authentication method".to_string()); + } + + if self.scram.is_none() { + return Err("Invalid state".to_string()); + } + + let scram = self.scram.take().unwrap(); + + let auth_data = String::from_utf8(prop.data.unwrap().to_vec()).unwrap(); + + // Process the server first message and reassign the SCRAM state. + let scram = match scram.handle_server_first(&auth_data) { + Ok(scram) => scram, + Err(e) => return Err(e.to_string()), + }; + + // Get the client final message and reassign the SCRAM state. + let (_, client_final) = scram.client_final(); + + Ok(Some(AuthProperties{ + method: Some("SCRAM-SHA-256".to_string()), + data: Some(client_final.into()), + reason: None, + user_properties: Vec::new(), + })) + } + + #[cfg(not(feature = "auth-scram"))] + Ok(Some(AuthProperties { + method: Some("SCRAM-SHA-256".to_string()), + data: Some("client final message".into()), + reason: None, + user_properties: Vec::new(), + })) + } +} + +fn main() -> Result<(), Box> { + let mut authmanager = ScramAuthManager::new("user1", "123456"); + let client_first = authmanager.auth_start().unwrap(); + let authmanager = Arc::new(Mutex::new(authmanager)); + + let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); + mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); + mqttoptions.set_authentication_data(client_first); + mqttoptions.set_auth_manager(authmanager.clone()); + let (client, mut connection) = Client::new(mqttoptions, 10); + + let (tx, rx) = bounded(1); + + thread::spawn(move || { + client + .subscribe("rumqtt_auth/topic", QoS::AtLeastOnce) + .unwrap(); + client + .publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world") + .unwrap(); + + // Wait for the connection to be established. + rx.recv().unwrap(); + + // Reauthenticate using SCRAM-SHA-256 + let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); + let properties = AuthProperties { + method: Some("SCRAM-SHA-256".to_string()), + data: client_first, + reason: None, + user_properties: Vec::new(), + }; + client.reauth(Some(properties)).unwrap(); + }); + + for (_, notification) in connection.iter().enumerate() { + match notification { + Ok(event) => { + println!("Event = {:?}", event); + match event { + rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { + tx.send("Connected").unwrap(); + } + _ => {} + } + } + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 93887bb1..3020933d 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -180,6 +180,8 @@ pub enum Outgoing { Disconnect, /// Await for an ack for more outgoing progress AwaitAck(u16), + /// Auth packet + Auth, } /// Requests by the client to mqtt event loop. Request are diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c..b1786e14 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -3,8 +3,8 @@ use std::time::Duration; use super::mqttbytes::v5::{ - Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, - Unsubscribe, UnsubscribeProperties, + Auth, AuthProperties, AuthReasonCode, Filter, PubAck, PubRec, Publish, PublishProperties, + Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; use super::mqttbytes::{valid_filter, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; @@ -56,7 +56,6 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_tx = eventloop.requests_tx.clone(); - let client = AsyncClient { request_tx }; (client, eventloop) @@ -196,6 +195,22 @@ impl AsyncClient { Ok(()) } + /// Sends a MQTT AUTH to `EventLoop` for authentication. + pub async fn reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::ReAuthenticate, properties); + let auth = Request::Auth(auth); + self.request_tx.send_async(auth).await?; + Ok(()) + } + + /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. + pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::ReAuthenticate, properties); + let auth = Request::Auth(auth); + self.request_tx.try_send(auth)?; + Ok(()) + } + /// Sends a MQTT Publish to the `EventLoop` async fn handle_publish_bytes( &self, @@ -600,6 +615,22 @@ impl Client { Ok(()) } + /// Sends a MQTT AUTH to `EventLoop` for authentication. + pub fn reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::ReAuthenticate, properties); + let auth = Request::Auth(auth); + self.client.request_tx.send(auth)?; + Ok(()) + } + + /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. + pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::ReAuthenticate, properties); + let auth = Request::Auth(auth); + self.client.request_tx.try_send(auth)?; + Ok(()) + } + /// Sends a MQTT Subscribe to the `EventLoop` fn handle_subscribe>( &self, @@ -896,4 +927,21 @@ mod test { .expect("Should be able to publish"); let _ = rx.try_recv().expect("Should have message"); } + + #[test] + fn test_reauth() { + let (client, mut connection) = + Client::new(MqttOptions::new("test-1", "localhost", 1883), 10); + let props = AuthProperties { + method: Some("test".to_string()), + data: Some(Bytes::from("test")), + reason: None, + user_properties: vec![], + }; + let _ = client.reauth(Some(props.clone())).expect("Should be able to reauth"); + let _ = connection.iter().next().expect("Should have event"); + + let _ = client.try_reauth(Some(props.clone())).expect("Should be able to reauth"); + let _ = connection.iter().next().expect("Should have event"); + } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index cd0568ad..a3cf2379 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -55,6 +55,8 @@ pub enum ConnectionError { NotConnAck(Box), #[error("Requests done")] RequestsDone, + #[error("Auth processing error")] + AuthProcessingError, #[cfg(feature = "websocket")] #[error("Invalid Url: {0}")] InvalidUrl(#[from] UrlError), @@ -102,9 +104,11 @@ impl EventLoop { let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; + let auth_manager = options.auth_manager(); + EventLoop { options, - state: MqttState::new(inflight_limit, manual_acks), + state: MqttState::new(inflight_limit, manual_acks, auth_manager), requests_tx, requests_rx, pending, @@ -146,7 +150,7 @@ impl EventLoop { if self.network.is_none() { let (network, connack) = time::timeout( Duration::from_secs(self.options.connection_timeout()), - connect(&mut self.options), + connect(&mut self.options, &mut self.state), ) .await??; // Last session might contain packets which aren't acked. If it's a new session, clear the pending packets. @@ -278,12 +282,15 @@ impl EventLoop { /// the stream. /// This function (for convenience) includes internal delays for users to perform internal sleeps /// between re-connections so that cancel semantics can be used during this sleep -async fn connect(options: &mut MqttOptions) -> Result<(Network, ConnAck), ConnectionError> { +async fn connect( + options: &mut MqttOptions, + state: &mut MqttState, +) -> Result<(Network, ConnAck), ConnectionError> { // connect to the broker let mut network = network_connect(options).await?; // make MQTT connection request (which internally awaits for ack) - let connack = mqtt_connect(options, &mut network).await?; + let connack = mqtt_connect(options, &mut network, state).await?; Ok((network, connack)) } @@ -397,6 +404,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result { let packet = Packet::Connect( Connect { @@ -414,22 +422,34 @@ async fn mqtt_connect( network.flush().await?; // validate connack - match network.read().await? { - Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - if let Some(props) = &connack.properties { - if let Some(keep_alive) = props.server_keep_alive { - options.keep_alive = Duration::from_secs(keep_alive as u64); - } - network.set_max_outgoing_size(props.max_packet_size); + loop { + match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + if let Some(props) = &connack.properties { + if let Some(keep_alive) = props.server_keep_alive { + options.keep_alive = Duration::from_secs(keep_alive as u64); + } + network.set_max_outgoing_size(props.max_packet_size); - // Override local session_expiry_interval value if set by server. - if props.session_expiry_interval.is_some() { - options.set_session_expiry_interval(props.session_expiry_interval); + // Override local session_expiry_interval value if set by server. + if props.session_expiry_interval.is_some() { + options.set_session_expiry_interval(props.session_expiry_interval); + } + } + return Ok(connack); + } + Incoming::ConnAck(connack) => { + return Err(ConnectionError::ConnectionRefused(connack.code)) + } + Incoming::Auth(auth) => { + if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? { + network.write(outgoing).await?; + network.flush().await?; + } else { + return Err(ConnectionError::AuthProcessingError); } } - Ok(connack) + packet => return Err(ConnectionError::NotConnAck(Box::new(packet))), } - Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), - packet => Err(ConnectionError::NotConnAck(Box::new(packet))), } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2518a93f..d6280def 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,11 +1,12 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; +use std::sync::{Arc, Mutex}; use std::time::Duration; + #[cfg(feature = "websocket")] use std::{ future::{Future, IntoFuture}, pin::Pin, - sync::Arc, }; mod client; @@ -31,6 +32,24 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +pub trait AuthManager: std::fmt::Debug + Send { + /// Process authentication data received from the server and generate authentication data to be sent back. + /// + /// # Arguments + /// + /// * `auth_prop` - The authentication Properties received from the server. + /// + /// # Returns + /// + /// * `Ok(auth_prop)` - The authentication Properties to be sent back to the server. + /// * `Err(error_message)` - An error indicating that the authentication process has failed or terminated. + + fn auth_continue( + &mut self, + auth_prop: Option, + ) -> Result, String>; +} + /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Clone, Debug, PartialEq, Eq)] @@ -47,6 +66,7 @@ pub enum Request { Unsubscribe(Unsubscribe), UnsubAck(UnsubAck), Disconnect, + Auth(Auth), } #[cfg(feature = "websocket")] @@ -104,6 +124,8 @@ pub struct MqttOptions { outgoing_inflight_upper_limit: Option, #[cfg(feature = "websocket")] request_modifier: Option, + + auth_manager: Option>>, } impl MqttOptions { @@ -139,6 +161,7 @@ impl MqttOptions { outgoing_inflight_upper_limit: None, #[cfg(feature = "websocket")] request_modifier: None, + auth_manager: None, } } @@ -547,6 +570,17 @@ impl MqttOptions { pub fn get_outgoing_inflight_upper_limit(&self) -> Option { self.outgoing_inflight_upper_limit } + + pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { + self.auth_manager = Some(auth_manager); + self + } + + pub fn auth_manager(&self) -> Option>> { + self.auth_manager.as_ref()?; + + self.auth_manager.clone() + } } #[cfg(feature = "url")] diff --git a/rumqttc/src/v5/mqttbytes/v5/auth.rs b/rumqttc/src/v5/mqttbytes/v5/auth.rs index 52ff6284..978e6212 100644 --- a/rumqttc/src/v5/mqttbytes/v5/auth.rs +++ b/rumqttc/src/v5/mqttbytes/v5/auth.rs @@ -10,7 +10,7 @@ use super::{ pub enum AuthReasonCode { Success, Continue, - ReAuthentivate, + ReAuthenticate, } impl AuthReasonCode { @@ -19,7 +19,7 @@ impl AuthReasonCode { let code = match reason_code { 0x00 => AuthReasonCode::Success, 0x18 => AuthReasonCode::Continue, - 0x19 => AuthReasonCode::ReAuthentivate, + 0x19 => AuthReasonCode::ReAuthenticate, _ => return Err(Error::MalformedPacket), }; @@ -30,7 +30,7 @@ impl AuthReasonCode { let reason_code = match self { AuthReasonCode::Success => 0x00, AuthReasonCode::Continue => 0x18, - AuthReasonCode::ReAuthentivate => 0x19, + AuthReasonCode::ReAuthenticate => 0x19, }; buffer.put_u8(reason_code); @@ -47,12 +47,20 @@ pub struct Auth { } impl Auth { + pub fn new(code: AuthReasonCode, properties: Option) -> Self { + Self { code, properties } + } + fn len(&self) -> usize { - let mut len = 1 // reason code - + 1; // property len + let mut len = 1; - if let Some(properties) = &self.properties { - len += properties.len(); + if let Some(p) = &self.properties { + let properties_len = p.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; } len @@ -107,22 +115,22 @@ impl AuthProperties { if let Some(method) = &self.method { let m_len = method.len(); - len += 1 + m_len; + len += 1 + 2 + m_len; } if let Some(data) = &self.data { let d_len = data.len(); - len += 1 + len_len(d_len) + d_len; + len += 1 + 2 + d_len; } if let Some(reason) = &self.reason { let r_len = reason.len(); - len += 1 + r_len; + len += 1 + 2 + r_len; } for (key, value) in self.user_properties.iter() { let p_len = key.len() + value.len(); - len += 1 + p_len; + len += 1 + 4 + p_len; } len @@ -146,7 +154,7 @@ impl AuthProperties { match property(prop)? { PropertyType::AuthenticationMethod => { let method = read_mqtt_string(bytes)?; - cursor += method.len(); + cursor += 2 + method.len(); props.method = Some(method); } PropertyType::AuthenticationData => { @@ -156,7 +164,7 @@ impl AuthProperties { } PropertyType::ReasonString => { let reason = read_mqtt_string(bytes)?; - cursor += reason.len(); + cursor += 2 + reason.len(); props.reason = Some(reason); } PropertyType::UserProperty => { @@ -200,3 +208,33 @@ impl AuthProperties { Ok(()) } } + +#[cfg(test)] +mod test { + use super::super::test::{USER_PROP_KEY, USER_PROP_VAL}; + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + #[test] + fn length_calculation() { + let mut dummy_bytes = BytesMut::new(); + // Use user_properties to pad the size to exceed ~128 bytes to make the + // remaining_length field in the packet be 2 bytes long. + let auth_props = AuthProperties { + method: Some("Authentication Method".into()), + data: Some("Authentication Data".into()), + reason: None, + user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())], + }; + + let auth_pkt = Auth::new(AuthReasonCode::Continue, Some(auth_props)); + + let size_from_size = auth_pkt.size(); + let size_from_write = auth_pkt.write(&mut dummy_bytes).unwrap(); + let size_from_bytes = dummy_bytes.len(); + + assert_eq!(size_from_write, size_from_bytes); + assert_eq!(size_from_size, size_from_bytes); + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 3256721d..48b33dd5 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,7 +1,7 @@ use std::slice::Iter; pub use self::{ - auth::Auth, + auth::{Auth, AuthProperties, AuthReasonCode}, codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, @@ -126,6 +126,10 @@ impl Packet { let disconnect = Disconnect::read(fixed_header, packet)?; Packet::Disconnect(disconnect) } + PacketType::Auth => { + let auth = Auth::read(fixed_header, packet)?; + Packet::Auth(auth) + } }; Ok(packet) @@ -199,6 +203,7 @@ pub enum PacketType { PingReq, PingResp, Disconnect, + Auth, } #[repr(u8)] @@ -285,6 +290,7 @@ impl FixedHeader { 12 => Ok(PacketType::PingReq), 13 => Ok(PacketType::PingResp), 14 => Ok(PacketType::Disconnect), + 15 => Ok(PacketType::Auth), _ => Err(Error::InvalidPacketType(num)), } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 6f37a471..22b7d8e5 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,15 +1,17 @@ use super::mqttbytes::v5::{ - ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, - PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, + Auth, AuthReasonCode, ConnAck, ConnectReturnCode, Disconnect, + DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, + PubRecReason, PubRel, PubRelReason, Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, + UnsubAckReason, Unsubscribe, }; use super::mqttbytes::{self, Error as MqttError, QoS}; -use super::{Event, Incoming, Outgoing, Request}; +use super::{AuthManager, Event, Incoming, Outgoing, Request}; use bytes::Bytes; use fixedbitset::FixedBitSet; use std::collections::{HashMap, VecDeque}; +use std::sync::{Arc, Mutex}; use std::{io, time::Instant}; /// Errors during state handling @@ -67,6 +69,10 @@ pub enum StateError { ConnFail { reason: ConnectReturnCode }, #[error("Connection closed by peer abruptly")] ConnectionAborted, + #[error("Authentication error: {0}")] + AuthError(String), + #[error("Auth Manager not set")] + AuthManagerNotSet, } impl From for StateError { @@ -122,13 +128,19 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, + /// Authentication manager + auth_manager: Option>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + pub fn new( + max_inflight: u16, + manual_acks: bool, + auth_manager: Option>>, + ) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -149,6 +161,7 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, + auth_manager, } } @@ -200,6 +213,7 @@ impl MqttState { } Request::PubAck(puback) => self.outgoing_puback(puback)?, Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + Request::Auth(auth) => self.outgoing_auth(auth)?, _ => unimplemented!(), }; @@ -228,6 +242,7 @@ impl MqttState { Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, + Incoming::Auth(auth) => self.handle_incoming_auth(auth)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); @@ -471,6 +486,37 @@ impl MqttState { Ok(None) } + fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { + match auth.code { + AuthReasonCode::Success => Ok(None), + AuthReasonCode::Continue => { + let props = auth.properties.clone(); + + // Check if auth manager is set + if self.auth_manager.is_none() { + return Err(StateError::AuthManagerNotSet); + } + + let auth_manager = self.auth_manager.clone().unwrap(); + + // Call auth_continue method of auth manager + let out_auth_props = match auth_manager + .lock() + .unwrap() + .auth_continue(props) + { + Ok(data) => data, + Err(err) => return Err(StateError::AuthError(err)), + }; + + let client_auth = Auth::new(AuthReasonCode::Continue, out_auth_props); + + self.outgoing_auth(client_auth) + } + _ => Err(StateError::AuthError("Authentication Failed!".to_string())), + } + } + /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { @@ -639,6 +685,17 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } + fn outgoing_auth(&mut self, auth: Auth) -> Result, StateError> { + let props = auth.properties.as_ref().unwrap(); + debug!( + "Auth packet sent. Auth Method: {:?}. Auth Data: {:?}", + props.method, props.data + ); + let event = Event::Outgoing(Outgoing::Auth); + self.events.push_back(event); + Ok(Some(Packet::Auth(auth))) + } + fn check_collision(&mut self, pkid: u16) -> Option { if let Some(publish) = &self.collision { if publish.pkid == pkid { @@ -711,7 +768,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(u16::MAX, false) + MqttState::new(u16::MAX, false, None) } #[test] @@ -772,7 +829,7 @@ mod test { #[test] fn outgoing_publish_with_max_inflight_is_ok() { - let mut mqtt = MqttState::new(2, false); + let mut mqtt = MqttState::new(2, false, None); // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce);