Skip to content

Commit

Permalink
refactor: rename vars and extract constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
josecelano committed Jun 25, 2024
1 parent 7ff0cd2 commit 16ae4fd
Showing 1 changed file with 36 additions and 33 deletions.
69 changes: 36 additions & 33 deletions src/servers/udp/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use derive_more::Constructor;
use futures::{Stream, StreamExt};
use ringbuf::traits::{Consumer, Observer, Producer};
use ringbuf::StaticRb;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::oneshot;
use tokio::task::{AbortHandle, JoinHandle};
Expand Down Expand Up @@ -255,6 +254,10 @@ impl BoundSocket {
socket: Arc::new(socket),
})
}

fn local_addr(&self) -> SocketAddr {
self.socket.local_addr().expect("it should get local address")
}
}

impl Deref for BoundSocket {
Expand All @@ -277,19 +280,29 @@ impl Debug for BoundSocket {
}

struct Receiver {
socket: Arc<UdpSocket>,
bound_socket: Arc<BoundSocket>,
tracker: Arc<Tracker>,
data: RefCell<[u8; MAX_PACKET_SIZE]>,
}

impl Receiver {
pub fn new(bound_socket: Arc<BoundSocket>, tracker: Arc<Tracker>) -> Self {
Receiver {
bound_socket,
tracker,
data: RefCell::new([0; MAX_PACKET_SIZE]),
}
}
}

impl Stream for Receiver {
type Item = std::io::Result<AbortHandle>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buf = *self.data.borrow_mut();
let mut buf = tokio::io::ReadBuf::new(&mut buf);

let Poll::Ready(ready) = self.socket.poll_recv_from(cx, &mut buf) else {
let Poll::Ready(ready) = self.bound_socket.poll_recv_from(cx, &mut buf) else {
return Poll::Pending;
};

Expand All @@ -301,7 +314,7 @@ impl Stream for Receiver {
Some(Ok(tokio::task::spawn(Udp::process_request(
request,
self.tracker.clone(),
self.socket.clone(),
self.bound_socket.clone(),
))
.abort_handle()))
}
Expand Down Expand Up @@ -338,34 +351,29 @@ impl Udp {
.await
.expect("it should bind to the socket within five seconds");

let socket = match socket {
let bound_socket = match socket {
Ok(socket) => socket,
Err(e) => {
tracing::error!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", addr = %bind_to, err = %e, "panic! (error when building socket)" );
panic!("could not bind to socket!");
}
};

let address = socket.local_addr().expect("it should get the locally bound address");
let local_addr = format!("udp://{address}");
let address = bound_socket.local_addr();
let local_udp_url = format!("udp://{address}");

// note: this log message is parsed by our container. i.e:
//
// `[UDP TRACKER][INFO] Starting on: udp://`
//
tracing::info!(target: "UDP TRACKER", "Starting on: {local_addr}");
tracing::info!(target: "UDP TRACKER", "Starting on: {local_udp_url}");

let socket = socket.socket;
let receiver = Receiver::new(bound_socket.into(), tracker);

let receiver = Receiver {
socket,
tracker,
data: RefCell::new([0; MAX_PACKET_SIZE]),
};
tracing::trace!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_udp_url, "(spawning main loop)");

tracing::trace!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(spawning main loop)");
let running = {
let local_addr = local_addr.clone();
let local_addr = local_udp_url.clone();
tokio::task::spawn(async move {
tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown::task", local_addr, "(listening...)");
let () = Self::run_udp_server_main(receiver).await;
Expand All @@ -376,29 +384,29 @@ impl Udp {
.send(Started { address })
.expect("the UDP Tracker service should not be dropped");

tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(started)");
tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_udp_url, "(started)");

let stop = running.abort_handle();

select! {
_ = running => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(stopped)"); },
_ = halt_task => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown",local_addr, "(halting)"); }
_ = running => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_udp_url, "(stopped)"); },
_ = halt_task => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown",local_udp_url, "(halting)"); }
}
stop.abort();

tokio::task::yield_now().await; // lets allow the other threads to complete.
}

async fn run_udp_server_main(mut direct: Receiver) {
async fn run_udp_server_main(mut receiver: Receiver) {
let reqs = &mut ActiveRequests::default();

let addr = direct.socket.local_addr().expect("it should get local address");
let addr = receiver.bound_socket.local_addr();
let local_addr = format!("udp://{addr}");

loop {
if let Some(req) = {
tracing::trace!(target: "UDP TRACKER: Udp::run_udp_server", local_addr, "(wait for request)");
direct.next().await
receiver.next().await
} {
tracing::trace!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, "(in)");

Expand Down Expand Up @@ -474,24 +482,19 @@ impl Udp {
}
}

async fn process_request(request: UdpRequest, tracker: Arc<Tracker>, socket: Arc<UdpSocket>) {
async fn process_request(request: UdpRequest, tracker: Arc<Tracker>, socket: Arc<BoundSocket>) {
tracing::trace!(target: "UDP TRACKER: Udp::process_request", request = %request.from, "(receiving)");
Self::process_valid_request(tracker, socket, request).await;
}

async fn process_valid_request(tracker: Arc<Tracker>, socket: Arc<UdpSocket>, udp_request: UdpRequest) {
async fn process_valid_request(tracker: Arc<Tracker>, socket: Arc<BoundSocket>, udp_request: UdpRequest) {
tracing::trace!(target: "UDP TRACKER: Udp::process_valid_request", "Making Response to {udp_request:?}");
let from = udp_request.from;
let response = handlers::handle_packet(
udp_request,
&tracker.clone(),
socket.local_addr().expect("it should get the local address"),
)
.await;
let response = handlers::handle_packet(udp_request, &tracker.clone(), socket.local_addr()).await;
Self::send_response(&socket.clone(), from, response).await;
}

async fn send_response(socket: &Arc<UdpSocket>, to: SocketAddr, response: Response) {
async fn send_response(bound_socket: &Arc<BoundSocket>, to: SocketAddr, response: Response) {
let response_type = match &response {
Response::Connect(_) => "Connect".to_string(),
Response::AnnounceIpv4(_) => "AnnounceIpv4".to_string(),
Expand All @@ -514,7 +517,7 @@ impl Udp {
tracing::debug!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), "(sending...)" );
tracing::trace!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), payload = ?&inner[..position], "(sending...)");

Self::send_packet(socket, &to, &inner[..position]).await;
Self::send_packet(bound_socket, &to, &inner[..position]).await;

tracing::trace!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), "(sent)");
}
Expand All @@ -524,7 +527,7 @@ impl Udp {
}
}

async fn send_packet(socket: &Arc<UdpSocket>, remote_addr: &SocketAddr, payload: &[u8]) {
async fn send_packet(socket: &Arc<BoundSocket>, remote_addr: &SocketAddr, payload: &[u8]) {
tracing::trace!(target: "UDP TRACKER: Udp::send_response", to = %remote_addr, ?payload, "(sending)");

// doesn't matter if it reaches or not
Expand Down

0 comments on commit 16ae4fd

Please sign in to comment.