From effca568032eaffcdb9d3f984000b238962ce438 Mon Sep 17 00:00:00 2001 From: ngthhu Date: Thu, 25 Apr 2024 17:22:01 +0700 Subject: [PATCH] refactor: [#681] udp return errors instead of panicking --- src/console/clients/udp/checker.rs | 16 +- src/shared/bit_torrent/tracker/udp/client.rs | 215 +++++++++++-------- tests/servers/udp/contract.rs | 71 ++++-- 3 files changed, 193 insertions(+), 109 deletions(-) diff --git a/src/console/clients/udp/checker.rs b/src/console/clients/udp/checker.rs index 12b8d764..9b2a9011 100644 --- a/src/console/clients/udp/checker.rs +++ b/src/console/clients/udp/checker.rs @@ -64,7 +64,7 @@ impl Client { let binding_address = local_bind_to.parse().context("binding local address")?; debug!("Binding to: {local_bind_to}"); - let udp_client = UdpClient::bind(&local_bind_to).await; + let udp_client = UdpClient::bind(&local_bind_to).await?; let bound_to = udp_client.socket.local_addr().context("bound local address")?; debug!("Bound to: {bound_to}"); @@ -88,7 +88,7 @@ impl Client { match &self.udp_tracker_client { Some(client) => { - client.udp_client.connect(&tracker_socket_addr.to_string()).await; + client.udp_client.connect(&tracker_socket_addr.to_string()).await?; self.remote_socket = Some(*tracker_socket_addr); Ok(()) } @@ -116,9 +116,9 @@ impl Client { match &self.udp_tracker_client { Some(client) => { - client.send(connect_request.into()).await; + client.send(connect_request.into()).await?; - let response = client.receive().await; + let response = client.receive().await?; debug!("connection request response:\n{response:#?}"); @@ -163,9 +163,9 @@ impl Client { match &self.udp_tracker_client { Some(client) => { - client.send(announce_request.into()).await; + client.send(announce_request.into()).await?; - let response = client.receive().await; + let response = client.receive().await?; debug!("announce request response:\n{response:#?}"); @@ -200,9 +200,9 @@ impl Client { match &self.udp_tracker_client { Some(client) => { - client.send(scrape_request.into()).await; + client.send(scrape_request.into()).await?; - let response = client.receive().await; + let response = client.receive().await?; debug!("scrape request response:\n{response:#?}"); diff --git a/src/shared/bit_torrent/tracker/udp/client.rs b/src/shared/bit_torrent/tracker/udp/client.rs index 11c8d8f6..9af9571b 100644 --- a/src/shared/bit_torrent/tracker/udp/client.rs +++ b/src/shared/bit_torrent/tracker/udp/client.rs @@ -1,8 +1,10 @@ +use core::result::Result::{Err, Ok}; use std::io::Cursor; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use anyhow::{anyhow, Context, Result}; use aquatic_udp_protocol::{ConnectRequest, Request, Response, TransactionId}; use log::debug; use tokio::net::UdpSocket; @@ -25,99 +27,120 @@ pub struct UdpClient { } impl UdpClient { - /// # Panics + /// # Errors + /// + /// Will return error if the local address can't be bound. /// - /// Will panic if the local address can't be bound. - pub async fn bind(local_address: &str) -> Self { - let valid_socket_addr = local_address + pub async fn bind(local_address: &str) -> Result { + let socket_addr = local_address .parse::() - .unwrap_or_else(|_| panic!("{local_address} is not a valid socket address")); + .context(format!("{local_address} is not a valid socket address"))?; - let socket = UdpSocket::bind(valid_socket_addr).await.unwrap(); + let socket = UdpSocket::bind(socket_addr).await?; - Self { + let udp_client = Self { socket: Arc::new(socket), timeout: DEFAULT_TIMEOUT, - } + }; + Ok(udp_client) } - /// # Panics + /// # Errors /// - /// Will panic if can't connect to the socket. - pub async fn connect(&self, remote_address: &str) { - let valid_socket_addr = remote_address + /// Will return error if can't connect to the socket. + pub async fn connect(&self, remote_address: &str) -> Result<()> { + let socket_addr = remote_address .parse::() - .unwrap_or_else(|_| panic!("{remote_address} is not a valid socket address")); + .context(format!("{remote_address} is not a valid socket address"))?; - match self.socket.connect(valid_socket_addr).await { - Ok(()) => debug!("Connected successfully"), - Err(e) => panic!("Failed to connect: {e:?}"), + match self.socket.connect(socket_addr).await { + Ok(()) => { + debug!("Connected successfully"); + Ok(()) + } + Err(e) => Err(anyhow!("Failed to connect: {e:?}")), } } - /// # Panics + /// # Errors /// - /// Will panic if: + /// Will return error if: /// /// - Can't write to the socket. /// - Can't send data. - pub async fn send(&self, bytes: &[u8]) -> usize { + pub async fn send(&self, bytes: &[u8]) -> Result { debug!(target: "UDP client", "sending {bytes:?} ..."); match time::timeout(self.timeout, self.socket.writable()).await { - Ok(writable_result) => match writable_result { - Ok(()) => (), - Err(e) => panic!("{}", format!("IO error waiting for the socket to become readable: {e:?}")), - }, - Err(e) => panic!("{}", format!("Timeout waiting for the socket to become readable: {e:?}")), + Ok(writable_result) => { + match writable_result { + Ok(()) => (), + Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")), + }; + } + Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")), }; match time::timeout(self.timeout, self.socket.send(bytes)).await { Ok(send_result) => match send_result { - Ok(size) => size, - Err(e) => panic!("{}", format!("IO error during send: {e:?}")), + Ok(size) => Ok(size), + Err(e) => Err(anyhow!("IO error during send: {e:?}")), }, - Err(e) => panic!("{}", format!("Send operation timed out: {e:?}")), + Err(e) => Err(anyhow!("Send operation timed out: {e:?}")), } } - /// # Panics + /// # Errors /// - /// Will panic if: + /// Will return error if: /// /// - Can't read from the socket. /// - Can't receive data. - pub async fn receive(&self, bytes: &mut [u8]) -> usize { + /// + /// # Panics + /// + pub async fn receive(&self, bytes: &mut [u8]) -> Result { debug!(target: "UDP client", "receiving ..."); match time::timeout(self.timeout, self.socket.readable()).await { - Ok(readable_result) => match readable_result { - Ok(()) => (), - Err(e) => panic!("{}", format!("IO error waiting for the socket to become readable: {e:?}")), - }, - Err(e) => panic!("{}", format!("Timeout waiting for the socket to become readable: {e:?}")), + Ok(readable_result) => { + match readable_result { + Ok(()) => (), + Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")), + }; + } + Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")), }; - let size = match time::timeout(self.timeout, self.socket.recv(bytes)).await { + let size_result = match time::timeout(self.timeout, self.socket.recv(bytes)).await { Ok(recv_result) => match recv_result { - Ok(size) => size, - Err(e) => panic!("{}", format!("IO error during send: {e:?}")), + Ok(size) => Ok(size), + Err(e) => Err(anyhow!("IO error during send: {e:?}")), }, - Err(e) => panic!("{}", format!("Receive operation timed out: {e:?}")), + Err(e) => Err(anyhow!("Receive operation timed out: {e:?}")), }; - debug!(target: "UDP client", "{size} bytes received {bytes:?}"); - - size + if size_result.is_ok() { + let size = size_result.as_ref().unwrap(); + debug!(target: "UDP client", "{size} bytes received {bytes:?}"); + size_result + } else { + size_result + } } } /// Creates a new `UdpClient` connected to a Udp server -pub async fn new_udp_client_connected(remote_address: &str) -> UdpClient { +/// +/// # Errors +/// +/// Will return any errors present in the call stack +/// +pub async fn new_udp_client_connected(remote_address: &str) -> Result { let port = 0; // Let OS choose an unused port. - let client = UdpClient::bind(&source_address(port)).await; - client.connect(remote_address).await; - client + let client = UdpClient::bind(&source_address(port)).await?; + client.connect(remote_address).await?; + Ok(client) } #[allow(clippy::module_name_repetitions)] @@ -127,85 +150,103 @@ pub struct UdpTrackerClient { } impl UdpTrackerClient { - /// # Panics + /// # Errors /// - /// Will panic if can't write request to bytes. - pub async fn send(&self, request: Request) -> usize { + /// Will return error if can't write request to bytes. + pub async fn send(&self, request: Request) -> Result { debug!(target: "UDP tracker client", "send request {request:?}"); // Write request into a buffer let request_buffer = vec![0u8; MAX_PACKET_SIZE]; let mut cursor = Cursor::new(request_buffer); - let request_data = match request.write(&mut cursor) { + let request_data_result = match request.write(&mut cursor) { Ok(()) => { #[allow(clippy::cast_possible_truncation)] let position = cursor.position() as usize; let inner_request_buffer = cursor.get_ref(); // Return slice which contains written request data - &inner_request_buffer[..position] + Ok(&inner_request_buffer[..position]) } - Err(e) => panic!("could not write request to bytes: {e}."), + Err(e) => Err(anyhow!("could not write request to bytes: {e}.")), }; + let request_data = request_data_result?; + self.udp_client.send(request_data).await } - /// # Panics + /// # Errors /// - /// Will panic if can't create response from the received payload (bytes buffer). - pub async fn receive(&self) -> Response { + /// Will return error if can't create response from the received payload (bytes buffer). + pub async fn receive(&self) -> Result { let mut response_buffer = [0u8; MAX_PACKET_SIZE]; - let payload_size = self.udp_client.receive(&mut response_buffer).await; + let payload_size = self.udp_client.receive(&mut response_buffer).await?; debug!(target: "UDP tracker client", "received {payload_size} bytes. Response {response_buffer:?}"); - Response::from_bytes(&response_buffer[..payload_size], true).unwrap() + let response = Response::from_bytes(&response_buffer[..payload_size], true)?; + + Ok(response) } } /// Creates a new `UdpTrackerClient` connected to a Udp Tracker server -pub async fn new_udp_tracker_client_connected(remote_address: &str) -> UdpTrackerClient { - let udp_client = new_udp_client_connected(remote_address).await; - UdpTrackerClient { udp_client } +/// +/// # Errors +/// +/// Will return any errors present in the call stack +/// +pub async fn new_udp_tracker_client_connected(remote_address: &str) -> Result { + let udp_client = new_udp_client_connected(remote_address).await?; + let udp_tracker_client = UdpTrackerClient { udp_client }; + Ok(udp_tracker_client) } /// Helper Function to Check if a UDP Service is Connectable /// -/// # Errors +/// # Panics /// /// It will return an error if unable to connect to the UDP service. /// -/// # Panics +/// # Errors +/// pub async fn check(binding: &SocketAddr) -> Result { debug!("Checking Service (detail): {binding:?}."); - let client = new_udp_tracker_client_connected(binding.to_string().as_str()).await; - - let connect_request = ConnectRequest { - transaction_id: TransactionId(123), - }; - - client.send(connect_request.into()).await; - - let process = move |response| { - if matches!(response, Response::Connect(_connect_response)) { - Ok("Connected".to_string()) - } else { - Err("Did not Connect".to_string()) - } - }; - - let sleep = time::sleep(Duration::from_millis(2000)); - tokio::pin!(sleep); - - tokio::select! { - () = &mut sleep => { - Err("Timed Out".to_string()) - } - response = client.receive() => { - process(response) + match new_udp_tracker_client_connected(binding.to_string().as_str()).await { + Ok(client) => { + let connect_request = ConnectRequest { + transaction_id: TransactionId(123), + }; + + // client.send() return usize, but doesn't use here + match client.send(connect_request.into()).await { + Ok(_) => (), + Err(e) => debug!("Error: {e:?}."), + }; + + let process = move |response| { + if matches!(response, Response::Connect(_connect_response)) { + Ok("Connected".to_string()) + } else { + Err("Did not Connect".to_string()) + } + }; + + let sleep = time::sleep(Duration::from_millis(2000)); + tokio::pin!(sleep); + + tokio::select! { + () = &mut sleep => { + Err("Timed Out".to_string()) + } + response = client.receive() => { + process(response.unwrap()) + } + } } + Err(e) => Err(format!("{e:?}")), } } diff --git a/tests/servers/udp/contract.rs b/tests/servers/udp/contract.rs index 91dca4d4..56e400f8 100644 --- a/tests/servers/udp/contract.rs +++ b/tests/servers/udp/contract.rs @@ -24,9 +24,15 @@ fn empty_buffer() -> [u8; MAX_PACKET_SIZE] { async fn send_connection_request(transaction_id: TransactionId, client: &UdpTrackerClient) -> ConnectionId { let connect_request = ConnectRequest { transaction_id }; - client.send(connect_request.into()).await; + match client.send(connect_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; - let response = client.receive().await; + let response = match client.receive().await { + Ok(response) => response, + Err(err) => panic!("{err}"), + }; match response { Response::Connect(connect_response) => connect_response.connection_id, @@ -38,12 +44,22 @@ async fn send_connection_request(transaction_id: TransactionId, client: &UdpTrac async fn should_return_a_bad_request_response_when_the_client_sends_an_empty_request() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = new_udp_client_connected(&env.bind_address().to_string()).await; + let client = match new_udp_client_connected(&env.bind_address().to_string()).await { + Ok(udp_client) => udp_client, + Err(err) => panic!("{err}"), + }; - client.send(&empty_udp_request()).await; + match client.send(&empty_udp_request()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; let mut buffer = empty_buffer(); - client.receive(&mut buffer).await; + match client.receive(&mut buffer).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; + let response = Response::from_bytes(&buffer, true).unwrap(); assert!(is_error_response(&response, "bad request")); @@ -63,15 +79,24 @@ mod receiving_a_connection_request { async fn should_return_a_connect_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = new_udp_tracker_client_connected(&env.bind_address().to_string()).await; + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; let connect_request = ConnectRequest { transaction_id: TransactionId(123), }; - client.send(connect_request.into()).await; + match client.send(connect_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; - let response = client.receive().await; + let response = match client.receive().await { + Ok(response) => response, + Err(err) => panic!("{err}"), + }; assert!(is_connect_response(&response, TransactionId(123))); @@ -97,7 +122,10 @@ mod receiving_an_announce_request { async fn should_return_an_announce_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = new_udp_tracker_client_connected(&env.bind_address().to_string()).await; + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; let connection_id = send_connection_request(TransactionId(123), &client).await; @@ -118,9 +146,15 @@ mod receiving_an_announce_request { port: Port(client.udp_client.socket.local_addr().unwrap().port()), }; - client.send(announce_request.into()).await; + match client.send(announce_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; - let response = client.receive().await; + let response = match client.receive().await { + Ok(response) => response, + Err(err) => panic!("{err}"), + }; println!("test response {response:?}"); @@ -143,7 +177,10 @@ mod receiving_an_scrape_request { async fn should_return_a_scrape_response() { let env = Started::new(&configuration::ephemeral().into()).await; - let client = new_udp_tracker_client_connected(&env.bind_address().to_string()).await; + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; let connection_id = send_connection_request(TransactionId(123), &client).await; @@ -159,9 +196,15 @@ mod receiving_an_scrape_request { info_hashes, }; - client.send(scrape_request.into()).await; + match client.send(scrape_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; - let response = client.receive().await; + let response = match client.receive().await { + Ok(response) => response, + Err(err) => panic!("{err}"), + }; assert!(is_scrape_response(&response));