Skip to content

Commit

Permalink
add Key::I64 and Key::U64 variants in aggregation (#2468)
Browse files Browse the repository at this point in the history
* add Key::I64 and Key::U64 variants in aggregation

Currently all `Key` numerical values are returned as f64. This causes problems in some
cases with the precision and the way f64 is serialized.

This PR adds `Key::I64` and `Key::U64` variants and uses them in the term
aggregation.

* add clarification comment
  • Loading branch information
PSeitz authored Jul 31, 2024
1 parent 75dc3eb commit 0d4e319
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 17 deletions.
25 changes: 25 additions & 0 deletions columnar/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@ impl NumericalValue {
NumericalValue::F64(_) => NumericalType::F64,
}
}

/// Tries to normalize the numerical value in the following priorities:
/// i64, i64, f64
pub fn normalize(self) -> Self {
match self {
NumericalValue::U64(val) => {
if val <= i64::MAX as u64 {
NumericalValue::I64(val as i64)
} else {
NumericalValue::F64(val as f64)
}
}
NumericalValue::I64(val) => NumericalValue::I64(val),
NumericalValue::F64(val) => {
let fract = val.fract();
if fract == 0.0 && val >= i64::MIN as f64 && val <= i64::MAX as f64 {
NumericalValue::I64(val as i64)
} else if fract == 0.0 && val >= u64::MIN as f64 && val <= u64::MAX as f64 {
NumericalValue::U64(val as u64)
} else {
NumericalValue::F64(val)
}
}
}
}
}

impl From<u64> for NumericalValue {
Expand Down
37 changes: 29 additions & 8 deletions src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ impl AggregationWithAccessor {
.map(|missing| match missing {
Key::Str(_) => ColumnType::Str,
Key::F64(_) => ColumnType::F64,
Key::I64(_) => ColumnType::I64,
Key::U64(_) => ColumnType::U64,
})
.unwrap_or(ColumnType::U64);
let column_and_types = get_all_ff_reader_or_empty(
Expand Down Expand Up @@ -232,13 +234,16 @@ impl AggregationWithAccessor {
missing.clone()
};

let missing_value_for_accessor = if let Some(missing) =
missing_value_term_agg.as_ref()
{
get_missing_val(column_type, missing, agg.agg.get_fast_field_names()[0])?
} else {
None
};
let missing_value_for_accessor =
if let Some(missing) = missing_value_term_agg.as_ref() {
get_missing_val_as_u64_lenient(
column_type,
missing,
agg.agg.get_fast_field_names()[0],
)?
} else {
None
};

let agg = AggregationWithAccessor {
segment_ordinal,
Expand Down Expand Up @@ -330,7 +335,14 @@ impl AggregationWithAccessor {
}
}

fn get_missing_val(
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
missing: &Key,
field_name: &str,
Expand All @@ -339,9 +351,18 @@ fn get_missing_val(
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
Expand Down
61 changes: 59 additions & 2 deletions src/aggregation/agg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,15 +939,72 @@ fn test_aggregation_on_json_object_mixed_types() {
},
"termagg": {
"buckets": [
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
{ "doc_count": 1, "key": 10, "min_price": { "value": 10.0 } },
{ "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } },
{ "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } },
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
{ "doc_count": 2, "key": 1, "key_as_string": "true", "min_price": { "value": null } },
],
"sum_other_doc_count": 0
}
}
)
);
}

#[test]
fn test_aggregation_on_json_object_mixed_numerical_segments() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
// => Segment with all values f64 numeric
index_writer
.add_document(doc!(json => json!({"mixed_price": 10.5})))
.unwrap();
// Gets converted to f64!
index_writer
.add_document(doc!(json => json!({"mixed_price": 10})))
.unwrap();
index_writer.commit().unwrap();
// => Segment with all values i64 numeric
index_writer
.add_document(doc!(json => json!({"mixed_price": 10})))
.unwrap();
index_writer.commit().unwrap();

index_writer.commit().unwrap();

// All bucket types
let agg_req_str = r#"
{
"termagg": {
"terms": {
"field": "json.mixed_price"
}
}
} "#;
let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap();
let aggregation_collector = get_collector(agg);
let reader = index.reader().unwrap();
let searcher = reader.searcher();

let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
use pretty_assertions::assert_eq;
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"termagg": {
"buckets": [
{ "doc_count": 2, "key": 10},
{ "doc_count": 1, "key": 10.5},
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
}
)
);
}
82 changes: 78 additions & 4 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::io;
use std::net::Ipv6Addr;

use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64};
use columnar::{
ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalValue,
};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};

Expand All @@ -19,7 +21,7 @@ use crate::aggregation::intermediate_agg_result::{
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, format_date, Key};
use crate::aggregation::{format_date, Key};
use crate::error::DataCorruption;
use crate::TantivyError;

Expand Down Expand Up @@ -497,6 +499,12 @@ impl SegmentTermCollector {
Key::F64(val) => {
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
}
Key::U64(val) => {
dict.insert(IntermediateKey::U64(*val), intermediate_entry);
}
Key::I64(val) => {
dict.insert(IntermediateKey::I64(*val), intermediate_entry);
}
}

