diff --git a/CHANGES.md b/CHANGES.md index 4c2c75b..45d346d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.12.4] - 2023-10-03 + +* Fix nested error handling for control service + ## [0.12.3] - 2023-10-01 * Fix Publish and Control error type diff --git a/Cargo.toml b/Cargo.toml index 70443f8..9585184 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "0.12.3" +version = "0.12.4" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" diff --git a/src/io.rs b/src/io.rs index 5185a1c..3c1d7c0 100644 --- a/src/io.rs +++ b/src/io.rs @@ -268,9 +268,10 @@ where } loop { + // println!("IO-DISP state :{:?}: {:?}", inner.io.flags(), inner.st); + match inner.st { IoDispatcherState::Processing => { - // println!("IO-DISP state :{:?}:", io.flags()); let item = match ready!(inner.poll_service(this.service, cx,)) { PollService::Ready => { // decode incoming bytes stream @@ -449,17 +450,16 @@ where let mut state = self.state.borrow_mut(); Poll::Ready(if let Some(err) = state.error.take() { log::trace!("error occured, stopping dispatcher"); + self.st = IoDispatcherState::Stop; match err { IoDispatcherError::Encoder(err) => { - self.st = IoDispatcherState::Stop; PollService::Item(DispatchItem::EncoderError(err)) } IoDispatcherError::Service(err) => { state.error = Some(IoDispatcherError::Service(err)); - PollService::Ready + PollService::Continue } IoDispatcherError::KeepAlive => { - self.st = IoDispatcherState::Stop; PollService::Item(DispatchItem::KeepAliveTimeout) } } @@ -692,7 +692,7 @@ mod tests { assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); // write side must be closed, dispatcher waiting for read side to close - assert!(!client.is_closed()); + assert!(client.is_closed()); // close read side client.close().await; diff --git a/tests/test_server.rs b/tests/test_server.rs index ee15f41..cbc23ff 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -381,6 +381,31 @@ async fn test_handle_incoming() -> std::io::Result<()> { Ok(()) } +#[ntex::test] +async fn test_nested_errors() -> std::io::Result<()> { + let srv = server::test_server(move || { + MqttServer::new(handshake) + .publish(|_| Ready::Ok(())) + .control(move |msg| match msg { + ControlMessage::Disconnect(_) => Ready::Err(()), + ControlMessage::Error(_) => Ready::Err(()), + _ => Ready::Ok(msg.disconnect()), + }) + .finish() + }); + + let io = srv.connect().await.unwrap(); + let codec = codec::Codec::default(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); + let _ = io.recv(&codec).await.unwrap().unwrap(); + + // disconnect + io.send(codec::Packet::Disconnect, &codec).await.unwrap(); + assert!(io.recv(&codec).await.unwrap().is_none()); + + Ok(()) +} + #[ntex::test] async fn test_large_publish() -> std::io::Result<()> { let srv = server::test_server(move || { diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 8424ada..03757b2 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -117,7 +117,6 @@ async fn test_disconnect() -> std::io::Result<()> { client::MqttConnector::new(srv.addr()).client_id("user").connect().await.unwrap(); let sink = client.sink(); - ntex::rt::spawn(client.start_default()); let res = @@ -152,7 +151,6 @@ async fn test_disconnect_with_reason() -> std::io::Result<()> { client::MqttConnector::new(srv.addr()).client_id("user").connect().await.unwrap(); let sink = client.sink(); - ntex::rt::spawn(client.start_default()); let res = @@ -162,6 +160,63 @@ async fn test_disconnect_with_reason() -> std::io::Result<()> { Ok(()) } +#[ntex::test] +async fn test_nested_errors_handling() -> std::io::Result<()> { + let srv = server::test_server(|| { + MqttServer::new(handshake) + .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) + .control(move |msg| match msg { + ControlMessage::Disconnect(_) => Ready::Err(TestError), + ControlMessage::Error(_) => Ready::Err(TestError), + ControlMessage::Closed(m) => Ready::Ok(m.ack()), + _ => panic!("{:?}", msg), + }) + .finish() + }); + + // connect to server + let io = srv.connect().await.unwrap(); + let codec = codec::Codec::default(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); + let _ = io.recv(&codec).await.unwrap().unwrap(); + + // disconnect + io.send(codec::Disconnect::default().into(), &codec).await.unwrap(); + assert!(io.recv(&codec).await.unwrap().is_none()); + + Ok(()) +} + +#[ntex::test] +async fn test_disconnect_on_error() -> std::io::Result<()> { + let srv = server::test_server(|| { + MqttServer::new(handshake) + .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) + .control(move |msg| match msg { + ControlMessage::Disconnect(_) => Ready::Err(TestError), + ControlMessage::Error(m) => { + Ready::Ok(m.ack(codec::DisconnectReasonCode::ImplementationSpecificError)) + } + ControlMessage::Closed(m) => Ready::Ok(m.ack()), + _ => panic!("{:?}", msg), + }) + .finish() + }); + + // connect to server + let io = srv.connect().await.unwrap(); + let codec = codec::Codec::default(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); + let _ = io.recv(&codec).await.unwrap().unwrap(); + + // disconnect + io.send(codec::Disconnect::default().into(), &codec).await.unwrap(); + let res = io.recv(&codec).await.unwrap().unwrap(); + assert!(matches!(res.0, codec::Packet::Disconnect(_))); + + Ok(()) +} + #[ntex::test] async fn test_disconnect_after_control_error() -> std::io::Result<()> { env_logger::init(); @@ -207,10 +262,7 @@ async fn test_disconnect_after_control_error() -> std::io::Result<()> { .unwrap(); let result = io.recv(&codec).await.unwrap().unwrap(); - if let codec::Packet::Disconnect(_) = result.0 { - } else { - panic!(); - } + assert!(matches!(result.0, codec::Packet::Disconnect(_))); Ok(()) } @@ -238,12 +290,7 @@ async fn test_ping() -> std::io::Result<()> { let io = srv.connect().await.unwrap(); let codec = codec::Codec::new(); - io.send( - codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), - &codec, - ) - .await - .unwrap(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); let _ = io.recv(&codec).await.unwrap().unwrap(); io.send(codec::Packet::PingRequest, &codec).await.unwrap(); @@ -278,12 +325,7 @@ async fn test_ack_order() -> std::io::Result<()> { let io = srv.connect().await.unwrap(); let codec = codec::Codec::default(); - io.send( - codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), - &codec, - ) - .await - .unwrap(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); let _ = io.recv(&codec).await.unwrap().unwrap(); io.send( @@ -351,14 +393,9 @@ async fn test_dups() { let io = srv.connect().await.unwrap(); let codec = codec::Codec::default(); - io.send( - codec::Packet::Connect(Box::new( - codec::Connect::default().client_id("user").receive_max(2), - )), - &codec, - ) - .await - .unwrap(); + io.send(codec::Connect::default().client_id("user").receive_max(2).into(), &codec) + .await + .unwrap(); let _ = io.recv(&codec).await.unwrap().unwrap(); io.send(