From 08dde140e48249f4314445ef8ce3025c586c3551 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 25 Sep 2023 12:02:32 +0600 Subject: [PATCH 1/2] Change handshake timeout behavior --- CHANGES.md | 5 +++ src/server.rs | 43 +++++++++++------- src/v3/selector.rs | 94 ++++++++++++--------------------------- src/v3/server.rs | 89 +++++++++++++++++-------------------- src/v5/selector.rs | 94 ++++++++++++--------------------------- src/v5/server.rs | 108 ++++++++++++++++++++------------------------- 6 files changed, 176 insertions(+), 257 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index ffd6f5b..429198d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,10 @@ # Changes +## [0.12.1] - 2023-09-25 + +* Change handshake timeout behavior (renamed to connect timeout). + Timeout handles slow client's Control frame. + ## [0.12.0] - 2023-09-18 * Refactor MqttError type diff --git a/src/server.rs b/src/server.rs index c7d333c..3dede37 100644 --- a/src/server.rs +++ b/src/server.rs @@ -13,7 +13,7 @@ use crate::{error::HandshakeError, error::MqttError, v3, v5}; pub struct MqttServer { v3: V3, v5: V5, - handshake_timeout: Millis, + connect_timeout: Millis, _t: marker::PhantomData<(Err, InitErr)>, } @@ -30,7 +30,7 @@ impl MqttServer { v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3), v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5), - handshake_timeout: Millis(10000), + connect_timeout: Millis(10000), _t: marker::PhantomData, } } @@ -50,12 +50,23 @@ impl Default } impl MqttServer { - /// Set handshake timeout. + /// Set client timeout for first `Connect` frame. + /// + /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// the entire frame within this time, the connection is terminated with + /// Mqtt::Handshake(HandshakeError::Timeout) error. /// - /// Handshake includes `connect` packet. - /// By default handshake timeuot is 10 seconds. + /// By default, connect timeuot is 10 seconds. + pub fn conenct_timeout(mut self, timeout: Seconds) -> Self { + self.connect_timeout = timeout.into(); + self + } + + #[deprecated(since = "0.12.1")] + #[doc(hidden)] + /// Set handshake timeout. pub fn handshake_timeout(mut self, timeout: Seconds) -> Self { - self.handshake_timeout = timeout.into(); + self.connect_timeout = timeout.into(); self } } @@ -113,7 +124,7 @@ where MqttServer { v3: service.finish(), v5: self.v5, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, _t: marker::PhantomData, } } @@ -140,7 +151,7 @@ where MqttServer { v3: service, v5: self.v5, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, _t: marker::PhantomData, } } @@ -185,7 +196,7 @@ where MqttServer { v3: self.v3, v5: service.finish(), - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, _t: marker::PhantomData, } } @@ -212,7 +223,7 @@ where MqttServer { v3: self.v3, v5: service, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, _t: marker::PhantomData, } } @@ -241,7 +252,7 @@ where let v5 = v5?; Ok(MqttServerImpl { handlers: (v3, v5), - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, _t: marker::PhantomData, }) } @@ -311,7 +322,7 @@ where /// Mqtt Server pub struct MqttServerImpl { handlers: (V3, V5), - handshake_timeout: Millis, + connect_timeout: Millis, _t: marker::PhantomData, } @@ -353,7 +364,7 @@ where MqttServerImplResponse { ctx, state: MqttServerImplState::Version { - item: Some((req, VersionCodec, Deadline::new(self.handshake_timeout))), + item: Some((req, VersionCodec, Deadline::new(self.connect_timeout))), }, handlers: &self.handlers, } @@ -491,7 +502,7 @@ impl DefaultProtocolServer { } } -impl ServiceFactory<(IoBoxed, Deadline)> for DefaultProtocolServer { +impl ServiceFactory for DefaultProtocolServer { type Response = (); type Error = MqttError; type Service = DefaultProtocolServer; @@ -503,12 +514,12 @@ impl ServiceFactory<(IoBoxed, Deadline)> for DefaultProtocolServer } } -impl Service<(IoBoxed, Deadline)> for DefaultProtocolServer { +impl Service for DefaultProtocolServer { type Response = (); type Error = MqttError; type Future<'f> = Ready where Self: 'f; - fn call<'a>(&'a self, _: (IoBoxed, Deadline), _: ServiceCtx<'a, Self>) -> Self::Future<'a> { + fn call<'a>(&'a self, _: IoBoxed, _: ServiceCtx<'a, Self>) -> Self::Future<'a> { Ready::Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new( io::ErrorKind::Other, format!("Protocol is not supported: {:?}", self.ver), diff --git a/src/v3/selector.rs b/src/v3/selector.rs index 8c33feb..efea423 100644 --- a/src/v3/selector.rs +++ b/src/v3/selector.rs @@ -12,12 +12,10 @@ use super::handshake::{Handshake, HandshakeAck}; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, MqttServer, Publish, Session}; -pub(crate) type SelectItem = (Handshake, Deadline); - type ServerFactory = - boxed::BoxServiceFactory<(), SelectItem, Either, MqttError, InitErr>; + boxed::BoxServiceFactory<(), Handshake, Either, MqttError, InitErr>; -type Server = boxed::BoxService, MqttError>; +type Server = boxed::BoxService, MqttError>; /// Mqtt server selector /// @@ -26,7 +24,7 @@ type Server = boxed::BoxService, MqttErr pub struct Selector { servers: Vec>, max_size: u32, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, _t: marker::PhantomData<(Err, InitErr)>, } @@ -37,7 +35,7 @@ impl Selector { Selector { servers: Vec::new(), max_size: 0, - handshake_timeout: Millis(10000), + connect_timeout: Millis(10000), pool: Default::default(), _t: marker::PhantomData, } @@ -49,12 +47,23 @@ where Err: 'static, InitErr: 'static, { - /// Set handshake timeout. + /// Set client timeout for first `Connect` frame. /// - /// Handshake includes `connect` packet and response `connect-ack`. - /// By default handshake timeuot is 10 seconds. + /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// the entire frame within this time, the connection is terminated with + /// Mqtt::Handshake(HandshakeError::Timeout) error. + /// + /// By default, connect timeuot is 10 seconds. + pub fn conenct_timeout(mut self, timeout: Seconds) -> Self { + self.connect_timeout = timeout.into(); + self + } + + #[deprecated(since = "0.12.1")] + #[doc(hidden)] + /// Set handshake timeout. pub fn handshake_timeout(mut self, timeout: Seconds) -> Self { - self.handshake_timeout = timeout.into(); + self.connect_timeout = timeout.into(); self } @@ -111,7 +120,7 @@ where Ok(SelectorService { servers, max_size: self.max_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, pool: self.pool.clone(), }) } @@ -169,7 +178,7 @@ where pub struct SelectorService { servers: Vec>, max_size: u32, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, } @@ -234,58 +243,11 @@ where #[inline] fn call<'a>(&'a self, io: IoBoxed, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { - Box::pin(async move { - let codec = mqtt::Codec::default(); - codec.set_max_size(self.max_size); - let shared = Rc::new(MqttShared::new(io.clone(), codec, false, self.pool.clone())); - let mut timeout = Deadline::new(self.handshake_timeout); - - // read first packet - let result = select(&mut timeout, async { - io.recv(&shared.codec) - .await - .map_err(|err| { - log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::Handshake(HandshakeError::from(err)) - })? - .ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Handshake(HandshakeError::Disconnected(None)) - }) - }) - .await; - - let (packet, size) = match result { - Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), - Either::Right(item) => item, - }?; - - let connect = match packet { - mqtt::Packet::Connect(connect) => connect, - packet => { - log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Handshake(HandshakeError::Protocol( - ProtocolError::unexpected_packet( - packet.packet_type(), - "Expected CONNECT packet [MQTT-3.1.0-1]", - ), - ))); - } - }; - - // call servers - let mut item = (Handshake::new(connect, size, io, shared), timeout); - for srv in &self.servers { - match ctx.call(srv, item).await? { - Either::Left(result) => { - item = result; - } - Either::Right(_) => return Ok(()), - } - } - log::error!("Cannot handle CONNECT packet {:?}", item.0); - Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) - }) + Service::<(IoBoxed, Deadline)>::call( + self, + (io, Deadline::new(self.connect_timeout)), + ctx, + ) } } @@ -353,7 +315,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, size, io, shared), timeout); + let mut item = Handshake::new(connect, size, io, shared); for srv in &self.servers { match ctx.call(srv, item).await? { Either::Left(result) => { @@ -362,7 +324,7 @@ where Either::Right(_) => return Ok(()), } } - log::error!("Cannot handle CONNECT packet {:?}", item.0.packet()); + log::error!("Cannot handle CONNECT packet {:?}", item.packet()); Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } diff --git a/src/v3/server.rs b/src/v3/server.rs index c142271..e58eb5d 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -3,7 +3,7 @@ use std::{fmt, future::Future, marker::PhantomData, rc::Rc}; use ntex::io::{DispatchItem, IoBoxed}; use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; -use ntex::util::{select, BoxFuture, Either}; +use ntex::util::{BoxFuture, Either}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::{io::Dispatcher, service, types::QoS}; @@ -11,7 +11,6 @@ use crate::{io::Dispatcher, service, types::QoS}; use super::control::{ControlMessage, ControlResult}; use super::default::{DefaultControlService, DefaultPublishService}; use super::handshake::{Handshake, HandshakeAck}; -use super::selector::SelectItem; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttSink, Publish, Session}; @@ -50,7 +49,7 @@ pub struct MqttServer { max_size: u32, max_inflight: u16, max_inflight_size: usize, - handshake_timeout: Seconds, + connect_timeout: Seconds, disconnect_timeout: Seconds, pub(super) pool: Rc, _t: PhantomData, @@ -76,7 +75,7 @@ where max_size: 0, max_inflight: 16, max_inflight_size: 65535, - handshake_timeout: Seconds::ZERO, + connect_timeout: Seconds::ZERO, disconnect_timeout: Seconds(3), pool: Default::default(), _t: PhantomData, @@ -94,12 +93,23 @@ where H::Error: From + From + From + From + fmt::Debug, { - /// Set handshake timeout. + /// Set client timeout for first `Connect` frame. + /// + /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// the entire frame within this time, the connection is terminated with + /// Mqtt::Handshake(HandshakeError::Timeout) error. /// - /// Handshake includes `connect` packet and response `connect-ack`. - /// By default handshake timeuot is disabled. + /// By default, connect timeout is disabled. + pub fn conenct_timeout(mut self, timeout: Seconds) -> Self { + self.connect_timeout = timeout; + self + } + + #[deprecated(since = "0.12.1")] + #[doc(hidden)] + /// Set handshake timeout. pub fn handshake_timeout(mut self, timeout: Seconds) -> Self { - self.handshake_timeout = timeout; + self.connect_timeout = timeout; self } @@ -169,7 +179,7 @@ where max_size: self.max_size, max_inflight: self.max_inflight, max_inflight_size: self.max_inflight_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, _t: PhantomData, @@ -191,7 +201,7 @@ where max_size: self.max_size, max_inflight: self.max_inflight, max_inflight_size: self.max_inflight_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, _t: PhantomData, @@ -222,7 +232,7 @@ where HandshakeFactory { factory: self.handshake, max_size: self.max_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, pool: self.pool.clone(), _t: PhantomData, }, @@ -242,8 +252,8 @@ where self, check: F, ) -> impl ServiceFactory< - SelectItem, - Response = Either, + Handshake, + Response = Either, Error = MqttError, InitError = H::InitError, > @@ -271,7 +281,7 @@ where struct HandshakeFactory { factory: H, max_size: u32, - handshake_timeout: Seconds, + connect_timeout: Seconds, pool: Rc, _t: PhantomData, } @@ -294,7 +304,7 @@ where max_size: self.max_size, pool: self.pool.clone(), service: self.factory.create(()).await?, - handshake_timeout: self.handshake_timeout.into(), + connect_timeout: self.connect_timeout.into(), _t: PhantomData, }) }) @@ -305,7 +315,7 @@ struct HandshakeService { service: H, max_size: u32, pool: Rc, - handshake_timeout: Millis, + connect_timeout: Millis, _t: PhantomData, } @@ -324,16 +334,16 @@ where fn call<'a>(&'a self, io: IoBoxed, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { log::trace!("Starting mqtt v3 handshake"); - let f = async move { + Box::pin(async move { let codec = mqtt::Codec::default(); codec.set_max_size(self.max_size); let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone())); // read first packet - let packet = io - .recv(&shared.codec) + let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec)) .await + .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::Handshake(HandshakeError::from(err)) @@ -393,15 +403,6 @@ where ))) } } - }; - - let handshake_timeout = self.handshake_timeout; - Box::pin(async move { - if let Ok(val) = timeout_checked(handshake_timeout, f).await { - val - } else { - Err(MqttError::Handshake(HandshakeError::Timeout)) - } }) } } @@ -415,7 +416,7 @@ pub(crate) struct ServerSelector { _t: PhantomData<(St, R)>, } -impl ServiceFactory for ServerSelector +impl ServiceFactory for ServerSelector where St: 'static, F: Fn(&Handshake) -> R + 'static, @@ -430,7 +431,7 @@ where InitError = MqttError, > + 'static, { - type Response = Either; + type Response = Either; type Error = MqttError; type InitError = H::InitError; type Service = ServerSelectorImpl; @@ -460,7 +461,7 @@ pub(crate) struct ServerSelectorImpl { _t: PhantomData<(St, R)>, } -impl Service for ServerSelectorImpl +impl Service for ServerSelectorImpl where St: 'static, F: Fn(&Handshake) -> R + 'static, @@ -475,7 +476,7 @@ where InitError = MqttError, > + 'static, { - type Response = Either; + type Response = Either; type Error = MqttError; type Future<'f> = BoxFuture<'f, Result> where Self: 'f; @@ -483,29 +484,19 @@ where ntex::forward_poll_shutdown!(handshake); #[inline] - fn call<'a>(&'a self, req: SelectItem, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { + fn call<'a>(&'a self, hnd: Handshake, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { Box::pin(async move { log::trace!("Start connection handshake"); - let (hnd, mut delay) = req; - - let result = match select((*self.check)(&hnd), &mut delay).await { - Either::Left(res) => res, - Either::Right(_) => return Err(MqttError::Handshake(HandshakeError::Timeout)), - }; + let result = (*self.check)(&hnd).await; if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { - Ok(Either::Left((hnd, delay))) + Ok(Either::Left(hnd)) } else { // authenticate mqtt connection - let ack = match select(ctx.call(&self.handshake, hnd), delay).await { - Either::Left(res) => res.map_err(|e| { - log::trace!("Connection handshake failed: {:?}", e); - MqttError::Handshake(HandshakeError::Service(e)) - })?, - Either::Right(_) => { - return Err(MqttError::Handshake(HandshakeError::Timeout)) - } - }; + let ack = ctx.call(&self.handshake, hnd).await.map_err(|e| { + log::trace!("Connection handshake failed: {:?}", e); + MqttError::Handshake(HandshakeError::Service(e)) + })?; match ack.session { Some(session) => { diff --git a/src/v5/selector.rs b/src/v5/selector.rs index a53bd59..51cfb73 100644 --- a/src/v5/selector.rs +++ b/src/v5/selector.rs @@ -13,12 +13,10 @@ use super::publish::{Publish, PublishAck}; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, MqttServer, Session}; -pub(crate) type SelectItem = (Handshake, Deadline); - type ServerFactory = - boxed::BoxServiceFactory<(), SelectItem, Either, MqttError, InitErr>; + boxed::BoxServiceFactory<(), Handshake, Either, MqttError, InitErr>; -type Server = boxed::BoxService, MqttError>; +type Server = boxed::BoxService, MqttError>; /// Mqtt server selector /// @@ -27,7 +25,7 @@ type Server = boxed::BoxService, MqttErr pub struct Selector { servers: Vec>, max_size: u32, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, _t: marker::PhantomData<(Err, InitErr)>, } @@ -38,7 +36,7 @@ impl Selector { Selector { servers: Vec::new(), max_size: 0, - handshake_timeout: Millis(10000), + connect_timeout: Millis(10000), pool: Default::default(), _t: marker::PhantomData, } @@ -50,12 +48,23 @@ where Err: 'static, InitErr: 'static, { - /// Set handshake timeout. + /// Set client timeout for first `Connect` frame. /// - /// Handshake includes `connect` packet and response `connect-ack`. - /// By default handshake timeuot is 10 seconds. + /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// the entire frame within this time, the connection is terminated with + /// Mqtt::Handshake(HandshakeError::Timeout) error. + /// + /// By default, connect timeout is disabled. + pub fn conenct_timeout(mut self, timeout: Seconds) -> Self { + self.connect_timeout = timeout.into(); + self + } + + #[deprecated(since = "0.12.1")] + #[doc(hidden)] + /// Set handshake timeout. pub fn handshake_timeout(mut self, timeout: Seconds) -> Self { - self.handshake_timeout = timeout.into(); + self.connect_timeout = timeout.into(); self } @@ -115,7 +124,7 @@ where Ok(SelectorService { servers, max_size: self.max_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, pool: self.pool.clone(), }) } @@ -173,7 +182,7 @@ where pub struct SelectorService { servers: Vec>, max_size: u32, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, } @@ -238,58 +247,11 @@ where #[inline] fn call<'a>(&'a self, io: IoBoxed, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { - Box::pin(async move { - let codec = mqtt::Codec::default(); - codec.set_max_inbound_size(self.max_size); - let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone())); - let mut timeout = Deadline::new(self.handshake_timeout); - - // read first packet - let result = select(&mut timeout, async { - io.recv(&shared.codec) - .await - .map_err(|err| { - log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::Handshake(HandshakeError::from(err)) - })? - .ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Handshake(HandshakeError::Disconnected(None)) - }) - }) - .await; - - let (packet, size) = match result { - Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), - Either::Right(item) => item, - }?; - - let connect = match packet { - mqtt::Packet::Connect(connect) => connect, - packet => { - log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {}", 1); - return Err(MqttError::Handshake(HandshakeError::Protocol( - ProtocolError::unexpected_packet( - packet.packet_type(), - "MQTT-3.1.0-1: Expected CONNECT packet", - ), - ))); - } - }; - - // call servers - let mut item = (Handshake::new(connect, size, io, shared), timeout); - for srv in self.servers.iter() { - match ctx.call(&srv, item).await? { - Either::Left(result) => { - item = result; - } - Either::Right(_) => return Ok(()), - } - } - log::error!("Cannot handle CONNECT packet {:?}", item.0); - Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) - }) + Service::<(IoBoxed, Deadline)>::call( + self, + (io, Deadline::new(self.connect_timeout)), + ctx, + ) } } @@ -356,7 +318,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, size, io, shared), timeout); + let mut item = Handshake::new(connect, size, io, shared); for srv in self.servers.iter() { match ctx.call(srv, item).await? { Either::Left(result) => { @@ -365,7 +327,7 @@ where Either::Right(_) => return Ok(()), } } - log::error!("Cannot handle CONNECT packet {:?}", item.0); + log::error!("Cannot handle CONNECT packet {:?}", item); Err(MqttError::Handshake(HandshakeError::Server("Cannot handle CONNECT packet"))) }) } diff --git a/src/v5/server.rs b/src/v5/server.rs index 2449651..ec2019d 100644 --- a/src/v5/server.rs +++ b/src/v5/server.rs @@ -3,7 +3,7 @@ use std::{convert::TryFrom, fmt, future::Future, marker::PhantomData, rc::Rc}; use ntex::io::{DispatchItem, IoBoxed}; use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; -use ntex::util::{select, BoxFuture, Either}; +use ntex::util::{BoxFuture, Either}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::{io::Dispatcher, service, types::QoS}; @@ -12,7 +12,6 @@ use super::control::{ControlMessage, ControlResult}; use super::default::{DefaultControlService, DefaultPublishService}; use super::handshake::{Handshake, HandshakeAck}; use super::publish::{Publish, PublishAck}; -use super::selector::SelectItem; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttSink, Session}; @@ -25,7 +24,7 @@ pub struct MqttServer { max_receive: u16, max_qos: QoS, max_inflight_size: usize, - handshake_timeout: Seconds, + connect_timeout: Seconds, disconnect_timeout: Seconds, max_topic_alias: u16, pub(super) pool: Rc, @@ -51,7 +50,7 @@ where max_receive: 15, max_qos: QoS::AtLeastOnce, max_inflight_size: 65535, - handshake_timeout: Seconds::ZERO, + connect_timeout: Seconds::ZERO, disconnect_timeout: Seconds(3), max_topic_alias: 32, pool: Rc::new(MqttSinkPool::default()), @@ -69,12 +68,23 @@ where + 'static, P: ServiceFactory, Response = PublishAck> + 'static, { - /// Set handshake timeout. + /// Set client timeout for first `Connect` frame. + /// + /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// the entire frame within this time, the connection is terminated with + /// Mqtt::Handshake(HandshakeError::Timeout) error. /// - /// Handshake includes `connect` packet and response `connect-ack`. - /// By default handshake timeuot is disabled. + /// By default, connect timeout is disabled. + pub fn conenct_timeout(mut self, timeout: Seconds) -> Self { + self.connect_timeout = timeout; + self + } + + #[deprecated(since = "0.12.1")] + #[doc(hidden)] + /// Set handshake timeout. pub fn handshake_timeout(mut self, timeout: Seconds) -> Self { - self.handshake_timeout = timeout; + self.connect_timeout = timeout; self } @@ -153,7 +163,7 @@ where max_topic_alias: self.max_topic_alias, max_qos: self.max_qos, max_inflight_size: self.max_inflight_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, _t: PhantomData, @@ -178,7 +188,7 @@ where max_topic_alias: self.max_topic_alias, max_qos: self.max_qos, max_inflight_size: self.max_inflight_size, - handshake_timeout: self.handshake_timeout, + connect_timeout: self.connect_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, _t: PhantomData, @@ -228,7 +238,7 @@ where max_receive: self.max_receive, max_topic_alias: self.max_topic_alias, max_qos: self.max_qos, - handshake_timeout: self.handshake_timeout.into(), + connect_timeout: self.connect_timeout.into(), pool: self.pool, _t: PhantomData, }, @@ -242,8 +252,8 @@ where self, check: F, ) -> impl ServiceFactory< - SelectItem, - Response = Either, + Handshake, + Response = Either, Error = MqttError, InitError = C::InitError, > @@ -275,7 +285,7 @@ struct HandshakeFactory { max_receive: u16, max_topic_alias: u16, max_qos: QoS, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, _t: PhantomData, } @@ -299,7 +309,7 @@ where let max_topic_alias = self.max_topic_alias; let max_qos = self.max_qos; let pool = self.pool.clone(); - let handshake_timeout = self.handshake_timeout; + let connect_timeout = self.connect_timeout; Box::pin(async move { let service = fut.await?; @@ -309,7 +319,7 @@ where max_receive, max_topic_alias, max_qos, - handshake_timeout, + connect_timeout, pool, _t: PhantomData, }) @@ -323,7 +333,7 @@ struct HandshakeService { max_receive: u16, max_topic_alias: u16, max_qos: QoS, - handshake_timeout: Millis, + connect_timeout: Millis, pool: Rc, _t: PhantomData, } @@ -350,13 +360,11 @@ where shared.set_receive_max(self.max_receive); shared.set_topic_alias_max(self.max_topic_alias); - let handshake_timeout = self.handshake_timeout; - - let f = async move { + Box::pin(async move { // read first packet - let packet = io - .recv(&shared.codec) + let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec)) .await + .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::Handshake(HandshakeError::from(err)) @@ -442,14 +450,6 @@ where ))) } } - }; - - Box::pin(async move { - if let Ok(val) = timeout_checked(handshake_timeout, f).await { - val - } else { - Err(MqttError::Handshake(HandshakeError::Timeout)) - } }) } } @@ -466,7 +466,7 @@ pub(crate) struct ServerSelector { _t: PhantomData<(St, R)>, } -impl ServiceFactory for ServerSelector +impl ServiceFactory for ServerSelector where St: 'static, F: Fn(&Handshake) -> R + 'static, @@ -481,7 +481,7 @@ where InitError = MqttError, > + 'static, { - type Response = Either; + type Response = Either; type Error = MqttError; type InitError = C::InitError; type Service = ServerSelectorImpl; @@ -526,7 +526,7 @@ pub(crate) struct ServerSelectorImpl { _t: PhantomData<(St, R)>, } -impl Service for ServerSelectorImpl +impl Service for ServerSelectorImpl where St: 'static, F: Fn(&Handshake) -> R + 'static, @@ -541,34 +541,27 @@ where InitError = MqttError, > + 'static, { - type Response = Either; + type Response = Either; type Error = MqttError; type Future<'f> = BoxFuture<'f, Result> where Self: 'f; ntex::forward_poll_ready!(connect, MqttError::Service); ntex::forward_poll_shutdown!(connect); - #[inline] - fn call<'a>(&'a self, req: SelectItem, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { + fn call<'a>(&'a self, hnd: Handshake, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { Box::pin(async move { log::trace!("Start connection handshake"); - let timeout = self.disconnect_timeout; - - req.0.shared.codec.set_max_inbound_size(self.max_size); - req.0.shared.set_max_qos(self.max_qos); - req.0.shared.set_receive_max(self.max_receive); - req.0.shared.set_topic_alias_max(self.max_topic_alias); - let (hnd, mut delay) = req; - - let result = match select((*self.check)(&hnd), &mut delay).await { - Either::Left(res) => res, - Either::Right(_) => return Err(MqttError::Handshake(HandshakeError::Timeout)), - }; - + let result = (*self.check)(&hnd).await; if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { - Ok(Either::Left((hnd, delay))) + Ok(Either::Left(hnd)) } else { + // decoder config + hnd.shared.codec.set_max_inbound_size(self.max_size); + hnd.shared.set_max_qos(self.max_qos); + hnd.shared.set_receive_max(self.max_receive); + hnd.shared.set_topic_alias_max(self.max_topic_alias); + // set max outbound (encoder) packet size if let Some(size) = hnd.packet().max_packet_size { hnd.shared.codec.set_max_outbound_size(size.get()); @@ -578,15 +571,10 @@ where hnd.packet().receive_max.map(|v| v.get()).unwrap_or(16) as usize; // authenticate mqtt connection - let mut ack = match select(ctx.call(&self.connect, hnd), &mut delay).await { - Either::Left(res) => res.map_err(|e| { - log::trace!("Connection handshake failed: {:?}", e); - MqttError::Handshake(HandshakeError::Service(e)) - })?, - Either::Right(_) => { - return Err(MqttError::Handshake(HandshakeError::Timeout)) - } - }; + let mut ack = ctx.call(&self.connect, hnd).await.map_err(|e| { + log::trace!("Connection handshake failed: {:?}", e); + MqttError::Handshake(HandshakeError::Service(e)) + })?; match ack.session { Some(session) => { @@ -619,7 +607,7 @@ where Dispatcher::new(ack.io, shared, handler) .keepalive_timeout(Seconds(ack.keepalive)) - .disconnect_timeout(timeout) + .disconnect_timeout(self.disconnect_timeout) .await?; Ok(Either::Right(())) } From 002439e01a21e624da245c3b5f5ebb3219183541 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 25 Sep 2023 12:12:36 +0600 Subject: [PATCH 2/2] wip --- Cargo.toml | 2 +- src/io.rs | 2 +- src/server.rs | 6 +++--- src/topic.rs | 3 +-- src/v3/codec/encode.rs | 2 +- src/v5/codec/decode.rs | 8 ++------ src/v5/codec/encode.rs | 2 +- tests/test_server.rs | 7 ++----- tests/test_server_v5.rs | 7 ++----- 9 files changed, 14 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ed87d3e..bb891eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "0.12.0" +version = "0.12.1" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" diff --git a/src/io.rs b/src/io.rs index d9da52b..5185a1c 100644 --- a/src/io.rs +++ b/src/io.rs @@ -730,7 +730,7 @@ mod tests { let (disp, io) = Dispatcher::new_debug(nio::Io::new(server), BytesCodec, Srv(counter.clone())); - io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &mut BytesCodec).unwrap(); + io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap(); ntex::rt::spawn(async move { let _ = disp.await; }); diff --git a/src/server.rs b/src/server.rs index 3dede37..6565e28 100644 --- a/src/server.rs +++ b/src/server.rs @@ -502,7 +502,7 @@ impl DefaultProtocolServer { } } -impl ServiceFactory for DefaultProtocolServer { +impl ServiceFactory<(IoBoxed, Deadline)> for DefaultProtocolServer { type Response = (); type Error = MqttError; type Service = DefaultProtocolServer; @@ -514,12 +514,12 @@ impl ServiceFactory for DefaultProtocolServer Service for DefaultProtocolServer { +impl Service<(IoBoxed, Deadline)> for DefaultProtocolServer { type Response = (); type Error = MqttError; type Future<'f> = Ready where Self: 'f; - fn call<'a>(&'a self, _: IoBoxed, _: ServiceCtx<'a, Self>) -> Self::Future<'a> { + fn call<'a>(&'a self, _: (IoBoxed, Deadline), _: ServiceCtx<'a, Self>) -> Self::Future<'a> { Ready::Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new( io::ErrorKind::Other, format!("Protocol is not supported: {:?}", self.ver), diff --git a/src/topic.rs b/src/topic.rs index 7a0fc14..7a55b91 100644 --- a/src/topic.rs +++ b/src/topic.rs @@ -376,8 +376,7 @@ mod tests { #[test_case("/finance" => Ok(vec![TopicFilterLevel::Blank, lvl_normal("finance")]) ; "12")] #[test_case("finance/" => Ok(vec![lvl_normal("finance"), TopicFilterLevel::Blank]) ; "13")] fn parsing(input: &str) -> Result, TopicFilterError> { - TopicFilter::try_from(ByteString::from(input)) - .map(|t| t.levels().iter().cloned().collect()) + TopicFilter::try_from(ByteString::from(input)).map(|t| t.levels().to_vec()) } #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), lvl_normal("player1")] => true; "1")] diff --git a/src/v3/codec/encode.rs b/src/v3/codec/encode.rs index 90a2710..2e13a8f 100644 --- a/src/v3/codec/encode.rs +++ b/src/v3/codec/encode.rs @@ -262,7 +262,7 @@ mod tests { let mut v = BytesMut::with_capacity(1024); encode(packet, &mut v, get_encoded_size(packet) as u32).unwrap(); assert_eq!(expected.len(), v.len()); - assert_eq!(&expected[..], &v[..]); + assert_eq!(expected, &v[..]); } #[test] diff --git a/src/v5/codec/decode.rs b/src/v5/codec/decode.rs index e606198..5023b02 100644 --- a/src/v5/codec/decode.rs +++ b/src/v5/codec/decode.rs @@ -56,12 +56,8 @@ mod tests { let (_len, consumed) = decode_variable_length(&bytes[1..]).unwrap().unwrap(); let cur = Bytes::copy_from_slice(&bytes[consumed + 1..]); let mut tmp = BytesMut::with_capacity(4096); - ntex::codec::Encoder::encode( - &mut crate::v5::codec::Codec::new(), - res.clone(), - &mut tmp, - ) - .unwrap(); + ntex::codec::Encoder::encode(&crate::v5::codec::Codec::new(), res.clone(), &mut tmp) + .unwrap(); let decoded = decode_packet(cur, fixed); let res = Ok(res); if decoded != res { diff --git a/src/v5/codec/encode.rs b/src/v5/codec/encode.rs index aeccdf0..1deb1f9 100644 --- a/src/v5/codec/encode.rs +++ b/src/v5/codec/encode.rs @@ -311,7 +311,7 @@ mod tests { let mut v = BytesMut::with_capacity(1024); packet.encode(&mut v, packet.encoded_size(1024) as u32).unwrap(); assert_eq!(expected.len(), v.len()); - assert_eq!(&expected[..], &v[..]); + assert_eq!(expected, &v[..]); } #[test] diff --git a/tests/test_server.rs b/tests/test_server.rs index 2a12409..ee15f41 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -475,11 +475,8 @@ async fn test_max_qos() -> std::io::Result<()> { let violated = violated.clone(); match msg { ControlMessage::ProtocolError(err) => { - match err.get_ref() { - ProtocolError::ProtocolViolation(_) => { - violated.store(true, Relaxed); - } - _ => (), + if let ProtocolError::ProtocolViolation(_) = err.get_ref() { + violated.store(true, Relaxed); } Ready::Ok(err.ack()) } diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 6e61b66..8424ada 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -919,11 +919,8 @@ async fn test_max_qos() -> std::io::Result<()> { let violated = violated.clone(); match msg { ControlMessage::ProtocolError(msg) => { - match msg.get_ref() { - error::ProtocolError::ProtocolViolation(_) => { - violated.store(true, Relaxed); - } - _ => (), + if let error::ProtocolError::ProtocolViolation(_) = msg.get_ref() { + violated.store(true, Relaxed); } Ready::Ok::<_, TestError>(msg.ack()) }