Skip to content

Commit

Permalink
fix(api): [torrust#143] fix new Axum API enpoint when URL params are …
Browse files Browse the repository at this point in the history
…invalid
  • Loading branch information
josecelano committed Jan 12, 2023
1 parent c502c1d commit 517ffde
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 48 deletions.
127 changes: 85 additions & 42 deletions src/apis/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::api::resource::auth_key::AuthKey;
use crate::api::resource::stats::Stats;
use crate::api::resource::torrent::{ListItem, Torrent};
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth::KeyId;
use crate::tracker::services::statistics::get_metrics;
use crate::tracker::services::torrent::{get_torrent_info, get_torrents, Pagination};
use crate::tracker::Tracker;
Expand All @@ -37,7 +38,7 @@ pub fn router(tracker: &Arc<Tracker>) -> Router {
)
.route(
"/api/whitelist/:info_hash",
delete(delete_torrent_from_whitelist_handler).with_state(tracker.clone()),
delete(remove_torrent_from_whitelist_handler).with_state(tracker.clone()),
)
// Whitelist command
.route(
Expand Down Expand Up @@ -68,6 +69,19 @@ pub enum ActionStatus<'a> {
Err { reason: std::borrow::Cow<'a, str> },
}

// Resource responses

fn response_auth_key(auth_key: &AuthKey) -> Response {
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json; charset=utf-8")],
serde_json::to_string(auth_key).unwrap(),
)
.into_response()
}

// OK response

fn response_ok() -> Response {
(
StatusCode::OK,
Expand All @@ -77,20 +91,29 @@ fn response_ok() -> Response {
.into_response()
}

fn response_err(reason: String) -> Response {
// Error responses

fn response_invalid_info_hash_param(info_hash: &str) -> Response {
response_bad_request(&format!(
"Invalid URL: invalid infohash param: string \"{}\", expected expected a 40 character long string",
info_hash
))
}

fn response_bad_request(body: &str) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
format!("Unhandled rejection: {:?}", ActionStatus::Err { reason: reason.into() }),
body.to_owned(),
)
.into_response()
}

fn response_auth_key(auth_key: &AuthKey) -> Response {
fn response_err(reason: String) -> Response {
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json; charset=utf-8")],
serde_json::to_string(auth_key).unwrap(),
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
format!("Unhandled rejection: {:?}", ActionStatus::Err { reason: reason.into() }),
)
.into_response()
}
Expand All @@ -99,15 +122,22 @@ pub async fn get_stats_handler(State(tracker): State<Arc<Tracker>>) -> Json<Stat
Json(Stats::from(get_metrics(tracker.clone()).await))
}

/// # Panics
///
/// Will panic if it can't parse the infohash in the request
pub async fn get_torrent_handler(State(tracker): State<Arc<Tracker>>, Path(info_hash): Path<String>) -> Response {
let optional_torrent_info = get_torrent_info(tracker.clone(), &InfoHash::from_str(&info_hash).unwrap()).await;
#[derive(Deserialize)]
pub struct InfoHashParam(String);

pub async fn get_torrent_handler(State(tracker): State<Arc<Tracker>>, Path(info_hash): Path<InfoHashParam>) -> Response {
let parsing_info_hash_result = InfoHash::from_str(&info_hash.0);

match optional_torrent_info {
Some(info) => Json(Torrent::from(info)).into_response(),
None => Json(json!("torrent not known")).into_response(),
match parsing_info_hash_result {
Err(_) => response_invalid_info_hash_param(&info_hash.0),
Ok(info_hash) => {
let optional_torrent_info = get_torrent_info(tracker.clone(), &info_hash).await;

match optional_torrent_info {
Some(info) => Json(Torrent::from(info)).into_response(),
None => Json(json!("torrent not known")).into_response(),
}
}
}
}

Expand All @@ -131,32 +161,33 @@ pub async fn get_torrents_handler(
))
}

/// # Panics
///
/// Will panic if it can't parse the infohash in the request
pub async fn add_torrent_to_whitelist_handler(State(tracker): State<Arc<Tracker>>, Path(info_hash): Path<String>) -> Response {
match tracker
.add_torrent_to_whitelist(&InfoHash::from_str(&info_hash).unwrap())
.await
{
Ok(..) => response_ok(),
Err(..) => response_err("failed to whitelist torrent".to_string()),
pub async fn add_torrent_to_whitelist_handler(
State(tracker): State<Arc<Tracker>>,
Path(info_hash): Path<InfoHashParam>,
) -> Response {
let parsing_info_hash_result = InfoHash::from_str(&info_hash.0);

match parsing_info_hash_result {
Err(_) => response_invalid_info_hash_param(&info_hash.0),
Ok(info_hash) => match tracker.add_torrent_to_whitelist(&info_hash).await {
Ok(..) => response_ok(),
Err(..) => response_err("failed to whitelist torrent".to_string()),
},
}
}

/// # Panics
///
/// Will panic if it can't parse the infohash in the request
pub async fn delete_torrent_from_whitelist_handler(
pub async fn remove_torrent_from_whitelist_handler(
State(tracker): State<Arc<Tracker>>,
Path(info_hash): Path<String>,
Path(info_hash): Path<InfoHashParam>,
) -> Response {
match tracker
.remove_torrent_from_whitelist(&InfoHash::from_str(&info_hash).unwrap())
.await
{
Ok(..) => response_ok(),
Err(..) => response_err("failed to remove torrent from whitelist".to_string()),
let parsing_info_hash_result = InfoHash::from_str(&info_hash.0);

match parsing_info_hash_result {
Err(_) => response_invalid_info_hash_param(&info_hash.0),
Ok(info_hash) => match tracker.remove_torrent_from_whitelist(&info_hash).await {
Ok(..) => response_ok(),
Err(..) => response_err("failed to remove torrent from whitelist".to_string()),
},
}
}

Expand All @@ -168,16 +199,28 @@ pub async fn reload_whitelist_handler(State(tracker): State<Arc<Tracker>>) -> Re
}

pub async fn generate_auth_key_handler(State(tracker): State<Arc<Tracker>>, Path(seconds_valid_or_key): Path<u64>) -> Response {
match tracker.generate_auth_key(Duration::from_secs(seconds_valid_or_key)).await {
let seconds_valid = seconds_valid_or_key;
match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await {
Ok(auth_key) => response_auth_key(&AuthKey::from(auth_key)),
Err(_) => response_err("failed to generate key".to_string()),
}
}

pub async fn delete_auth_key_handler(State(tracker): State<Arc<Tracker>>, Path(seconds_valid_or_key): Path<String>) -> Response {
match tracker.remove_auth_key(&seconds_valid_or_key).await {
Ok(_) => response_ok(),
Err(_) => response_err("failed to delete key".to_string()),
#[derive(Deserialize)]
pub struct KeyIdParam(String);

pub async fn delete_auth_key_handler(
State(tracker): State<Arc<Tracker>>,
Path(seconds_valid_or_key): Path<KeyIdParam>,
) -> Response {
let key_id = KeyId::from_str(&seconds_valid_or_key.0);

match key_id {
Err(_) => response_bad_request(&format!("Invalid auth key id param \"{}\"", seconds_valid_or_key.0)),
Ok(key_id) => match tracker.remove_auth_key(&key_id.to_string()).await {
Ok(_) => response_ok(),
Err(_) => response_err("failed to delete key".to_string()),
},
}
}

Expand Down
31 changes: 31 additions & 0 deletions src/tracker/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use std::time::Duration;

use derive_more::{Display, Error};
Expand Down Expand Up @@ -50,6 +51,8 @@ pub fn verify(auth_key: &Key) -> Result<(), Error> {

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct Key {
// todo: replace key field definition with:
// pub key: KeyId,
pub key: String,
pub valid_until: Option<DurationSinceUnixEpoch>,
}
Expand Down Expand Up @@ -77,6 +80,24 @@ impl Key {
}
}

#[derive(Debug, Display, PartialEq, Clone)]
pub struct KeyId(String);

#[derive(Debug, PartialEq, Eq)]
pub struct ParseKeyIdError;

impl FromStr for KeyId {
type Err = ParseKeyIdError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() != AUTH_KEY_LENGTH {
return Err(ParseKeyIdError);
}

Ok(Self(s.to_string()))
}
}

#[derive(Debug, Display, PartialEq, Eq, Error)]
#[allow(dead_code)]
pub enum Error {
Expand All @@ -97,6 +118,7 @@ impl From<r2d2_sqlite::rusqlite::Error> for Error {

#[cfg(test)]
mod tests {
use std::str::FromStr;
use std::time::Duration;

use crate::protocol::clock::{Current, StoppedTime};
Expand All @@ -122,6 +144,15 @@ mod tests {
assert_eq!(auth_key.unwrap().key, key_string);
}

#[test]
fn auth_key_id_from_string() {
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
let auth_key_id = auth::KeyId::from_str(key_string);

assert!(auth_key_id.is_ok());
assert_eq!(auth_key_id.unwrap().to_string(), key_string);
}

#[test]
fn generate_valid_auth_key() {
let auth_key = auth::generate(Duration::new(9999, 0));
Expand Down
1 change: 1 addition & 0 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl Tracker {
///
/// Will return a `key::Error` if unable to get any `auth_key`.
pub async fn verify_auth_key(&self, auth_key: &auth::Key) -> Result<(), auth::Error> {
// todo: use auth::KeyId for the function argument `auth_key`
match self.keys.read().await.get(&auth_key.key) {
None => Err(auth::Error::KeyInvalid),
Some(key) => auth::verify(key),
Expand Down
12 changes: 12 additions & 0 deletions tests/api/asserts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ pub async fn assert_ok(response: Response) {

// Error responses

pub async fn assert_bad_request(response: Response, body: &str) {
assert_eq!(response.status(), 400);
assert_eq!(response.headers().get("content-type").unwrap(), "text/plain; charset=utf-8");
assert_eq!(response.text().await.unwrap(), body);
}

pub async fn assert_method_not_allowed(response: Response) {
assert_eq!(response.status(), 405);
assert_eq!(response.headers().get("content-type").unwrap(), "text/plain; charset=utf-8");
assert_eq!(response.text().await.unwrap(), "HTTP method not allowed");
}

pub async fn assert_torrent_not_known(response: Response) {
assert_eq!(response.status(), 200);
assert_eq!(response.headers().get("content-type").unwrap(), "application/json");
Expand Down
Loading

0 comments on commit 517ffde

Please sign in to comment.