Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix API endpoint: Return 401 when the token is missing #133

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions src/api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use log::error;
use serde::{Deserialize, Serialize};
use warp::{filters, reply, serve, Filter};
use warp::http::StatusCode;
use warp::{filters, reject, reply, serve, Filter, Rejection, Reply};

use super::resource::auth_key::AuthKey;
use super::resource::peer;
Expand All @@ -27,6 +29,11 @@ enum ActionStatus<'a> {
Err { reason: std::borrow::Cow<'a, str> },
}

#[derive(Debug)]
struct Unauthorized;

impl reject::Reject for Unauthorized {}

impl warp::reject::Reject for ActionStatus<'static> {}

fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Error = warp::reject::Rejection> + Clone {
Expand All @@ -52,9 +59,7 @@ fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Er

Ok(())
}
None => Err(warp::reject::custom(ActionStatus::Err {
reason: "unauthorized".into(),
})),
None => Err(warp::reject::custom(Unauthorized)),
}
})
.untuple_one()
Expand Down Expand Up @@ -317,11 +322,58 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc<tracker::Tracker>) -> impl w
.or(reload_keys),
);

let server = api_routes.and(authenticate(tracker.config.http_api.access_tokens.clone()));
let server = api_routes
.and(authenticate(tracker.config.http_api.access_tokens.clone()))
.recover(handle_rejection);

let (_addr, api_server) = serve(server).bind_with_graceful_shutdown(socket_addr, async move {
tokio::signal::ctrl_c().await.expect("Failed to listen to shutdown signal.");
});

api_server
}

#[allow(clippy::unused_async)]
async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
if let Some(_e) = err.find::<Unauthorized>() {
Ok(reply::with_status("unauthorized", StatusCode::UNAUTHORIZED))
} else if let Some(e) = err.find::<ActionStatus>() {
match e {
ActionStatus::Ok => Ok(reply::with_status("", StatusCode::OK)),
ActionStatus::Err { reason } => {
if reason == "token not valid" {
Ok(reply::with_status("token not valid", StatusCode::INTERNAL_SERVER_ERROR))
} else if reason == "failed to remove torrent from whitelist" {
Ok(reply::with_status(
"failed to remove torrent from whitelist",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to whitelist torrent" {
Ok(reply::with_status(
"failed to whitelist torrent",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to generate key" {
Ok(reply::with_status(
"failed to generate key",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to delete key" {
Ok(reply::with_status("failed to delete key", StatusCode::INTERNAL_SERVER_ERROR))
} else if reason == "failed to reload whitelist" {
Ok(reply::with_status(
"failed to reload whitelist",
StatusCode::INTERNAL_SERVER_ERROR,
))
} else if reason == "failed to reload keys" {
Ok(reply::with_status("failed to reload keys", StatusCode::INTERNAL_SERVER_ERROR))
} else {
Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR))
}
}
}
} else {
error!("unhandled rejection: {err:?}");
Ok(reply::with_status("internal server error", StatusCode::INTERNAL_SERVER_ERROR))
}
}
26 changes: 26 additions & 0 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ mod tracker_api {

use crate::common::ephemeral_random_port;

#[tokio::test]
async fn should_return_an_unauthorized_response_when_the_token_is_missing() {
let api_server = ApiServer::new_running_instance().await;

let url = format!("http://{}/api/torrents", api_server.connection_info.unwrap().bind_address);
let res = reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap();

assert_eq!(res.status(), 401);
}

#[tokio::test]
async fn should_return_an_internal_error_server_when_the_token_is_invalid() {
let api_server = ApiServer::new_running_instance().await;

let res = ApiClient::new(api_server.get_connection_info().unwrap())
.request("api/torrents", "invalid token")
.await;

assert_eq!(res.status(), 500);
}

#[tokio::test]
async fn should_allow_generating_a_new_auth_key() {
let api_server = ApiServer::new_running_instance().await;
Expand Down Expand Up @@ -376,5 +397,10 @@ mod tracker_api {
.await
.unwrap()
}

pub async fn request(&self, path: &str, token: &str) -> Response {
let url = format!("http://{}/{}?token={}", &self.connection_info.bind_address, path, token);
reqwest::Client::builder().build().unwrap().get(url).send().await.unwrap()
}
}
}