Skip to content

Commit

Permalink
Replace regex with nom for duration parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasPickering committed Apr 30, 2024
1 parent 9744bbe commit e404701
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 50 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ notify = {version = "^6.1.1", default-features = false, features = ["macos_fseve
open = "5.1.1"
pretty_assertions = "1.4.0"
ratatui = {version = "^0.26.0", features = ["unstable-rendered-line-info"]}
regex = {version = "1.10.3", default-features = false, features = ["perf"]}
reqwest = {version = "^0.11.20", default-features = false, features = ["rustls-tls"]}
rmp-serde = "^1.1.2"
rusqlite = {version = "^0.30.0", default-features = false, features = ["bundled", "chrono", "uuid"]}
Expand Down
4 changes: 2 additions & 2 deletions src/cli/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl BuildRequestCommand {
collection.profiles.get(profile_id).ok_or_else(|| {
anyhow!(
"No profile with ID `{profile_id}`; options are: {}",
collection.profiles.keys().join(", ")
collection.profiles.keys().format(", ")
)
})?;
}
Expand All @@ -185,7 +185,7 @@ impl BuildRequestCommand {
anyhow!(
"No recipe with ID `{}`; options are: {}",
self.recipe_id,
collection.recipes.recipe_ids().join(", ")
collection.recipes.recipe_ids().format(", ")
)
})?
.clone();
Expand Down
109 changes: 64 additions & 45 deletions src/collection/cereal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,34 @@ impl<'de> Deserialize<'de> for Template {
/// - d
/// Examples: `30s`, `5m`, `12h`, `3d`
pub mod serde_duration {
use regex::Regex;
use derive_more::Display;
use itertools::Itertools;
use nom::{
bytes::complete::take_while,
character::complete::digit1,
combinator::{all_consuming, map_res},
sequence::tuple,
IResult,
};
use serde::{de::Error, Deserialize, Deserializer, Serializer};
use std::{sync::OnceLock, time::Duration};
use std::time::Duration;
use strum::{EnumIter, EnumString, IntoEnumIterator};

const UNIT_SECOND: &str = "s";
const UNIT_MINUTE: &str = "m";
const UNIT_HOUR: &str = "h";
const UNIT_DAY: &str = "d";
#[derive(Debug, Display, EnumIter, EnumString)]
enum Unit {
#[display("s")]
#[strum(serialize = "s")]
Second,
#[display("m")]
#[strum(serialize = "m")]
Minute,
#[display("h")]
#[strum(serialize = "h")]
Hour,
#[display("d")]
#[strum(serialize = "d")]
Day,
}

pub fn serialize<S>(
duration: &Duration,
Expand All @@ -136,48 +156,43 @@ pub mod serde_duration {
{
// Always serialize as seconds, because it's easiest. Sub-second
// precision is lost
S::serialize_str(
serializer,
&format!("{}{}", duration.as_secs(), UNIT_SECOND),
)
S::serialize_str(serializer, &format!("{}s", duration.as_secs()))
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
// TODO remove regex
// unstable: use LazyLock https://github.com/rust-lang/rust/pull/121377
static REGEX: OnceLock<Regex> = OnceLock::new();
let s = String::deserialize(deserializer)?;
let regex = REGEX.get_or_init(|| Regex::new("^(\\d+)(\\w+)$").unwrap());
if let Some(captures) = regex.captures(&s) {
let quantity: u64 = captures
.get(1)
.expect("No first group")
.as_str()
.parse()
// Error should be impossible because the regex only allows ints
.map_err(|_| D::Error::custom("Invalid int"))?;
let unit = captures.get(2).expect("No second group").as_str();
let seconds = match unit {
UNIT_SECOND => quantity,
UNIT_MINUTE => quantity * 60,
UNIT_HOUR => quantity * 60 * 60,
UNIT_DAY => quantity * 60 * 60 * 24,
_ => {
return Err(D::Error::custom(format!(
"Unknown duration unit: {unit:?}; must be one of {:?}",
[UNIT_SECOND, UNIT_MINUTE, UNIT_HOUR, UNIT_DAY]
)))
}
};
Ok(Duration::from_secs(seconds))
} else {
Err(D::Error::custom(
"Invalid duration, must be \"<quantity><unit>\" (e.g. \"12d\")",
))
fn quantity(input: &str) -> IResult<&str, u64> {
map_res(digit1, str::parse)(input)
}

fn unit(input: &str) -> IResult<&str, &str> {
take_while(char::is_alphabetic)(input)
}

let input = String::deserialize(deserializer)?;
let (_, (quantity, unit)) =
all_consuming(tuple((quantity, unit)))(&input).map_err(|_| {
D::Error::custom(
"Invalid duration, must be `<quantity><unit>` (e.g. `12d`)",
)
})?;

let unit = unit.parse().map_err(|_| {
D::Error::custom(format!(
"Unknown duration unit `{unit}`; must be one of {}",
Unit::iter()
.format_with(", ", |unit, f| f(&format_args!("`{unit}`")))
))
})?;
let seconds = match unit {
Unit::Second => quantity,
Unit::Minute => quantity * 60,
Unit::Hour => quantity * 60 * 60,
Unit::Day => quantity * 60 * 60 * 24,
};
Ok(Duration::from_secs(seconds))
}

#[cfg(test)]
Expand Down Expand Up @@ -225,19 +240,23 @@ pub mod serde_duration {
#[rstest]
#[case::negative(
"-1s",
r#"Invalid duration, must be "<quantity><unit>" (e.g. "12d")"#
"Invalid duration, must be `<quantity><unit>` (e.g. `12d`)"
)]
#[case::whitespace(
" 1s ",
r#"Invalid duration, must be "<quantity><unit>" (e.g. "12d")"#
"Invalid duration, must be `<quantity><unit>` (e.g. `12d`)"
)]
#[case::trailing_whitespace(
"1s ",
"Invalid duration, must be `<quantity><unit>` (e.g. `12d`)"
)]
#[case::decimal(
"3.5s",
r#"Invalid duration, must be "<quantity><unit>" (e.g. "12d")"#
"Invalid duration, must be `<quantity><unit>` (e.g. `12d`)"
)]
#[case::invalid_unit(
"3hr",
r#"Unknown duration unit: "hr"; must be one of ["s", "m", "h", "d"]"#
"Unknown duration unit `hr`; must be one of `s`, `m`, `h`, `d`"
)]
fn test_deserialize_error(
#[case] s: &'static str,
Expand Down
2 changes: 1 addition & 1 deletion src/collection/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ impl TryFrom<String> for Method {
value.parse().map_err(|_| {
anyhow!(
"Invalid HTTP method `{value}`. Must be one of: {}",
Method::iter().map(|method| method.to_string()).join(", ")
Method::iter().map(|method| method.to_string()).format(", ")
)
})
}
Expand Down

0 comments on commit e404701

Please sign in to comment.