Skip to content

Commit

Permalink
Merge #814: Refactor UdpClient to return errors instead of panicking
Browse files Browse the repository at this point in the history
effca56 refactor: [#681] udp return errors instead of panicking (ngthhu)

Pull request description:

  I replaced panics with error handling using `anyhow::Result<T>`, which is equivalent to `std::result::Result<T, anyhow::Error>`. Although I can run `cargo run` without errors, I'm not certain if there are any hidden bugs.

ACKs for top commit:
  josecelano:
    ACK effca56

Tree-SHA512: 30653458d1f2b3155dbdddd0f5197cee6746819122db45d58c0aaa4a0b9cae0fccee76b769afb22b57749c2140d042580125ad2689fafa0b65ec6847866f659b
  • Loading branch information
josecelano committed May 2, 2024
2 parents 92349d3 + effca56 commit 90c7780
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 109 deletions.
16 changes: 8 additions & 8 deletions src/console/clients/udp/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Client {
let binding_address = local_bind_to.parse().context("binding local address")?;

debug!("Binding to: {local_bind_to}");
let udp_client = UdpClient::bind(&local_bind_to).await;
let udp_client = UdpClient::bind(&local_bind_to).await?;

let bound_to = udp_client.socket.local_addr().context("bound local address")?;
debug!("Bound to: {bound_to}");
Expand All @@ -88,7 +88,7 @@ impl Client {

match &self.udp_tracker_client {
Some(client) => {
client.udp_client.connect(&tracker_socket_addr.to_string()).await;
client.udp_client.connect(&tracker_socket_addr.to_string()).await?;
self.remote_socket = Some(*tracker_socket_addr);
Ok(())
}
Expand Down Expand Up @@ -116,9 +116,9 @@ impl Client {

match &self.udp_tracker_client {
Some(client) => {
client.send(connect_request.into()).await;
client.send(connect_request.into()).await?;

let response = client.receive().await;
let response = client.receive().await?;

debug!("connection request response:\n{response:#?}");

Expand Down Expand Up @@ -163,9 +163,9 @@ impl Client {

match &self.udp_tracker_client {
Some(client) => {
client.send(announce_request.into()).await;
client.send(announce_request.into()).await?;

let response = client.receive().await;
let response = client.receive().await?;

debug!("announce request response:\n{response:#?}");

Expand Down Expand Up @@ -200,9 +200,9 @@ impl Client {

match &self.udp_tracker_client {
Some(client) => {
client.send(scrape_request.into()).await;
client.send(scrape_request.into()).await?;

let response = client.receive().await;
let response = client.receive().await?;

debug!("scrape request response:\n{response:#?}");

Expand Down
215 changes: 128 additions & 87 deletions src/shared/bit_torrent/tracker/udp/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::result::Result::{Err, Ok};
use std::io::Cursor;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::{anyhow, Context, Result};
use aquatic_udp_protocol::{ConnectRequest, Request, Response, TransactionId};
use log::debug;
use tokio::net::UdpSocket;
Expand All @@ -25,99 +27,120 @@ pub struct UdpClient {
}

impl UdpClient {
/// # Panics
/// # Errors
///
/// Will return error if the local address can't be bound.
///
/// Will panic if the local address can't be bound.
pub async fn bind(local_address: &str) -> Self {
let valid_socket_addr = local_address
pub async fn bind(local_address: &str) -> Result<Self> {
let socket_addr = local_address
.parse::<SocketAddr>()
.unwrap_or_else(|_| panic!("{local_address} is not a valid socket address"));
.context(format!("{local_address} is not a valid socket address"))?;

let socket = UdpSocket::bind(valid_socket_addr).await.unwrap();
let socket = UdpSocket::bind(socket_addr).await?;

Self {
let udp_client = Self {
socket: Arc::new(socket),
timeout: DEFAULT_TIMEOUT,
}
};
Ok(udp_client)
}

/// # Panics
/// # Errors
///
/// Will panic if can't connect to the socket.
pub async fn connect(&self, remote_address: &str) {
let valid_socket_addr = remote_address
/// Will return error if can't connect to the socket.
pub async fn connect(&self, remote_address: &str) -> Result<()> {
let socket_addr = remote_address
.parse::<SocketAddr>()
.unwrap_or_else(|_| panic!("{remote_address} is not a valid socket address"));
.context(format!("{remote_address} is not a valid socket address"))?;

match self.socket.connect(valid_socket_addr).await {
Ok(()) => debug!("Connected successfully"),
Err(e) => panic!("Failed to connect: {e:?}"),
match self.socket.connect(socket_addr).await {
Ok(()) => {
debug!("Connected successfully");
Ok(())
}
Err(e) => Err(anyhow!("Failed to connect: {e:?}")),
}
}

/// # Panics
/// # Errors
///
/// Will panic if:
/// Will return error if:
///
/// - Can't write to the socket.
/// - Can't send data.
pub async fn send(&self, bytes: &[u8]) -> usize {
pub async fn send(&self, bytes: &[u8]) -> Result<usize> {
debug!(target: "UDP client", "sending {bytes:?} ...");

match time::timeout(self.timeout, self.socket.writable()).await {
Ok(writable_result) => match writable_result {
Ok(()) => (),
Err(e) => panic!("{}", format!("IO error waiting for the socket to become readable: {e:?}")),
},
Err(e) => panic!("{}", format!("Timeout waiting for the socket to become readable: {e:?}")),
Ok(writable_result) => {
match writable_result {
Ok(()) => (),
Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")),
};
}
Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")),
};

match time::timeout(self.timeout, self.socket.send(bytes)).await {
Ok(send_result) => match send_result {
Ok(size) => size,
Err(e) => panic!("{}", format!("IO error during send: {e:?}")),
Ok(size) => Ok(size),
Err(e) => Err(anyhow!("IO error during send: {e:?}")),
},
Err(e) => panic!("{}", format!("Send operation timed out: {e:?}")),
Err(e) => Err(anyhow!("Send operation timed out: {e:?}")),
}
}

/// # Panics
/// # Errors
///
/// Will panic if:
/// Will return error if:
///
/// - Can't read from the socket.
/// - Can't receive data.
pub async fn receive(&self, bytes: &mut [u8]) -> usize {
///
/// # Panics
///
pub async fn receive(&self, bytes: &mut [u8]) -> Result<usize> {
debug!(target: "UDP client", "receiving ...");

match time::timeout(self.timeout, self.socket.readable()).await {
Ok(readable_result) => match readable_result {
Ok(()) => (),
Err(e) => panic!("{}", format!("IO error waiting for the socket to become readable: {e:?}")),
},
Err(e) => panic!("{}", format!("Timeout waiting for the socket to become readable: {e:?}")),
Ok(readable_result) => {
match readable_result {
Ok(()) => (),
Err(e) => return Err(anyhow!("IO error waiting for the socket to become readable: {e:?}")),
};
}
Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")),
};

let size = match time::timeout(self.timeout, self.socket.recv(bytes)).await {
let size_result = match time::timeout(self.timeout, self.socket.recv(bytes)).await {
Ok(recv_result) => match recv_result {
Ok(size) => size,
Err(e) => panic!("{}", format!("IO error during send: {e:?}")),
Ok(size) => Ok(size),
Err(e) => Err(anyhow!("IO error during send: {e:?}")),
},
Err(e) => panic!("{}", format!("Receive operation timed out: {e:?}")),
Err(e) => Err(anyhow!("Receive operation timed out: {e:?}")),
};

debug!(target: "UDP client", "{size} bytes received {bytes:?}");

size
if size_result.is_ok() {
let size = size_result.as_ref().unwrap();
debug!(target: "UDP client", "{size} bytes received {bytes:?}");
size_result
} else {
size_result
}
}
}

/// Creates a new `UdpClient` connected to a Udp server
pub async fn new_udp_client_connected(remote_address: &str) -> UdpClient {
///
/// # Errors
///
/// Will return any errors present in the call stack
///
pub async fn new_udp_client_connected(remote_address: &str) -> Result<UdpClient> {
let port = 0; // Let OS choose an unused port.
let client = UdpClient::bind(&source_address(port)).await;
client.connect(remote_address).await;
client
let client = UdpClient::bind(&source_address(port)).await?;
client.connect(remote_address).await?;
Ok(client)
}

#[allow(clippy::module_name_repetitions)]
Expand All @@ -127,85 +150,103 @@ pub struct UdpTrackerClient {
}

impl UdpTrackerClient {
/// # Panics
/// # Errors
///
/// Will panic if can't write request to bytes.
pub async fn send(&self, request: Request) -> usize {
/// Will return error if can't write request to bytes.
pub async fn send(&self, request: Request) -> Result<usize> {
debug!(target: "UDP tracker client", "send request {request:?}");

// Write request into a buffer
let request_buffer = vec![0u8; MAX_PACKET_SIZE];
let mut cursor = Cursor::new(request_buffer);

let request_data = match request.write(&mut cursor) {
let request_data_result = match request.write(&mut cursor) {
Ok(()) => {
#[allow(clippy::cast_possible_truncation)]
let position = cursor.position() as usize;
let inner_request_buffer = cursor.get_ref();
// Return slice which contains written request data
&inner_request_buffer[..position]
Ok(&inner_request_buffer[..position])
}
Err(e) => panic!("could not write request to bytes: {e}."),
Err(e) => Err(anyhow!("could not write request to bytes: {e}.")),
};

let request_data = request_data_result?;

self.udp_client.send(request_data).await
}

/// # Panics
/// # Errors
///
/// Will panic if can't create response from the received payload (bytes buffer).
pub async fn receive(&self) -> Response {
/// Will return error if can't create response from the received payload (bytes buffer).
pub async fn receive(&self) -> Result<Response> {
let mut response_buffer = [0u8; MAX_PACKET_SIZE];

let payload_size = self.udp_client.receive(&mut response_buffer).await;
let payload_size = self.udp_client.receive(&mut response_buffer).await?;

debug!(target: "UDP tracker client", "received {payload_size} bytes. Response {response_buffer:?}");

Response::from_bytes(&response_buffer[..payload_size], true).unwrap()
let response = Response::from_bytes(&response_buffer[..payload_size], true)?;

Ok(response)
}
}

/// Creates a new `UdpTrackerClient` connected to a Udp Tracker server
pub async fn new_udp_tracker_client_connected(remote_address: &str) -> UdpTrackerClient {
let udp_client = new_udp_client_connected(remote_address).await;
UdpTrackerClient { udp_client }
///
/// # Errors
///
/// Will return any errors present in the call stack
///
pub async fn new_udp_tracker_client_connected(remote_address: &str) -> Result<UdpTrackerClient> {
let udp_client = new_udp_client_connected(remote_address).await?;
let udp_tracker_client = UdpTrackerClient { udp_client };
Ok(udp_tracker_client)
}

/// Helper Function to Check if a UDP Service is Connectable
///
/// # Errors
/// # Panics
///
/// It will return an error if unable to connect to the UDP service.
///
/// # Panics
/// # Errors
///
pub async fn check(binding: &SocketAddr) -> Result<String, String> {
debug!("Checking Service (detail): {binding:?}.");

let client = new_udp_tracker_client_connected(binding.to_string().as_str()).await;

let connect_request = ConnectRequest {
transaction_id: TransactionId(123),
};

client.send(connect_request.into()).await;

let process = move |response| {
if matches!(response, Response::Connect(_connect_response)) {
Ok("Connected".to_string())
} else {
Err("Did not Connect".to_string())
}
};

let sleep = time::sleep(Duration::from_millis(2000));
tokio::pin!(sleep);

tokio::select! {
() = &mut sleep => {
Err("Timed Out".to_string())
}
response = client.receive() => {
process(response)
match new_udp_tracker_client_connected(binding.to_string().as_str()).await {
Ok(client) => {
let connect_request = ConnectRequest {
transaction_id: TransactionId(123),
};

// client.send() return usize, but doesn't use here
match client.send(connect_request.into()).await {
Ok(_) => (),
Err(e) => debug!("Error: {e:?}."),
};

let process = move |response| {
if matches!(response, Response::Connect(_connect_response)) {
Ok("Connected".to_string())
} else {
Err("Did not Connect".to_string())
}
};

let sleep = time::sleep(Duration::from_millis(2000));
tokio::pin!(sleep);

tokio::select! {
() = &mut sleep => {
Err("Timed Out".to_string())
}
response = client.receive() => {
process(response.unwrap())
}
}
}
Err(e) => Err(format!("{e:?}")),
}
}
Loading

0 comments on commit 90c7780

Please sign in to comment.