From 19b4dcbe07cab8a8440cf179132946987106d0a1 Mon Sep 17 00:00:00 2001 From: Jose Celano Date: Wed, 21 Dec 2022 15:40:16 +0000 Subject: [PATCH] fix(api): [#58] return 401 when the token is missing --- src/api/server.rs | 62 +++++++++++++++++++++++++++++++++++++++++++---- tests/api.rs | 26 ++++++++++++++++++++ 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/src/api/server.rs b/src/api/server.rs index 5967a8be..8208f431 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -4,8 +4,10 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use log::error; use serde::{Deserialize, Serialize}; -use warp::{filters, reply, serve, Filter}; +use warp::http::StatusCode; +use warp::{filters, reject, reply, serve, Filter, Rejection, Reply}; use super::resource::auth_key::AuthKey; use super::resource::peer; @@ -27,6 +29,11 @@ enum ActionStatus<'a> { Err { reason: std::borrow::Cow<'a, str> }, } +#[derive(Debug)] +struct Unauthorized; + +impl reject::Reject for Unauthorized {} + impl warp::reject::Reject for ActionStatus<'static> {} fn authenticate(tokens: HashMap) -> impl Filter + Clone { @@ -52,9 +59,7 @@ fn authenticate(tokens: HashMap) -> impl Filter Err(warp::reject::custom(ActionStatus::Err { - reason: "unauthorized".into(), - })), + None => Err(warp::reject::custom(Unauthorized)), } }) .untuple_one() @@ -317,7 +322,9 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc) -> impl w .or(reload_keys), ); - let server = api_routes.and(authenticate(tracker.config.http_api.access_tokens.clone())); + let server = api_routes + .and(authenticate(tracker.config.http_api.access_tokens.clone())) + .recover(handle_rejection); let (_addr, api_server) = serve(server).bind_with_graceful_shutdown(socket_addr, async move { tokio::signal::ctrl_c().await.expect("Failed to listen to shutdown signal."); @@ -325,3 +332,48 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc) -> impl w api_server } + +#[allow(clippy::unused_async)] +async fn handle_rejection(err: Rejection) -> Result { + if let Some(_e) = err.find::() { + Ok(reply::with_status("unauthorized", StatusCode::UNAUTHORIZED)) + } else if let Some(e) = err.find::() { + match e { + ActionStatus::Ok => Ok(reply::with_status("", StatusCode::OK)), + ActionStatus::Err { reason } => { + if reason == "token not valid" { + Ok(reply::with_status("token not valid", StatusCode::INTERNAL_SERVER_ERROR)) + } else if reason == "failed to remove torrent from whitelist" { + Ok(reply::with_status( + "failed to remove torrent from whitelist", + StatusCode::INTERNAL_SERVER_ERROR, + )) + } else if reason == "failed to whitelist torrent" { + Ok(reply::with_status( + "failed to whitelist torrent", + StatusCode::INTERNAL_SERVER_ERROR, + )) + } else if reason == "failed to generate key" { + Ok(reply::with_status( + "failed to generate key", + StatusCode::INTERNAL_SERVER_ERROR, + )) + } else if reason == "failed to delete key" { + Ok(reply::with_status("failed to delete key", StatusCode::INTERNAL_SERVER_ERROR)) + } else if reason == "failed to reload whitelist" { + Ok(reply::with_status( + "failed to reload whitelist", + StatusCode::INTERNAL_SERVER_ERROR, + )) + } else if reason == "failed to reload keys" { + Ok(reply::with_status("failed to reload keys", StatusCode::INTERNAL_SERVER_ERROR)) + } else { + Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR)) + } + } + } + } else { + error!("unhandled rejection: {err:?}"); + Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR)) + } +} diff --git a/tests/api.rs b/tests/api.rs index 706cd0b8..b8de057c 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -30,6 +30,27 @@ mod tracker_api { use crate::common::ephemeral_random_port; + #[tokio::test] + async fn should_return_an_unauthorized_response_when_the_token_is_missing() { + let api_server = ApiServer::new_running_instance().await; + + let url = format!("http://{}/api/torrents", api_server.connection_info.unwrap().bind_address); + let res = reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap(); + + assert_eq!(res.status(), 401); + } + + #[tokio::test] + async fn should_return_an_internal_error_server_when_the_token_is_invalid() { + let api_server = ApiServer::new_running_instance().await; + + let res = ApiClient::new(api_server.get_connection_info().unwrap()) + .request("api/torrents", "invalid token") + .await; + + assert_eq!(res.status(), 500); + } + #[tokio::test] async fn should_allow_generating_a_new_auth_key() { let api_server = ApiServer::new_running_instance().await; @@ -376,5 +397,10 @@ mod tracker_api { .await .unwrap() } + + pub async fn request(&self, path: &str, token: &str) -> Response { + let url = format!("http://{}/{}?token={}", &self.connection_info.bind_address, path, token); + reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap() + } } }