Skip to content

Commit

Permalink
agg: support to deserialize f64 from string (#2311)
Browse files Browse the repository at this point in the history
* agg: support to deserialize f64 from string

* remove visit_string

* disallow NaN
  • Loading branch information
PSeitz authored Mar 5, 2024
1 parent 40aa4ab commit 7e41d31
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/aggregation/bucket/histogram/date_histogram.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};

use super::{HistogramAggregation, HistogramBounds};
use crate::aggregation::AggregationError;
use crate::aggregation::*;

/// DateHistogramAggregation is similar to `HistogramAggregation`, but it can only be used with date
/// type.
Expand Down
4 changes: 3 additions & 1 deletion src/aggregation/bucket/histogram/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::aggregation::intermediate_agg_result::{
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, format_date};
use crate::aggregation::*;
use crate::TantivyError;

/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
Expand Down Expand Up @@ -73,6 +73,7 @@ pub struct HistogramAggregation {
pub field: String,
/// The interval to chunk your data range. Each bucket spans a value range of [0..interval).
/// Must be a positive value.
#[serde(deserialize_with = "deserialize_f64")]
pub interval: f64,
/// Intervals implicitly defines an absolute grid of buckets `[interval * k, interval * (k +
/// 1))`.
Expand All @@ -85,6 +86,7 @@ pub struct HistogramAggregation {
/// fall into the buckets with the key 0 and 10.
/// With offset 5 and interval 10, they would both fall into the bucket with they key 5 and the
/// range [5..15)
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub offset: Option<f64>,
/// The minimum number of documents in a bucket to be returned. Defaults to 0.
pub min_doc_count: Option<u64>,
Expand Down
16 changes: 11 additions & 5 deletions src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ use crate::aggregation::intermediate_agg_result::{
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::{
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey,
};
use crate::aggregation::*;
use crate::TantivyError;

/// Provide user-defined buckets to aggregate on.
Expand Down Expand Up @@ -72,11 +70,19 @@ pub struct RangeAggregationRange {
pub key: Option<String>,
/// The from range value, which is inclusive in the range.
/// `None` equals to an open ended interval.
#[serde(skip_serializing_if = "Option::is_none", default)]
#[serde(
skip_serializing_if = "Option::is_none",
default,
deserialize_with = "deserialize_option_f64"
)]
pub from: Option<f64>,
/// The to range value, which is not inclusive in the range.
/// `None` equals to an open ended interval.
#[serde(skip_serializing_if = "Option::is_none", default)]
#[serde(
skip_serializing_if = "Option::is_none",
default,
deserialize_with = "deserialize_option_f64"
)]
pub to: Option<f64>,
}

Expand Down
73 changes: 71 additions & 2 deletions src/aggregation/metric/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::{IntermediateStats, SegmentStatsCollector};
use super::*;
use crate::aggregation::*;

/// A single-value metric aggregation that computes the average of numeric values that are
/// extracted from the aggregated documents.
Expand All @@ -24,7 +25,7 @@ pub struct AverageAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down Expand Up @@ -65,3 +66,71 @@ impl IntermediateAverage {
self.stats.finalize().avg
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn deserialization_with_missing_test1() {
let json = r#"{
"field": "score",
"missing": "10.0"
}"#;
let avg: AverageAggregation = serde_json::from_str(json).unwrap();
assert_eq!(avg.field, "score");
assert_eq!(avg.missing, Some(10.0));
// no dot
let json = r#"{
"field": "score",
"missing": "10"
}"#;
let avg: AverageAggregation = serde_json::from_str(json).unwrap();
assert_eq!(avg.field, "score");
assert_eq!(avg.missing, Some(10.0));

// from value
let avg: AverageAggregation = serde_json::from_value(json!({
"field": "score_f64",
"missing": 10u64,
}))
.unwrap();
assert_eq!(avg.missing, Some(10.0));
// from value
let avg: AverageAggregation = serde_json::from_value(json!({
"field": "score_f64",
"missing": 10u32,
}))
.unwrap();
assert_eq!(avg.missing, Some(10.0));
let avg: AverageAggregation = serde_json::from_value(json!({
"field": "score_f64",
"missing": 10i8,
}))
.unwrap();
assert_eq!(avg.missing, Some(10.0));
}