entries.swap_remove(index);
Expand Down Expand Up @@ -583,8 +591,26 @@ impl SegmentTermCollector {
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.column_type);
dict.insert(IntermediateKey::F64(val), intermediate_entry);
if self.column_type == ColumnType::U64 {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
} else if self.column_type == ColumnType::I64 {
dict.insert(IntermediateKey::I64(i64::from_u64(val)), intermediate_entry);
} else {
let val = f64::from_u64(val);
let val: NumericalValue = val.into();

match val.normalize() {
NumericalValue::U64(val) => {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
}
NumericalValue::I64(val) => {
dict.insert(IntermediateKey::I64(val), intermediate_entry);
}
NumericalValue::F64(val) => {
dict.insert(IntermediateKey::F64(val), intermediate_entry);
}
}
};
}
};

Expand Down Expand Up @@ -1719,6 +1745,54 @@ mod tests {
Ok(())
}

#[test]
fn terms_aggregation_u64_value() -> crate::Result<()> {
// Make sure that large u64 are not truncated
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
index_writer.add_document(doc!(
id_field => 9_223_372_036_854_775_807u64,
))?;
index_writer.add_document(doc!(
id_field => 1_769_070_189_829_214_202u64,
))?;
index_writer.add_document(doc!(
id_field => 1_769_070_189_829_214_202u64,
))?;
index_writer.commit()?;
}

let agg_req: Aggregations = serde_json::from_value(json!({
"my_ids": {
"terms": {
"field": "id"
},
}
}))
.unwrap();

let res = exec_request_with_query(agg_req, &index, None)?;

// id field
assert_eq!(
res["my_ids"]["buckets"][0]["key"],
1_769_070_189_829_214_202u64
);
assert_eq!(res["my_ids"]["buckets"][0]["doc_count"], 2);
assert_eq!(
res["my_ids"]["buckets"][1]["key"],
9_223_372_036_854_775_807u64
);
assert_eq!(res["my_ids"]["buckets"][1]["doc_count"], 1);
assert_eq!(res["my_ids"]["buckets"][2]["key"], serde_json::Value::Null);

Ok(())
}

#[test]
fn terms_aggregation_missing1() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
Expand Down
1 change: 0 additions & 1 deletion src/aggregation/bucket/term_missing_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ impl SegmentAggregationCollector for TermMissingAgg {
)?;
missing_entry.sub_aggregation = res;
}

entries.insert(missing.into(), missing_entry);

let bucket = IntermediateBucketResult::Terms {
Expand Down
12 changes: 11 additions & 1 deletion src/aggregation/intermediate_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ pub enum IntermediateKey {
Str(String),
/// `f64` key
F64(f64),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
}
impl From<Key> for IntermediateKey {
fn from(value: Key) -> Self {
match value {
Key::Str(s) => Self::Str(s),
Key::F64(f) => Self::F64(f),
Key::U64(f) => Self::U64(f),
Key::I64(f) => Self::I64(f),
}
}
}
Expand All @@ -73,7 +79,9 @@ impl From<IntermediateKey> for Key {
}
}
IntermediateKey::F64(f) => Self::F64(f),
IntermediateKey::Bool(f) => Self::F64(f as u64 as f64),
IntermediateKey::Bool(f) => Self::U64(f as u64),
IntermediateKey::U64(f) => Self::U64(f),
IntermediateKey::I64(f) => Self::I64(f),
}
}
}
Expand All @@ -86,6 +94,8 @@ impl std::hash::Hash for IntermediateKey {
match self {
IntermediateKey::Str(text) => text.hash(state),
IntermediateKey::F64(val) => val.to_bits().hash(state),
IntermediateKey::U64(val) => val.hash(state),
IntermediateKey::I64(val) => val.hash(state),
IntermediateKey::Bool(val) => val.hash(state),
IntermediateKey::IpAddr(val) => val.hash(state),
}
Expand Down
9 changes: 8 additions & 1 deletion src/aggregation/metric/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ impl SegmentCardinalityCollector {
Ok(())
})?;
if has_missing {
// Replace missing with the actual value provided
let missing_key = self
.missing
.as_ref()
.expect("Found placeholder term_ord but `missing` is None");
.expect("Found sentinel value u64::MAX for term_ord but `missing` is not set");
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
Expand All @@ -191,6 +192,12 @@ impl SegmentCardinalityCollector {
let val = f64_to_u64(*val);
self.cardinality.sketch.insert_any(&val);
}
Key::U64(val) => {
self.cardinality.sketch.insert_any(&val);
}
Key::I64(val) => {
self.cardinality.sketch.insert_any(&val);
}
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,16 @@ pub type SerializedKey = String;

#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)]
/// The key to identify a bucket.
///
/// The order is important, with serde untagged, that we try to deserialize into i64 first.
#[serde(untagged)]
pub enum Key {
/// String key
Str(String),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
/// `f64` key
F64(f64),
}
Expand All @@ -350,6 +356,8 @@ impl std::hash::Hash for Key {
match self {
Key::Str(text) => text.hash(state),
Key::F64(val) => val.to_bits().hash(state),
Key::U64(val) => val.hash(state),
Key::I64(val) => val.hash(state),
}
}
}
Expand All @@ -369,6 +377,8 @@ impl Display for Key {
match self {
Key::Str(val) => f.write_str(val),
Key::F64(val) => f.write_str(&val.to_string()),
Key::U64(val) => f.write_str(&val.to_string()),
Key::I64(val) => f.write_str(&val.to_string()),
}
}
}
Expand Down

0 comments on commit 0d4e319

Please sign in to comment.