Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MqttError type #149

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
fail-fast: false
matrix:
version:
- 1.66.0 # MSRV
# - 1.66.0 # MSRV
- stable
- nightly

Expand Down
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.12.0] - 2023-09-18

* Refactor MqttError type

## [0.11.4] - 2023-08-10

* Update ntex deps
Expand Down
22 changes: 8 additions & 14 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.11.4"
version = "0.12.0"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -9,12 +9,12 @@ categories = ["network-programming"]
keywords = ["MQTT", "IoT", "messaging"]
license = "MIT"
exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2018"
edition = "2021"

[dependencies]
ntex = "0.7.3"
ntex-util = "0.3.1"
bitflags = "1.3"
ntex = "0.7.4"
ntex-util = "0.3.2"
bitflags = "2.4"
log = "0.4"
pin-project-lite = "0.2"
serde = { version = "1.0", features = ["derive"] }
Expand All @@ -23,15 +23,9 @@ thiserror = "1.0"

[dev-dependencies]
env_logger = "0.10"
ntex-tls = "0.3.0"
ntex-tls = "0.3.1"
rustls = "0.21"
rustls-pemfile = "1.0"
openssl = "0.10"
ntex = { version = "0.7.3", features = ["tokio", "rustls", "openssl"] }
test-case = "3"

[profile.dev]
lto = "off" # cannot build tests with "thin"

[profile.test]
lto = "off" # cannot build tests with "thin"
test-case = "3.2"
ntex = { version = "0.7.4", features = ["tokio", "rustls", "openssl"] }
18 changes: 11 additions & 7 deletions examples/mqtt-ws-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ntex::io::{Filter, Io};
use ntex::service::{chain_factory, ServiceFactory};
use ntex::util::{variant, Ready};
use ntex::ws;
use ntex_mqtt::{v3, v5, MqttError, MqttServer};
use ntex_mqtt::{v3, v5, HandshakeError, MqttError, MqttServer};
use ntex_tls::openssl::Acceptor;
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};