#[test]
fn deserialization_with_missing_test_fail() {
let json = r#"{
"field": "score",
"missing": "a"
}"#;
let avg: Result<AverageAggregation, _> = serde_json::from_str(json);
assert!(avg.is_err());
assert!(avg
.unwrap_err()
.to_string()
.contains("Failed to parse f64 from string: \"a\""));

// Disallow NaN
let json = r#"{
"field": "score",
"missing": "NaN"
}"#;
let avg: Result<AverageAggregation, _> = serde_json::from_str(json);
assert!(avg.is_err());
assert!(avg.unwrap_err().to_string().contains("NaN"));
}
}
5 changes: 3 additions & 2 deletions src/aggregation/metric/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::{IntermediateStats, SegmentStatsCollector};
use super::*;
use crate::aggregation::*;

/// A single-value metric aggregation that counts the number of values that are
/// extracted from the aggregated documents.
Expand All @@ -24,7 +25,7 @@ pub struct CountAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down
5 changes: 3 additions & 2 deletions src/aggregation/metric/max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::{IntermediateStats, SegmentStatsCollector};
use super::*;
use crate::aggregation::*;

/// A single-value metric aggregation that computes the maximum of numeric values that are
/// extracted from the aggregated documents.
Expand All @@ -24,7 +25,7 @@ pub struct MaxAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down
5 changes: 3 additions & 2 deletions src/aggregation/metric/min.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::{IntermediateStats, SegmentStatsCollector};
use super::*;
use crate::aggregation::*;

/// A single-value metric aggregation that computes the minimum of numeric values that are
/// extracted from the aggregated documents.
Expand All @@ -24,7 +25,7 @@ pub struct MinAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down
1 change: 1 addition & 0 deletions src/aggregation/metric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod percentiles;
mod stats;
mod sum;
mod top_hits;

pub use average::*;
pub use count::*;
pub use max::*;
Expand Down
8 changes: 6 additions & 2 deletions src/aggregation/metric/percentiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, AggregationError};
use crate::aggregation::*;
use crate::{DocId, TantivyError};

/// # Percentiles
Expand Down Expand Up @@ -84,7 +84,11 @@ pub struct PercentilesAggregationReq {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(skip_serializing_if = "Option::is_none", default)]
#[serde(
skip_serializing_if = "Option::is_none",
default,
deserialize_with = "deserialize_option_f64"
)]
pub missing: Option<f64>,
}
fn default_percentiles() -> &'static [f64] {
Expand Down
28 changes: 26 additions & 2 deletions src/aggregation/metric/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64};
use crate::aggregation::*;
use crate::{DocId, TantivyError};

/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
Expand All @@ -33,7 +33,7 @@ pub struct StatsAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down Expand Up @@ -580,6 +580,30 @@ mod tests {
})
);

// From string
let agg_req: Aggregations = serde_json::from_value(json!({
"my_stats": {
"stats": {
"field": "json.partially_empty",
"missing": "0.0"
},
}
}))
.unwrap();

let res = exec_request_with_query(agg_req, &index, None)?;

assert_eq!(
res["my_stats"],
json!({
"avg": 2.5,
"count": 4,
"max": 10.0,
"min": 0.0,
"sum": 10.0
})
);

Ok(())
}

Expand Down
5 changes: 3 additions & 2 deletions src/aggregation/metric/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::{IntermediateStats, SegmentStatsCollector};
use super::*;
use crate::aggregation::*;

/// A single-value metric aggregation that sums up numeric values that are
/// extracted from the aggregated documents.
Expand All @@ -24,7 +25,7 @@ pub struct SumAggregation {
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
}

Expand Down
Loading

0 comments on commit 7e41d31

Please sign in to comment.