Skip to content

Commit

Permalink
fix(api): [#58] return 401 when the token is missing
Browse files Browse the repository at this point in the history
  • Loading branch information
josecelano committed Dec 21, 2022
1 parent 8982159 commit 19b4dcb
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 5 deletions.
62 changes: 57 additions & 5 deletions src/api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use log::error;
use serde::{Deserialize, Serialize};
use warp::{filters, reply, serve, Filter};
use warp::http::StatusCode;
use warp::{filters, reject, reply, serve, Filter, Rejection, Reply};

use super::resource::auth_key::AuthKey;
use super::resource::peer;
Expand All @@ -27,6 +29,11 @@ enum ActionStatus<'a> {
Err { reason: std::borrow::Cow<'a, str> },
}

#[derive(Debug)]
struct Unauthorized;

impl reject::Reject for Unauthorized {}

impl warp::reject::Reject for ActionStatus<'static> {}

fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Error = warp::reject::Rejection> + Clone {
Expand All @@ -52,9 +59,7 @@ fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Er

Ok(())
}
None => Err(warp::reject::custom(ActionStatus::Err {
reason: "unauthorized".into(),
})),
None => Err(warp::reject::custom(Unauthorized)),
}
})
.untuple_one()
Expand Down Expand Up @@ -317,11 +322,58 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc<tracker::Tracker>) -> impl w
.or(reload_keys),
);

let server = api_routes.and(authenticate(tracker.config.http_api.access_tokens.clone()));
let server = api_routes
.and(authenticate(tracker.config.http_api.access_tokens.clone()))
.recover(handle_rejection);

let (_addr, api_server) = serve(server).bind_with_graceful_shutdown(socket_addr, async move {
tokio::signal::ctrl_c().await.expect("Failed to listen to shutdown signal.");
});

api_server
}

#[allow(clippy::unused_async)]
async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
if let Some(_e) = err.find::<Unauthorized>() {
Ok(reply::with_status("unauthorized", StatusCode::UNAUTHORIZED))
} else if let Some(e) = err.find::<ActionStatus>() {
match e {
ActionStatus::Ok => Ok(reply::with_status("", StatusCode::OK)),
ActionStatus::Err { reason } => {
if reason == "token not valid" {
Ok(reply::with_status("token not valid", StatusCode::INTERNAL_SERVER_ERROR))
} else if reason == "failed to remove torrent from whitelist" {
Ok(reply::with_status(
"failed to remove torrent from whitelist",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to whitelist torrent" {
Ok(reply::with_status(
"failed to whitelist torrent",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to generate key" {
Ok(reply::with_status(
"failed to generate key",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to delete key" {
Ok(reply::with_status("failed to delete key", StatusCode::INTERNAL_SERVER_ERROR))
} else if reason == "failed to reload whitelist" {
Ok(reply::with_status(
"failed to reload whitelist",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to reload keys" {
Ok(reply::with_status("failed to reload keys", StatusCode::INTERNAL_SERVER_ERROR))
} else {
Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR))
}
}
}
} else {
error!("unhandled rejection: {err:?}");
Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR))
}
}
26 changes: 26 additions & 0 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ mod tracker_api {

use crate::common::ephemeral_random_port;

#[tokio::test]
async fn should_return_an_unauthorized_response_when_the_token_is_missing() {
let api_server = ApiServer::new_running_instance().await;

let url = format!("http://{}/api/torrents", api_server.connection_info.unwrap().bind_address);
let res = reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap();

assert_eq!(res.status(), 401);
}

#[tokio::test]
async fn should_return_an_internal_error_server_when_the_token_is_invalid() {
let api_server = ApiServer::new_running_instance().await;

let res = ApiClient::new(api_server.get_connection_info().unwrap())
.request("api/torrents", "invalid token")
.await;

assert_eq!(res.status(), 500);
}

#[tokio::test]
async fn should_allow_generating_a_new_auth_key() {
let api_server = ApiServer::new_running_instance().await;
Expand Down Expand Up @@ -376,5 +397,10 @@ mod tracker_api {
.await
.unwrap()
}

pub async fn request(&self, path: &str, token: &str) -> Response {
let url = format!("http://{}/{}?token={}", &self.connection_info.bind_address, path, token);
reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap()
}
}
}

0 comments on commit 19b4dcb

Please sign in to comment.