diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index eca367f..1ea778c 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -8,7 +8,7 @@ jobs: fail-fast: false matrix: version: - - 1.66.0 # MSRV + # - 1.66.0 # MSRV - stable - nightly diff --git a/CHANGES.md b/CHANGES.md index 004ed4e..ffd6f5b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.12.0] - 2023-09-18 + +* Refactor MqttError type + ## [0.11.4] - 2023-08-10 * Update ntex deps diff --git a/Cargo.toml b/Cargo.toml index f2ddb73..ed87d3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "0.11.4" +version = "0.12.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -9,12 +9,12 @@ categories = ["network-programming"] keywords = ["MQTT", "IoT", "messaging"] license = "MIT" exclude = [".gitignore", ".travis.yml", ".cargo/config"] -edition = "2018" +edition = "2021" [dependencies] -ntex = "0.7.3" -ntex-util = "0.3.1" -bitflags = "1.3" +ntex = "0.7.4" +ntex-util = "0.3.2" +bitflags = "2.4" log = "0.4" pin-project-lite = "0.2" serde = { version = "1.0", features = ["derive"] } @@ -23,15 +23,9 @@ thiserror = "1.0" [dev-dependencies] env_logger = "0.10" -ntex-tls = "0.3.0" +ntex-tls = "0.3.1" rustls = "0.21" rustls-pemfile = "1.0" openssl = "0.10" -ntex = { version = "0.7.3", features = ["tokio", "rustls", "openssl"] } -test-case = "3" - -[profile.dev] -lto = "off" # cannot build tests with "thin" - -[profile.test] -lto = "off" # cannot build tests with "thin" +test-case = "3.2" +ntex = { version = "0.7.4", features = ["tokio", "rustls", "openssl"] } diff --git a/examples/mqtt-ws-server.rs b/examples/mqtt-ws-server.rs index d61479a..3bf564c 100644 --- a/examples/mqtt-ws-server.rs +++ b/examples/mqtt-ws-server.rs @@ -6,7 +6,7 @@ use ntex::io::{Filter, Io}; use ntex::service::{chain_factory, ServiceFactory}; use ntex::util::{variant, Ready}; use ntex::ws; -use ntex_mqtt::{v3, v5, MqttError, MqttServer}; +use ntex_mqtt::{v3, v5, HandshakeError, MqttError, MqttServer}; use ntex_tls::openssl::Acceptor; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; @@ -101,9 +101,9 @@ async fn main() -> std::io::Result<()> { return match result { Some(Protocol::Mqtt) => Ok(variant::Variant2::V1(io)), Some(Protocol::Http) => Ok(variant::Variant2::V2(io)), - Some(Protocol::Unknown) => { - Err(MqttError::ServerError("Unsupported protocol")) - } + Some(Protocol::Unknown) => Err(MqttError::Handshake( + HandshakeError::Server("Unsupported protocol"), + )), None => { // need to read more data io.read_ready().await?; @@ -139,8 +139,10 @@ async fn main() -> std::io::Result<()> { &codec, ) .await?; - return Err(MqttError::ServerError( - "WebSockets handshake error", + return Err(MqttError::Handshake( + HandshakeError::Server( + "WebSockets handshake error", + ), )); } Ok(mut res) => { @@ -176,7 +178,9 @@ async fn main() -> std::io::Result<()> { // adapt service error to mqtt error .map_err(|e| { log::info!("Http server error: {:?}", e); - MqttError::ServerError("Http server error") + MqttError::Handshake(HandshakeError::Server( + "Http server error", + )) })), ) })? diff --git a/src/error.rs b/src/error.rs index 966726e..7b427c6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,18 +10,29 @@ pub enum MqttError { /// Publish handler service error #[error("Service error")] Service(E), + /// Handshake error + #[error("Mqtt handshake error: {}", _0)] + Handshake(#[from] HandshakeError), +} + +/// Errors which can occur during mqtt connection handshake. +#[derive(Debug, thiserror::Error)] +pub enum HandshakeError { + /// Handshake service error + #[error("Handshake service error")] + Service(E), /// Protocol error #[error("Mqtt protocol error: {}", _0)] Protocol(#[from] ProtocolError), /// Handshake timeout #[error("Handshake timeout")] - HandshakeTimeout, + Timeout, /// Peer disconnect #[error("Peer is disconnected, error: {:?}", _0)] Disconnected(Option), /// Server error #[error("Server error: {}", _0)] - ServerError(&'static str), + Server(&'static str), } /// Protocol level errors @@ -54,6 +65,7 @@ enum ViolationInner { #[error("{message}; received packet with type `{packet_type:b}`")] UnexpectedPacket { packet_type: u8, message: &'static str }, } + impl ProtocolViolationError { pub(crate) fn reason(&self) -> DisconnectReasonCode { match self.inner { @@ -87,21 +99,21 @@ impl ProtocolError { impl From for MqttError { fn from(err: io::Error) -> Self { - MqttError::Disconnected(Some(err)) + MqttError::Handshake(HandshakeError::Disconnected(Some(err))) } } impl From> for MqttError { fn from(err: Either) -> Self { - MqttError::Disconnected(Some(err.into_inner())) + MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner()))) } } -impl From> for MqttError { +impl From> for HandshakeError { fn from(err: Either) -> Self { match err { - Either::Left(err) => MqttError::Protocol(ProtocolError::Decode(err)), - Either::Right(err) => MqttError::Disconnected(Some(err)), + Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)), + Either::Right(err) => HandshakeError::Disconnected(Some(err)), } } } @@ -109,8 +121,10 @@ impl From> for MqttError { impl From> for MqttError { fn from(err: Either) -> Self { match err { - Either::Left(err) => MqttError::Protocol(ProtocolError::Encode(err)), - Either::Right(err) => MqttError::Disconnected(Some(err)), + Either::Left(err) => { + MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err))) + } + Either::Right(err) => MqttError::Handshake(HandshakeError::Disconnected(Some(err))), } } } diff --git a/src/inflight.rs b/src/inflight.rs index 727104a..2965428 100644 --- a/src/inflight.rs +++ b/src/inflight.rs @@ -14,7 +14,7 @@ pub(crate) struct InFlightService { } impl InFlightService { - pub fn new(max_cap: u16, max_size: usize, service: S) -> Self { + pub(crate) fn new(max_cap: u16, max_size: usize, service: S) -> Self { Self { service, count: Counter::new(max_cap, max_size) } } } @@ -101,7 +101,7 @@ impl Counter { CounterGuard::new(size, self.0.clone()) } - fn available(&self, cx: &mut Context<'_>) -> bool { + fn available(&self, cx: &Context<'_>) -> bool { self.0.available(cx) } } @@ -142,7 +142,7 @@ impl CounterInner { } } - fn available(&self, cx: &mut Context<'_>) -> bool { + fn available(&self, cx: &Context<'_>) -> bool { if (self.max_cap == 0 || self.cur_cap.get() < self.max_cap) && (self.max_size == 0 || self.cur_size.get() <= self.max_size) { diff --git a/src/lib.rs b/src/lib.rs index 0e35682..9139571 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(rust_2018_idioms)] +#![deny(rust_2018_idioms, warnings, unreachable_pub)] #![allow(clippy::type_complexity)] //! MQTT Client/Server framework @@ -19,7 +19,7 @@ mod session; mod types; mod version; -pub use self::error::MqttError; +pub use self::error::{HandshakeError, MqttError}; pub use self::server::MqttServer; pub use self::session::Session; pub use self::topic::{TopicFilter, TopicFilterError, TopicFilterLevel}; diff --git a/src/server.rs b/src/server.rs index be047a0..c7d333c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,7 +7,7 @@ use ntex::time::{Deadline, Millis, Seconds}; use ntex::util::{join, ready, BoxFuture, Ready}; use crate::version::{ProtocolVersion, VersionCodec}; -use crate::{error::MqttError, v3, v5}; +use crate::{error::HandshakeError, error::MqttError, v3, v5}; /// Mqtt Server pub struct MqttServer { @@ -437,7 +437,11 @@ where MqttServerImplStateProject::Version { ref mut item } => { match item.as_mut().unwrap().2.poll_elapsed(cx) { Poll::Pending => (), - Poll::Ready(_) => return Poll::Ready(Err(MqttError::HandshakeTimeout)), + Poll::Ready(_) => { + return Poll::Ready(Err(MqttError::Handshake( + HandshakeError::Timeout, + ))) + } } let st = item.as_mut().unwrap(); @@ -458,16 +462,17 @@ where unreachable!() } Err(RecvError::WriteBackpressure) => { - ready!(st.0.poll_flush(cx, false)) - .map_err(|e| MqttError::Disconnected(Some(e)))?; + ready!(st.0.poll_flush(cx, false)).map_err(|e| { + MqttError::Handshake(HandshakeError::Disconnected(Some(e))) + })?; continue; } - Err(RecvError::Decoder(err)) => { - Poll::Ready(Err(MqttError::Protocol(err.into()))) - } - Err(RecvError::PeerGone(err)) => { - Poll::Ready(Err(MqttError::Disconnected(err))) - } + Err(RecvError::Decoder(err)) => Poll::Ready(Err(MqttError::Handshake( + HandshakeError::Protocol(err.into()), + ))), + Err(RecvError::PeerGone(err)) => Poll::Ready(Err( + MqttError::Handshake(HandshakeError::Disconnected(err)), + )), }; } } @@ -504,9 +509,9 @@ impl Service<(IoBoxed, Deadline)> for DefaultProtocolServer = Ready where Self: 'f; fn call<'a>(&'a self, _: (IoBoxed, Deadline), _: ServiceCtx<'a, Self>) -> Self::Future<'a> { - Ready::Err(MqttError::Disconnected(Some(io::Error::new( + Ready::Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new( io::ErrorKind::Other, format!("Protocol is not supported: {:?}", self.ver), - )))) + ))))) } } diff --git a/src/topic.rs b/src/topic.rs index 9871b6b..7a0fc14 100644 --- a/src/topic.rs +++ b/src/topic.rs @@ -339,7 +339,7 @@ mod tests { is_valid(topic_filter) } - pub fn lvl_normal>(s: T) -> TopicFilterLevel { + fn lvl_normal>(s: T) -> TopicFilterLevel { if s.as_ref().contains(|c| c == '+' || c == '#') { panic!("invalid normal level `{}` contains +|#", s.as_ref()); } @@ -347,7 +347,7 @@ mod tests { TopicFilterLevel::Normal(s.as_ref().into()) } - pub fn lvl_sys>(s: T) -> TopicFilterLevel { + fn lvl_sys>(s: T) -> TopicFilterLevel { if s.as_ref().contains(|c| c == '+' || c == '#') { panic!("invalid normal level `{}` contains +|#", s.as_ref()); } @@ -359,7 +359,7 @@ mod tests { TopicFilterLevel::System(s.as_ref().into()) } - pub fn topic(topic: &'static str) -> TopicFilter { + fn topic(topic: &'static str) -> TopicFilter { TopicFilter::try_from(ByteString::from_static(topic)).unwrap() } diff --git a/src/types.rs b/src/types.rs index 9b67762..6dbb0f3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,10 +1,10 @@ -pub const MQTT: &[u8] = b"MQTT"; -pub const MQTT_LEVEL_3: u8 = 4; -pub const MQTT_LEVEL_5: u8 = 5; -pub const WILL_QOS_SHIFT: u8 = 3; +pub(crate) const MQTT: &[u8] = b"MQTT"; +pub(crate) const MQTT_LEVEL_3: u8 = 4; +pub(crate) const MQTT_LEVEL_5: u8 = 5; +pub(crate) const WILL_QOS_SHIFT: u8 = 3; /// Max possible packet size -pub const MAX_PACKET_SIZE: u32 = 0xF_FF_FF_FF; +pub(crate) const MAX_PACKET_SIZE: u32 = 0xF_FF_FF_FF; prim_enum! { /// Quality of Service @@ -32,6 +32,7 @@ prim_enum! { } bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ConnectFlags: u8 { const USERNAME = 0b1000_0000; const PASSWORD = 0b0100_0000; @@ -43,6 +44,7 @@ bitflags::bitflags! { } bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ConnectAckFlags: u8 { const SESSION_PRESENT = 0b0000_0001; } diff --git a/src/v3/client/connector.rs b/src/v3/client/connector.rs index a734657..481e261 100644 --- a/src/v3/client/connector.rs +++ b/src/v3/client/connector.rs @@ -205,11 +205,7 @@ where } async fn _connect(&self) -> Result> { - let io: IoBoxed = self - .connector - .call(Connect::new(self.address.clone())) - .await? - .into(); + let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into(); let pkt = self.pkt.clone(); let max_send = self.max_send; let max_receive = self.max_receive; diff --git a/src/v3/client/dispatcher.rs b/src/v3/client/dispatcher.rs index 911da23..2236278 100644 --- a/src/v3/client/dispatcher.rs +++ b/src/v3/client/dispatcher.rs @@ -6,9 +6,9 @@ use ntex::io::DispatchItem; use ntex::service::{Pipeline, Service, ServiceCall, ServiceCtx}; use ntex::util::{inflight::InFlightService, BoxFuture, Either, HashSet, Ready}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::v3::shared::{Ack, MqttShared}; use crate::v3::{codec, control::ControlResultKind, publish::Publish}; -use crate::{error::MqttError, error::ProtocolError}; use super::control::{ControlMessage, ControlResult}; @@ -90,8 +90,7 @@ where self.inner.sink.close(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = - Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; })); } @@ -120,9 +119,9 @@ where if let Some(pid) = packet_id { if !inner.inflight.borrow_mut().insert(pid) { log::trace!("Duplicated packet id for publish packet: {:?}", pid); - return Either::Right(Either::Left(Ready::Err( - MqttError::ServerError("Duplicated packet id for publish packet"), - ))); + return Either::Right(Either::Left(Ready::Err(MqttError::Handshake( + HandshakeError::Server("Duplicated packet id for publish packet"), + )))); } } Either::Left(PublishResponse { @@ -135,21 +134,27 @@ where } DispatchItem::Item((codec::Packet::PublishAck { packet_id }, _)) => { if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) { - Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) + Either::Right(Either::Left(Ready::Err(MqttError::Handshake( + HandshakeError::Protocol(e), + )))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } DispatchItem::Item((codec::Packet::SubscribeAck { packet_id, status }, _)) => { if let Err(e) = self.inner.sink.pkt_ack(Ack::Subscribe { packet_id, status }) { - Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) + Either::Right(Either::Left(Ready::Err(MqttError::Handshake( + HandshakeError::Protocol(e), + )))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } DispatchItem::Item((codec::Packet::UnsubscribeAck { packet_id }, _)) => { if let Err(e) = self.inner.sink.pkt_ack(Ack::Unsubscribe(packet_id)) { - Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) + Either::Right(Either::Left(Ready::Err(MqttError::Handshake( + HandshakeError::Protocol(e), + )))) } else { Either::Right(Either::Left(Ready::Ok(None))) } @@ -161,10 +166,10 @@ where | codec::Packet::Unsubscribe { .. }), _, )) => Either::Right(Either::Left(Ready::Err( - ProtocolError::unexpected_packet( + HandshakeError::Protocol(ProtocolError::unexpected_packet( pkt.packet_type(), "Packet of the type is not expected from server", - ) + )) .into(), ))), DispatchItem::Item((pkt, _)) => { @@ -377,7 +382,7 @@ mod tests { )))); let err = f.await.err().unwrap(); match err { - MqttError::ServerError(msg) => { + MqttError::Handshake(HandshakeError::Server(msg)) => { assert!(msg == "Duplicated packet id for publish packet") } _ => panic!(), diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index b723128..7df1e13 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -7,7 +7,7 @@ use ntex::service::{self, Pipeline, Service, ServiceCall, ServiceCtx, ServiceFac use ntex::util::buffer::{BufferService, BufferServiceError}; use ntex::util::{inflight::InFlightService, join, BoxFuture, Either, HashSet, Ready}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::QoS; use super::control::{ @@ -46,8 +46,10 @@ where let fut = join(factories.0.create(session.clone()), factories.1.create(session)); let (publish, control) = fut.await; - let publish = publish.map_err(|e| MqttError::Service(e.into()))?; - let control = control.map_err(|e| MqttError::Service(e.into()))?; + let publish = + publish.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?; + let control = + control.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?; let control = BufferService::new( 16, @@ -55,10 +57,12 @@ where InFlightService::new(1, control), ) .map_err(|err| match err { - BufferServiceError::Service(e) => MqttError::Service(E::from(e)), - BufferServiceError::RequestCanceled => { - MqttError::ServerError("Request handling has been canceled") + BufferServiceError::Service(e) => { + MqttError::Handshake(HandshakeError::Service(E::from(e))) } + BufferServiceError::RequestCanceled => MqttError::Handshake( + HandshakeError::Server("Request handling has been canceled"), + ), }); Ok( @@ -145,8 +149,7 @@ where self.inner.sink.close(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = - Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; })); } @@ -256,10 +259,14 @@ where } if !self.inner.inflight.borrow_mut().insert(packet_id) { - log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id); - return Either::Right(Either::Left(Ready::Err(MqttError::ServerError( - "Duplicated packet id for unsubscribe packet", - )))); + log::trace!("Duplicated packet id for subscribe packet: {:?}", packet_id); + return Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::generic_violation( + "Duplicated packet id for subscribe packet", + )), + &self.inner, + ctx, + ))); } Either::Right(Either::Right(ControlResponse::new( @@ -284,9 +291,13 @@ where if !self.inner.inflight.borrow_mut().insert(packet_id) { log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id); - return Either::Right(Either::Left(Ready::Err(MqttError::ServerError( - "Duplicated packet id for unsubscribe packet", - )))); + return Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::generic_violation( + "Duplicated packet id for unsubscribe packet", + )), + &self.inner, + ctx, + ))); } Either::Right(Either::Right(ControlResponse::new( diff --git a/src/v3/publish.rs b/src/v3/publish.rs index f9cebe2..6cb3649 100644 --- a/src/v3/publish.rs +++ b/src/v3/publish.rs @@ -7,6 +7,7 @@ use serde_json::Error as JsonError; use crate::v3::codec; +#[derive(Clone)] /// Publish message pub struct Publish { pkt: codec::Publish, @@ -14,10 +15,6 @@ pub struct Publish { topic: Path, } -#[derive(Debug)] -/// Publish ack -pub struct PublishAck; - impl Publish { /// Create a new `Publish` message from a PUBLISH /// packet diff --git a/src/v3/selector.rs b/src/v3/selector.rs index e9207f1..8c33feb 100644 --- a/src/v3/selector.rs +++ b/src/v3/selector.rs @@ -5,7 +5,7 @@ use ntex::service::{boxed, Service, ServiceCtx, ServiceFactory}; use ntex::time::{Deadline, Millis, Seconds}; use ntex::util::{select, BoxFuture, Either}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use super::control::{ControlMessage, ControlResult}; use super::handshake::{Handshake, HandshakeAck}; @@ -246,17 +246,17 @@ where .await .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) }) }) .await; let (packet, size) = match result { - Either::Left(_) => Err(MqttError::HandshakeTimeout), + Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), Either::Right(item) => item, }?; @@ -264,9 +264,11 @@ where mqtt::Packet::Connect(connect) => connect, packet => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "Expected CONNECT packet [MQTT-3.1.0-1]", + return Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "Expected CONNECT packet [MQTT-3.1.0-1]", + ), ))); } }; @@ -282,7 +284,7 @@ where } } log::error!("Cannot handle CONNECT packet {:?}", item.0); - Err(MqttError::ServerError("Cannot handle CONNECT packet")) + Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } } @@ -323,17 +325,17 @@ where .await .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) }) }) .await; let (packet, size) = match result { - Either::Left(_) => Err(MqttError::HandshakeTimeout), + Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), Either::Right(item) => item, }?; @@ -341,9 +343,11 @@ where mqtt::Packet::Connect(connect) => connect, packet => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "MQTT-3.1.0-1: Expected CONNECT packet", + return Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "MQTT-3.1.0-1: Expected CONNECT packet", + ), ))); } }; @@ -359,7 +363,7 @@ where } } log::error!("Cannot handle CONNECT packet {:?}", item.0.packet()); - Err(MqttError::ServerError("Cannot handle CONNECT packet")) + Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } } diff --git a/src/v3/server.rs b/src/v3/server.rs index 483c36e..c142271 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -5,7 +5,7 @@ use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; use ntex::util::{select, BoxFuture, Either}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::{io::Dispatcher, service, types::QoS}; use super::control::{ControlMessage, ControlResult}; @@ -336,11 +336,11 @@ where .await .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) })?; match packet { @@ -379,15 +379,17 @@ where ack.io.send(pkt, &ack.shared.codec).await?; let _ = ack.io.shutdown().await; - Err(MqttError::Disconnected(None)) + Err(MqttError::Handshake(HandshakeError::Disconnected(None))) } } } (packet, _) => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "MQTT-3.1.0-1: Expected CONNECT packet", + Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "MQTT-3.1.0-1: Expected CONNECT packet", + ), ))) } } @@ -398,7 +400,7 @@ where if let Ok(val) = timeout_checked(handshake_timeout, f).await { val } else { - Err(MqttError::HandshakeTimeout) + Err(MqttError::Handshake(HandshakeError::Timeout)) } }) } @@ -488,19 +490,21 @@ where let result = match select((*self.check)(&hnd), &mut delay).await { Either::Left(res) => res, - Either::Right(_) => return Err(MqttError::HandshakeTimeout), + Either::Right(_) => return Err(MqttError::Handshake(HandshakeError::Timeout)), }; - if !result.map_err(MqttError::Service)? { + if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { Ok(Either::Left((hnd, delay))) } else { // authenticate mqtt connection let ack = match select(ctx.call(&self.handshake, hnd), delay).await { Either::Left(res) => res.map_err(|e| { log::trace!("Connection handshake failed: {:?}", e); - MqttError::Service(e) + MqttError::Handshake(HandshakeError::Service(e)) })?, - Either::Right(_) => return Err(MqttError::HandshakeTimeout), + Either::Right(_) => { + return Err(MqttError::Handshake(HandshakeError::Timeout)) + } }; match ack.session { @@ -538,7 +542,7 @@ where ack.io.send(pkt, &ack.shared.codec).await?; let _ = ack.io.shutdown().await; - Err(MqttError::Disconnected(None)) + Err(MqttError::Handshake(HandshakeError::Disconnected(None))) } } } diff --git a/src/v3/shared.rs b/src/v3/shared.rs index 5d3d1ad..fed9aed 100644 --- a/src/v3/shared.rs +++ b/src/v3/shared.rs @@ -38,6 +38,7 @@ impl Default for MqttSinkPool { } bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct Flags: u8 { const CLIENT = 0b1000_0000; const WRB_ENABLED = 0b0100_0000; // write-backpressure diff --git a/src/v5/client/connector.rs b/src/v5/client/connector.rs index f7ca0f1..9b16d3a 100644 --- a/src/v5/client/connector.rs +++ b/src/v5/client/connector.rs @@ -209,11 +209,7 @@ where } async fn _connect(&self) -> Result>> { - let io: IoBoxed = self - .connector - .call(Connect::new(self.address.clone())) - .await? - .into(); + let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into(); let pkt = self.pkt.clone(); let keep_alive = pkt.keep_alive; let max_packet_size = pkt.max_packet_size.map(|v| v.get()).unwrap_or(0); diff --git a/src/v5/client/dispatcher.rs b/src/v5/client/dispatcher.rs index 608254c..5adccac 100644 --- a/src/v5/client/dispatcher.rs +++ b/src/v5/client/dispatcher.rs @@ -6,7 +6,7 @@ use ntex::io::DispatchItem; use ntex::service::{Pipeline, Service, ServiceCall, ServiceCtx}; use ntex::util::{BoxFuture, ByteString, Either, HashMap, HashSet, Ready}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::packet_type; use crate::v5::codec::DisconnectReasonCode; use crate::v5::shared::{Ack, MqttShared}; @@ -116,8 +116,7 @@ where self.inner.sink.drop_sink(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = - Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; })); } @@ -296,10 +295,10 @@ where | codec::Packet::Unsubscribe(_)), _, )) => Either::Right(Either::Left(Ready::Err( - ProtocolError::unexpected_packet( + HandshakeError::Protocol(ProtocolError::unexpected_packet( pkt.packet_type(), "Packet of the type is not expected from server", - ) + )) .into(), ))), DispatchItem::Item((codec::Packet::PingResponse, _)) => { diff --git a/src/v5/codec/codec.rs b/src/v5/codec/codec.rs index fb4b073..ca4f195 100644 --- a/src/v5/codec/codec.rs +++ b/src/v5/codec/codec.rs @@ -17,6 +17,7 @@ pub struct Codec { } bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct CodecFlags: u8 { const NO_PROBLEM_INFO = 0b0000_0001; const NO_RETAIN = 0b0000_0010; diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index 84043d2..771137e 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -7,7 +7,7 @@ use ntex::util::{buffer::BufferService, buffer::BufferServiceError}; use ntex::util::{join, BoxFuture, ByteString, Either, HashMap, HashSet, Ready}; use ntex::{service, Pipeline, Service, ServiceCall, ServiceCtx, ServiceFactory}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use super::control::{ControlMessage, ControlResult}; use super::publish::{Publish, PublishAck}; @@ -44,8 +44,10 @@ where let (publish, control) = join(factories.0.create(ses.clone()), factories.1.create(ses)).await; - let publish = publish.map_err(|e| MqttError::Service(e.into()))?; - let control = control.map_err(|e| MqttError::Service(e.into()))?; + let publish = + publish.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?; + let control = + control.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?; let control = BufferService::new( 16, @@ -53,10 +55,12 @@ where InFlightService::new(1, control), ) .map_err(|err| match err { - BufferServiceError::Service(e) => MqttError::Service(E::from(e)), - BufferServiceError::RequestCanceled => { - MqttError::ServerError("Request handling has been canceled") + BufferServiceError::Service(e) => { + MqttError::Handshake(HandshakeError::Service(E::from(e))) } + BufferServiceError::RequestCanceled => MqttError::Handshake( + HandshakeError::Server("Request handling has been canceled"), + ), }); Ok(crate::inflight::InFlightService::new( @@ -152,8 +156,7 @@ where self.inner.sink.drop_sink(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = - Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; })); } diff --git a/src/v5/mod.rs b/src/v5/mod.rs index 7a72d80..db16990 100644 --- a/src/v5/mod.rs +++ b/src/v5/mod.rs @@ -29,4 +29,4 @@ pub use crate::error; pub use crate::topic::{TopicFilter, TopicFilterError}; pub use crate::types::QoS; -pub(self) const RECEIVE_MAX_DEFAULT: NonZeroU16 = unsafe { NonZeroU16::new_unchecked(65_535) }; +const RECEIVE_MAX_DEFAULT: NonZeroU16 = unsafe { NonZeroU16::new_unchecked(65_535) }; diff --git a/src/v5/selector.rs b/src/v5/selector.rs index ef6df1c..a53bd59 100644 --- a/src/v5/selector.rs +++ b/src/v5/selector.rs @@ -5,7 +5,7 @@ use ntex::service::{boxed, Service, ServiceCtx, ServiceFactory}; use ntex::time::{Deadline, Millis, Seconds}; use ntex::util::{select, BoxFuture, Either}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use super::control::{ControlMessage, ControlResult}; use super::handshake::{Handshake, HandshakeAck}; @@ -250,17 +250,17 @@ where .await .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) }) }) .await; let (packet, size) = match result { - Either::Left(_) => Err(MqttError::HandshakeTimeout), + Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), Either::Right(item) => item, }?; @@ -268,9 +268,11 @@ where mqtt::Packet::Connect(connect) => connect, packet => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {}", 1); - return Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "MQTT-3.1.0-1: Expected CONNECT packet", + return Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "MQTT-3.1.0-1: Expected CONNECT packet", + ), ))); } }; @@ -286,7 +288,7 @@ where } } log::error!("Cannot handle CONNECT packet {:?}", item.0); - Err(MqttError::ServerError("Cannot handle CONNECT packet")) + Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } } @@ -326,17 +328,17 @@ where .await .map_err(|err| { // log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) }) }) .await; let (packet, size) = match result { - Either::Left(_) => Err(MqttError::HandshakeTimeout), + Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), Either::Right(item) => item, }?; @@ -344,9 +346,11 @@ where mqtt::Packet::Connect(connect) => connect, packet => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "Expected CONNECT packet [MQTT-3.1.0-1]", + return Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "Expected CONNECT packet [MQTT-3.1.0-1]", + ), ))); } }; @@ -362,7 +366,7 @@ where } } log::error!("Cannot handle CONNECT packet {:?}", item.0); - Err(MqttError::ServerError("Cannot handle CONNECT packet")) + Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } } diff --git a/src/v5/server.rs b/src/v5/server.rs index 84bf1c7..2449651 100644 --- a/src/v5/server.rs +++ b/src/v5/server.rs @@ -5,7 +5,7 @@ use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; use ntex::util::{select, BoxFuture, Either}; -use crate::error::{MqttError, ProtocolError}; +use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::{io::Dispatcher, service, types::QoS}; use super::control::{ControlMessage, ControlResult}; @@ -359,11 +359,11 @@ where .await .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::from(err) + MqttError::Handshake(HandshakeError::from(err)) })? .ok_or_else(|| { log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected(None) + MqttError::Handshake(HandshakeError::Disconnected(None)) })?; match packet { @@ -380,7 +380,7 @@ where let mut ack = ctx .call(&self.service, Handshake::new(connect, size, io, shared)) .await - .map_err(MqttError::Service)?; + .map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))?; match ack.session { Some(session) => { @@ -428,15 +428,17 @@ where ) .await?; let _ = ack.io.shutdown().await; - Err(MqttError::Disconnected(None)) + Err(MqttError::Handshake(HandshakeError::Disconnected(None))) } } } (packet, _) => { log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {}", 1); - Err(MqttError::Protocol(ProtocolError::unexpected_packet( - packet.packet_type(), - "Expected CONNECT packet [MQTT-3.1.0-1]", + Err(MqttError::Handshake(HandshakeError::Protocol( + ProtocolError::unexpected_packet( + packet.packet_type(), + "Expected CONNECT packet [MQTT-3.1.0-1]", + ), ))) } } @@ -446,7 +448,7 @@ where if let Ok(val) = timeout_checked(handshake_timeout, f).await { val } else { - Err(MqttError::HandshakeTimeout) + Err(MqttError::Handshake(HandshakeError::Timeout)) } }) } @@ -561,10 +563,10 @@ where let result = match select((*self.check)(&hnd), &mut delay).await { Either::Left(res) => res, - Either::Right(_) => return Err(MqttError::HandshakeTimeout), + Either::Right(_) => return Err(MqttError::Handshake(HandshakeError::Timeout)), }; - if !result.map_err(MqttError::Service)? { + if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { Ok(Either::Left((hnd, delay))) } else { // set max outbound (encoder) packet size @@ -579,9 +581,11 @@ where let mut ack = match select(ctx.call(&self.connect, hnd), &mut delay).await { Either::Left(res) => res.map_err(|e| { log::trace!("Connection handshake failed: {:?}", e); - MqttError::Service(e) + MqttError::Handshake(HandshakeError::Service(e)) })?, - Either::Right(_) => return Err(MqttError::HandshakeTimeout), + Either::Right(_) => { + return Err(MqttError::Handshake(HandshakeError::Timeout)) + } }; match ack.session { @@ -629,7 +633,7 @@ where ) .await?; let _ = ack.io.shutdown().await; - Err(MqttError::Disconnected(None)) + Err(MqttError::Handshake(HandshakeError::Disconnected(None))) } } } diff --git a/src/v5/shared.rs b/src/v5/shared.rs index 97d4d4d..5f4e31e 100644 --- a/src/v5/shared.rs +++ b/src/v5/shared.rs @@ -7,6 +7,7 @@ use ntex::{channel::pool, io::IoRef}; use crate::{error, error::SendPacketError, types::packet_type, v5::codec, QoS}; bitflags::bitflags! { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct Flags: u8 { const WRB_ENABLED = 0b0100_0000; // write-backpressure const ON_PUBLISH_ACK = 0b0010_0000; // on-publish-ack callback