Expand Down Expand Up @@ -101,9 +101,9 @@ async fn main() -> std::io::Result<()> {
return match result {
Some(Protocol::Mqtt) => Ok(variant::Variant2::V1(io)),
Some(Protocol::Http) => Ok(variant::Variant2::V2(io)),
Some(Protocol::Unknown) => {
Err(MqttError::ServerError("Unsupported protocol"))
}
Some(Protocol::Unknown) => Err(MqttError::Handshake(
HandshakeError::Server("Unsupported protocol"),
)),
None => {
// need to read more data
io.read_ready().await?;
Expand Down Expand Up @@ -139,8 +139,10 @@ async fn main() -> std::io::Result<()> {
&codec,
)
.await?;
return Err(MqttError::ServerError(
"WebSockets handshake error",
return Err(MqttError::Handshake(
HandshakeError::Server(
"WebSockets handshake error",
),
));
}
Ok(mut res) => {
Expand Down Expand Up @@ -176,7 +178,9 @@ async fn main() -> std::io::Result<()> {
// adapt service error to mqtt error
.map_err(|e| {
log::info!("Http server error: {:?}", e);
MqttError::ServerError("Http server error")
MqttError::Handshake(HandshakeError::Server(
"Http server error",
))
})),
)
})?
Expand Down
32 changes: 23 additions & 9 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,29 @@
/// Publish handler service error
#[error("Service error")]
Service(E),
/// Handshake error
#[error("Mqtt handshake error: {}", _0)]
Handshake(#[from] HandshakeError<E>),
}

/// Errors which can occur during mqtt connection handshake.
#[derive(Debug, thiserror::Error)]

Check warning on line 19 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L19

Added line #L19 was not covered by tests
pub enum HandshakeError<E> {
/// Handshake service error
#[error("Handshake service error")]
Service(E),
/// Protocol error
#[error("Mqtt protocol error: {}", _0)]
Protocol(#[from] ProtocolError),
/// Handshake timeout
#[error("Handshake timeout")]
HandshakeTimeout,
Timeout,
/// Peer disconnect
#[error("Peer is disconnected, error: {:?}", _0)]
Disconnected(Option<io::Error>),
/// Server error
#[error("Server error: {}", _0)]
ServerError(&'static str),
Server(&'static str),
}

/// Protocol level errors
Expand Down Expand Up @@ -54,6 +65,7 @@
#[error("{message}; received packet with type `{packet_type:b}`")]
UnexpectedPacket { packet_type: u8, message: &'static str },
}

impl ProtocolViolationError {
pub(crate) fn reason(&self) -> DisconnectReasonCode {
match self.inner {
Expand Down Expand Up @@ -87,30 +99,32 @@

impl<E> From<io::Error> for MqttError<E> {
fn from(err: io::Error) -> Self {
MqttError::Disconnected(Some(err))
MqttError::Handshake(HandshakeError::Disconnected(Some(err)))

Check warning on line 102 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L102

Added line #L102 was not covered by tests
}
}

impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
fn from(err: Either<io::Error, io::Error>) -> Self {
MqttError::Disconnected(Some(err.into_inner()))
MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))

Check warning on line 108 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L108

Added line #L108 was not covered by tests
}
}

impl<E> From<Either<DecodeError, io::Error>> for MqttError<E> {
impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
fn from(err: Either<DecodeError, io::Error>) -> Self {
match err {
Either::Left(err) => MqttError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => MqttError::Disconnected(Some(err)),
Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => HandshakeError::Disconnected(Some(err)),

Check warning on line 116 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L115-L116

Added lines #L115 - L116 were not covered by tests
}
}
}

impl<E> From<Either<EncodeError, io::Error>> for MqttError<E> {
fn from(err: Either<EncodeError, io::Error>) -> Self {
match err {
Either::Left(err) => MqttError::Protocol(ProtocolError::Encode(err)),
Either::Right(err) => MqttError::Disconnected(Some(err)),
Either::Left(err) => {
MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))

Check warning on line 125 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L124-L125

Added lines #L124 - L125 were not covered by tests
}
Either::Right(err) => MqttError::Handshake(HandshakeError::Disconnected(Some(err))),

Check warning on line 127 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L127

Added line #L127 was not covered by tests
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/inflight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) struct InFlightService<S> {
}

impl<S> InFlightService<S> {
pub fn new(max_cap: u16, max_size: usize, service: S) -> Self {
pub(crate) fn new(max_cap: u16, max_size: usize, service: S) -> Self {
Self { service, count: Counter::new(max_cap, max_size) }
}
}
Expand Down Expand Up @@ -101,7 +101,7 @@ impl Counter {
CounterGuard::new(size, self.0.clone())
}

fn available(&self, cx: &mut Context<'_>) -> bool {
fn available(&self, cx: &Context<'_>) -> bool {
self.0.available(cx)
}
}
Expand Down Expand Up @@ -142,7 +142,7 @@ impl CounterInner {
}
}

fn available(&self, cx: &mut Context<'_>) -> bool {
fn available(&self, cx: &Context<'_>) -> bool {
if (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
&& (self.max_size == 0 || self.cur_size.get() <= self.max_size)
{
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![deny(rust_2018_idioms)]
#![deny(rust_2018_idioms, warnings, unreachable_pub)]
#![allow(clippy::type_complexity)]

//! MQTT Client/Server framework
Expand All @@ -19,7 +19,7 @@ mod session;
mod types;
mod version;

pub use self::error::MqttError;
pub use self::error::{HandshakeError, MqttError};
pub use self::server::MqttServer;
pub use self::session::Session;
pub use self::topic::{TopicFilter, TopicFilterError, TopicFilterLevel};
Expand Down
29 changes: 17 additions & 12 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use ntex::util::{join, ready, BoxFuture, Ready};

use crate::version::{ProtocolVersion, VersionCodec};
use crate::{error::MqttError, v3, v5};
use crate::{error::HandshakeError, error::MqttError, v3, v5};

/// Mqtt Server
pub struct MqttServer<V3, V5, Err, InitErr> {
Expand Down Expand Up @@ -437,7 +437,11 @@
MqttServerImplStateProject::Version { ref mut item } => {
match item.as_mut().unwrap().2.poll_elapsed(cx) {
Poll::Pending => (),
Poll::Ready(_) => return Poll::Ready(Err(MqttError::HandshakeTimeout)),
Poll::Ready(_) => {
return Poll::Ready(Err(MqttError::Handshake(
HandshakeError::Timeout,
)))

Check warning on line 443 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L441-L443

Added lines #L441 - L443 were not covered by tests
}
}

let st = item.as_mut().unwrap();
Expand All @@ -458,16 +462,17 @@
unreachable!()
}
Err(RecvError::WriteBackpressure) => {
ready!(st.0.poll_flush(cx, false))
.map_err(|e| MqttError::Disconnected(Some(e)))?;
ready!(st.0.poll_flush(cx, false)).map_err(|e| {
MqttError::Handshake(HandshakeError::Disconnected(Some(e)))
})?;

Check warning on line 467 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L465-L467

Added lines #L465 - L467 were not covered by tests
continue;
}
Err(RecvError::Decoder(err)) => {
Poll::Ready(Err(MqttError::Protocol(err.into())))
}
Err(RecvError::PeerGone(err)) => {
Poll::Ready(Err(MqttError::Disconnected(err)))
}
Err(RecvError::Decoder(err)) => Poll::Ready(Err(MqttError::Handshake(
HandshakeError::Protocol(err.into()),
))),
Err(RecvError::PeerGone(err)) => Poll::Ready(Err(
MqttError::Handshake(HandshakeError::Disconnected(err)),
)),

Check warning on line 475 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L470-L475

Added lines #L470 - L475 were not covered by tests
};
}
}
Expand Down Expand Up @@ -504,9 +509,9 @@
type Future<'f> = Ready<Self::Response, Self::Error> where Self: 'f;

fn call<'a>(&'a self, _: (IoBoxed, Deadline), _: ServiceCtx<'a, Self>) -> Self::Future<'a> {
Ready::Err(MqttError::Disconnected(Some(io::Error::new(
Ready::Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new(

Check warning on line 512 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L512

Added line #L512 was not covered by tests
io::ErrorKind::Other,
format!("Protocol is not supported: {:?}", self.ver),
))))
)))))

Check warning on line 515 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L515

Added line #L515 was not covered by tests
}
}
6 changes: 3 additions & 3 deletions src/topic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,15 @@ mod tests {
is_valid(topic_filter)
}

pub fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
if s.as_ref().contains(|c| c == '+' || c == '#') {
panic!("invalid normal level `{}` contains +|#", s.as_ref());
}

TopicFilterLevel::Normal(s.as_ref().into())
}

pub fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
if s.as_ref().contains(|c| c == '+' || c == '#') {
panic!("invalid normal level `{}` contains +|#", s.as_ref());
}
Expand All @@ -359,7 +359,7 @@ mod tests {
TopicFilterLevel::System(s.as_ref().into())
}

pub fn topic(topic: &'static str) -> TopicFilter {
fn topic(topic: &'static str) -> TopicFilter {
TopicFilter::try_from(ByteString::from_static(topic)).unwrap()
}

Expand Down
12 changes: 7 additions & 5 deletions src/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
pub const MQTT: &[u8] = b"MQTT";
pub const MQTT_LEVEL_3: u8 = 4;
pub const MQTT_LEVEL_5: u8 = 5;
pub const WILL_QOS_SHIFT: u8 = 3;
pub(crate) const MQTT: &[u8] = b"MQTT";
pub(crate) const MQTT_LEVEL_3: u8 = 4;
pub(crate) const MQTT_LEVEL_5: u8 = 5;
pub(crate) const WILL_QOS_SHIFT: u8 = 3;

/// Max possible packet size
pub const MAX_PACKET_SIZE: u32 = 0xF_FF_FF_FF;
pub(crate) const MAX_PACKET_SIZE: u32 = 0xF_FF_FF_FF;

prim_enum! {
/// Quality of Service
Expand Down Expand Up @@ -32,6 +32,7 @@
}

bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]

Check warning on line 35 in src/types.rs

View check run for this annotation

Codecov / codecov/patch

src/types.rs#L35

Added line #L35 was not covered by tests
pub struct ConnectFlags: u8 {
const USERNAME = 0b1000_0000;
const PASSWORD = 0b0100_0000;
Expand All @@ -43,6 +44,7 @@
}

bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]

Check warning on line 47 in src/types.rs

View check run for this annotation

Codecov / codecov/patch

src/types.rs#L47

Added line #L47 was not covered by tests
pub struct ConnectAckFlags: u8 {
const SESSION_PRESENT = 0b0000_0001;
}
Expand Down
6 changes: 1 addition & 5 deletions src/v3/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,7 @@ where
}

async fn _connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
let io: IoBoxed = self
.connector
.call(Connect::new(self.address.clone()))
.await?
.into();
let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
let pkt = self.pkt.clone();
let max_send = self.max_send;
let max_receive = self.max_receive;
Expand Down
Loading
Loading