Skip to content

Commit

Permalink
Fix nested error handling for control service (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 authored Oct 3, 2023
1 parent 9636bf3 commit 2bab9df
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.12.4] - 2023-10-03

* Fix nested error handling for control service

## [0.12.3] - 2023-10-01

* Fix Publish and Control error type
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.12.3"
version = "0.12.4"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand Down
10 changes: 5 additions & 5 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,10 @@ where
}

loop {
// println!("IO-DISP state :{:?}: {:?}", inner.io.flags(), inner.st);

match inner.st {
IoDispatcherState::Processing => {
// println!("IO-DISP state :{:?}:", io.flags());
let item = match ready!(inner.poll_service(this.service, cx,)) {
PollService::Ready => {
// decode incoming bytes stream
Expand Down Expand Up @@ -449,17 +450,16 @@ where
let mut state = self.state.borrow_mut();
Poll::Ready(if let Some(err) = state.error.take() {
log::trace!("error occured, stopping dispatcher");
self.st = IoDispatcherState::Stop;
match err {
IoDispatcherError::Encoder(err) => {
self.st = IoDispatcherState::Stop;
PollService::Item(DispatchItem::EncoderError(err))
}
IoDispatcherError::Service(err) => {
state.error = Some(IoDispatcherError::Service(err));
PollService::Ready
PollService::Continue
}
IoDispatcherError::KeepAlive => {
self.st = IoDispatcherState::Stop;
PollService::Item(DispatchItem::KeepAliveTimeout)
}
}
Expand Down Expand Up @@ -692,7 +692,7 @@ mod tests {
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));

// write side must be closed, dispatcher waiting for read side to close
assert!(!client.is_closed());
assert!(client.is_closed());

// close read side
client.close().await;
Expand Down
25 changes: 25 additions & 0 deletions tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,31 @@ async fn test_handle_incoming() -> std::io::Result<()> {
Ok(())
}

#[ntex::test]
async fn test_nested_errors() -> std::io::Result<()> {
let srv = server::test_server(move || {
MqttServer::new(handshake)
.publish(|_| Ready::Ok(()))
.control(move |msg| match msg {
ControlMessage::Disconnect(_) => Ready::Err(()),
ControlMessage::Error(_) => Ready::Err(()),
_ => Ready::Ok(msg.disconnect()),
})
.finish()
});

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

// disconnect
io.send(codec::Packet::Disconnect, &codec).await.unwrap();
assert!(io.recv(&codec).await.unwrap().is_none());

Ok(())
}

#[ntex::test]
async fn test_large_publish() -> std::io::Result<()> {
let srv = server::test_server(move || {
Expand Down
89 changes: 63 additions & 26 deletions tests/test_server_v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ async fn test_disconnect() -> std::io::Result<()> {
client::MqttConnector::new(srv.addr()).client_id("user").connect().await.unwrap();

let sink = client.sink();

ntex::rt::spawn(client.start_default());

let res =
Expand Down Expand Up @@ -152,7 +151,6 @@ async fn test_disconnect_with_reason() -> std::io::Result<()> {
client::MqttConnector::new(srv.addr()).client_id("user").connect().await.unwrap();

let sink = client.sink();

ntex::rt::spawn(client.start_default());

let res =
Expand All @@ -162,6 +160,63 @@ async fn test_disconnect_with_reason() -> std::io::Result<()> {
Ok(())
}

#[ntex::test]
async fn test_nested_errors_handling() -> std::io::Result<()> {
let srv = server::test_server(|| {
MqttServer::new(handshake)
.publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack()))
.control(move |msg| match msg {
ControlMessage::Disconnect(_) => Ready::Err(TestError),
ControlMessage::Error(_) => Ready::Err(TestError),
ControlMessage::Closed(m) => Ready::Ok(m.ack()),
_ => panic!("{:?}", msg),
})
.finish()
});

// connect to server
let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

// disconnect
io.send(codec::Disconnect::default().into(), &codec).await.unwrap();
assert!(io.recv(&codec).await.unwrap().is_none());

Ok(())
}

#[ntex::test]
async fn test_disconnect_on_error() -> std::io::Result<()> {
let srv = server::test_server(|| {
MqttServer::new(handshake)
.publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack()))
.control(move |msg| match msg {
ControlMessage::Disconnect(_) => Ready::Err(TestError),
ControlMessage::Error(m) => {
Ready::Ok(m.ack(codec::DisconnectReasonCode::ImplementationSpecificError))
}
ControlMessage::Closed(m) => Ready::Ok(m.ack()),
_ => panic!("{:?}", msg),
})
.finish()
});

// connect to server
let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

// disconnect
io.send(codec::Disconnect::default().into(), &codec).await.unwrap();
let res = io.recv(&codec).await.unwrap().unwrap();
assert!(matches!(res.0, codec::Packet::Disconnect(_)));

Ok(())
}

#[ntex::test]
async fn test_disconnect_after_control_error() -> std::io::Result<()> {
env_logger::init();
Expand Down Expand Up @@ -207,10 +262,7 @@ async fn test_disconnect_after_control_error() -> std::io::Result<()> {
.unwrap();

let result = io.recv(&codec).await.unwrap().unwrap();
if let codec::Packet::Disconnect(_) = result.0 {
} else {
panic!();
}
assert!(matches!(result.0, codec::Packet::Disconnect(_)));
Ok(())
}

Expand Down Expand Up @@ -238,12 +290,7 @@ async fn test_ping() -> std::io::Result<()> {

let io = srv.connect().await.unwrap();
let codec = codec::Codec::new();
io.send(
codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))),
&codec,
)
.await
.unwrap();
io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

io.send(codec::Packet::PingRequest, &codec).await.unwrap();
Expand Down Expand Up @@ -278,12 +325,7 @@ async fn test_ack_order() -> std::io::Result<()> {

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.send(
codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))),
&codec,
)
.await
.unwrap();
io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

io.send(
Expand Down Expand Up @@ -351,14 +393,9 @@ async fn test_dups() {

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.send(
codec::Packet::Connect(Box::new(
codec::Connect::default().client_id("user").receive_max(2),
)),
&codec,
)
.await
.unwrap();
io.send(codec::Connect::default().client_id("user").receive_max(2).into(), &codec)
.await
.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

io.send(
Expand Down

0 comments on commit 2bab9df

Please sign in to comment.