diff --git a/Cargo.toml b/Cargo.toml index 2f0a18c..9d00046 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ chrono = { version ="0.4.23", features = ["serde", "rustc-serialize"] } reqwest = { version = "0.11.13", features = ["json"] } serde = { version = "1.0.151", features=["derive"] } tokio = { version = "1.23.0", features = ["full"] } +serde_with = "2.1.0" diff --git a/src/lib.rs b/src/lib.rs index cf60010..965cd7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,7 @@ impl DuneClient { #[cfg(test)] mod tests { use super::*; + use crate::response::ExecutionStatus; use dotenv::dotenv; use serde::Deserialize; use std::env; @@ -183,7 +184,7 @@ mod tests { async fn get_status() { let dune = get_dune(); let status = dune.get_status(JOB_ID).await.unwrap(); - assert_eq!(status.state, "QUERY_STATE_COMPLETED") + assert_eq!(status.state, ExecutionStatus::Complete) } #[tokio::test] diff --git a/src/response.rs b/src/response.rs index 4ecd7a0..cfd5326 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,21 +1,59 @@ use chrono::{DateTime, NaiveDateTime, Utc}; use serde::{de, Deserialize, Deserializer}; +use serde_with::DeserializeFromStr; +use std::str::FromStr; #[derive(Deserialize, Debug)] pub struct DuneError { - pub(crate) error: String, + pub error: String, } #[derive(Deserialize, Debug)] pub struct ExecutionResponse { - pub(crate) execution_id: String, + pub execution_id: String, // TODO use ExecutionState Enum pub state: String, } +#[derive(DeserializeFromStr, Debug, PartialEq)] +pub enum ExecutionStatus { + Complete, + Executing, + Pending, + Cancelled, + Failed, +} + +impl FromStr for ExecutionStatus { + type Err = String; + + fn from_str(input: &str) -> Result { + match input { + "QUERY_STATE_COMPLETED" => Ok(ExecutionStatus::Complete), + "QUERY_STATE_EXECUTING" => Ok(ExecutionStatus::Executing), + "QUERY_STATE_PENDING" => Ok(ExecutionStatus::Pending), + "QUERY_STATE_CANCELLED" => Ok(ExecutionStatus::Cancelled), + "QUERY_STATE_FAILED" => Ok(ExecutionStatus::Failed), + other => Err(format!("Parse Error {other}")), + } + } +} + +impl ExecutionStatus { + pub fn is_terminal(&self) -> bool { + match self { + ExecutionStatus::Complete => true, + ExecutionStatus::Cancelled => true, + ExecutionStatus::Failed => true, + ExecutionStatus::Executing => false, + ExecutionStatus::Pending => false, + } + } +} + #[derive(Deserialize, Debug)] pub struct CancellationResponse { - pub(crate) success: bool, + pub success: bool, } #[derive(Deserialize, Debug)] @@ -55,7 +93,7 @@ pub struct ExecutionTimes { pub struct GetStatusResponse { pub execution_id: String, pub query_id: u32, - pub state: String, + pub state: ExecutionStatus, #[serde(flatten)] pub times: ExecutionTimes, pub queue_position: Option, @@ -72,7 +110,7 @@ pub struct ExecutionResult { pub struct GetResultResponse { pub execution_id: String, pub query_id: u32, - pub state: String, + pub state: ExecutionStatus, // TODO - this `flatten` isn't what I had hoped for. // I want the `times` field to disappear // and all sub-fields to be brought up to this layer. @@ -86,3 +124,45 @@ impl GetResultResponse { self.result.rows } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn status_from_str() { + assert_eq!( + ExecutionStatus::from_str("invalid"), + Err(String::from("Parse Error invalid")) + ); + assert_eq!( + ExecutionStatus::from_str("QUERY_STATE_COMPLETED"), + Ok(ExecutionStatus::Complete) + ); + assert_eq!( + ExecutionStatus::from_str("QUERY_STATE_EXECUTING"), + Ok(ExecutionStatus::Executing) + ); + assert_eq!( + ExecutionStatus::from_str("QUERY_STATE_PENDING"), + Ok(ExecutionStatus::Pending) + ); + assert_eq!( + ExecutionStatus::from_str("QUERY_STATE_CANCELLED"), + Ok(ExecutionStatus::Cancelled) + ); + assert_eq!( + ExecutionStatus::from_str("QUERY_STATE_FAILED"), + Ok(ExecutionStatus::Failed) + ); + } + #[test] + fn terminal_statuses() { + assert_eq!(ExecutionStatus::Complete.is_terminal(), true); + assert_eq!(ExecutionStatus::Cancelled.is_terminal(), true); + assert_eq!(ExecutionStatus::Failed.is_terminal(), true); + + assert_eq!(ExecutionStatus::Pending.is_terminal(), false); + assert_eq!(ExecutionStatus::Executing.is_terminal(), false); + } +}