From d57622d54b0c69740d08bbd02f5b10a60373fea0 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Tue, 20 Feb 2024 03:22:22 +0100 Subject: [PATCH] support bool type in term aggregation (#2318) * support bool type in term aggregation * add Bool to Intermediate Key --- src/aggregation/agg_req_with_accessor.rs | 2 +- src/aggregation/agg_tests.rs | 29 +++++++-- .../bucket/histogram/date_histogram.rs | 18 +++--- src/aggregation/bucket/term_agg.rs | 64 +++++++++++++++---- src/aggregation/bucket/term_missing_agg.rs | 12 ++-- src/aggregation/intermediate_agg_result.rs | 36 ++++++++--- src/aggregation/mod.rs | 2 + 7 files changed, 123 insertions(+), 40 deletions(-) diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 2e4e3b9388..2e7a617efc 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -169,8 +169,8 @@ impl AggregationWithAccessor { ColumnType::F64, ColumnType::Str, ColumnType::DateTime, + ColumnType::Bool, // ColumnType::Bytes Unsupported - // ColumnType::Bool Unsupported // ColumnType::IpAddr Unsupported ]; diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 1fea7fe6f4..d9008becfd 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -587,6 +587,9 @@ fn test_aggregation_on_json_object() { let schema = schema_builder.build(); let index = Index::create_in_ram(schema); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); + index_writer + .add_document(doc!(json => json!({"color": "red"}))) + .unwrap(); index_writer .add_document(doc!(json => json!({"color": "red"}))) .unwrap(); @@ -614,8 +617,8 @@ fn test_aggregation_on_json_object() { &serde_json::json!({ "jsonagg": { "buckets": [ + {"doc_count": 2, "key": "red"}, {"doc_count": 1, "key": "blue"}, - {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, "sum_other_doc_count": 0 @@ -637,6 +640,9 @@ fn test_aggregation_on_nested_json_object() { index_writer .add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} }))) .unwrap(); + index_writer + .add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} }))) + .unwrap(); index_writer.commit().unwrap(); let reader = index.reader().unwrap(); let searcher = reader.searcher(); @@ -664,7 +670,7 @@ fn test_aggregation_on_nested_json_object() { &serde_json::json!({ "jsonagg1": { "buckets": [ - {"doc_count": 1, "key": "blue"}, + {"doc_count": 2, "key": "blue"}, {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, @@ -672,7 +678,7 @@ fn test_aggregation_on_nested_json_object() { }, "jsonagg2": { "buckets": [ - {"doc_count": 1, "key": "blue"}, + {"doc_count": 2, "key": "blue"}, {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, @@ -814,6 +820,12 @@ fn test_aggregation_on_json_object_mixed_types() { .unwrap(); index_writer.commit().unwrap(); // => Segment with all values text + index_writer + .add_document(doc!(json => json!({"mixed_type": "blue"}))) + .unwrap(); + index_writer + .add_document(doc!(json => json!({"mixed_type": "blue"}))) + .unwrap(); index_writer .add_document(doc!(json => json!({"mixed_type": "blue"}))) .unwrap(); @@ -825,6 +837,9 @@ fn test_aggregation_on_json_object_mixed_types() { index_writer.commit().unwrap(); // => Segment with mixed values + index_writer + .add_document(doc!(json => json!({"mixed_type": "red"}))) + .unwrap(); index_writer .add_document(doc!(json => json!({"mixed_type": "red"}))) .unwrap(); @@ -870,6 +885,8 @@ fn test_aggregation_on_json_object_mixed_types() { let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); + // pretty print as json + use pretty_assertions::assert_eq; assert_eq!( &aggregation_res_json, &serde_json::json!({ @@ -885,9 +902,9 @@ fn test_aggregation_on_json_object_mixed_types() { "buckets": [ { "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } }, { "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } }, - // TODO bool is also not yet handled in aggregation - { "doc_count": 1, "key": "blue", "min_price": { "value": null } }, - { "doc_count": 1, "key": "red", "min_price": { "value": null } }, + { "doc_count": 2, "key": "red", "min_price": { "value": null } }, + { "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } }, + { "doc_count": 3, "key": "blue", "min_price": { "value": null } }, ], "sum_other_doc_count": 0 } diff --git a/src/aggregation/bucket/histogram/date_histogram.rs b/src/aggregation/bucket/histogram/date_histogram.rs index b4919a4e56..cdbf50e214 100644 --- a/src/aggregation/bucket/histogram/date_histogram.rs +++ b/src/aggregation/bucket/histogram/date_histogram.rs @@ -352,8 +352,10 @@ pub mod tests { let docs = vec![ vec![r#"{ "date": "2015-01-01T12:10:30Z", "text": "aaa" }"#], vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#], + vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#], vec![r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb" }"#], vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#], + vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#], ]; let index = get_test_index_from_docs(merge_segments, &docs).unwrap(); @@ -382,7 +384,7 @@ pub mod tests { { "key_as_string" : "2015-01-01T00:00:00Z", "key" : 1420070400000.0, - "doc_count" : 4 + "doc_count" : 6 } ] } @@ -420,15 +422,15 @@ pub mod tests { { "key_as_string" : "2015-01-01T00:00:00Z", "key" : 1420070400000.0, - "doc_count" : 4, + "doc_count" : 6, "texts": { "buckets": [ { - "doc_count": 2, + "doc_count": 3, "key": "bbb" }, { - "doc_count": 1, + "doc_count": 2, "key": "ccc" }, { @@ -467,7 +469,7 @@ pub mod tests { "sales_over_time": { "buckets": [ { - "doc_count": 2, + "doc_count": 3, "key": 1420070400000.0, "key_as_string": "2015-01-01T00:00:00Z" }, @@ -492,7 +494,7 @@ pub mod tests { "key_as_string": "2015-01-05T00:00:00Z" }, { - "doc_count": 1, + "doc_count": 2, "key": 1420502400000.0, "key_as_string": "2015-01-06T00:00:00Z" } @@ -533,7 +535,7 @@ pub mod tests { "key_as_string": "2014-12-31T00:00:00Z" }, { - "doc_count": 2, + "doc_count": 3, "key": 1420070400000.0, "key_as_string": "2015-01-01T00:00:00Z" }, @@ -558,7 +560,7 @@ pub mod tests { "key_as_string": "2015-01-05T00:00:00Z" }, { - "doc_count": 1, + "doc_count": 2, "key": 1420502400000.0, "key_as_string": "2015-01-06T00:00:00Z" }, diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index b0e40f88db..2c825e6969 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -256,7 +256,7 @@ pub struct SegmentTermCollector { term_buckets: TermBuckets, req: TermsAggregationInternal, blueprint: Option>, - field_type: ColumnType, + column_type: ColumnType, accessor_idx: usize, } @@ -355,7 +355,7 @@ impl SegmentTermCollector { field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { - if field_type == ColumnType::Bytes || field_type == ColumnType::Bool { + if field_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( "terms aggregation is not supported for column type {:?}", field_type @@ -389,7 +389,7 @@ impl SegmentTermCollector { req: TermsAggregationInternal::from_req(req), term_buckets, blueprint, - field_type, + column_type: field_type, accessor_idx, }) } @@ -466,7 +466,7 @@ impl SegmentTermCollector { Ok(intermediate_entry) }; - if self.field_type == ColumnType::Str { + if self.column_type == ColumnType::Str { let term_dict = agg_with_accessor .str_dict_column .as_ref() @@ -531,28 +531,34 @@ impl SegmentTermCollector { }); } } - } else if self.field_type == ColumnType::DateTime { + } else if self.column_type == ColumnType::DateTime { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } + } else if self.column_type == ColumnType::Bool { + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val = bool::from_u64(val); + dict.insert(IntermediateKey::Bool(val), intermediate_entry); + } } else { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val = f64_from_fastfield_u64(val, &self.field_type); + let val = f64_from_fastfield_u64(val, &self.column_type); dict.insert(IntermediateKey::F64(val), intermediate_entry); } }; - Ok(IntermediateBucketResult::Terms( - IntermediateTermBucketResult { + Ok(IntermediateBucketResult::Terms { + buckets: IntermediateTermBucketResult { entries: dict, sum_other_doc_count, doc_count_error_upper_bound: term_doc_count_before_cutoff, }, - )) + }) } } @@ -1365,7 +1371,7 @@ mod tests { #[test] fn terms_aggregation_different_tokenizer_on_ff_test() -> crate::Result<()> { - let terms = vec!["Hello Hello", "Hallo Hallo"]; + let terms = vec!["Hello Hello", "Hallo Hallo", "Hallo Hallo"]; let index = get_test_index_from_terms(true, &[terms])?; @@ -1383,7 +1389,7 @@ mod tests { println!("{}", serde_json::to_string_pretty(&res).unwrap()); assert_eq!(res["my_texts"]["buckets"][0]["key"], "Hallo Hallo"); - assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2); assert_eq!(res["my_texts"]["buckets"][1]["key"], "Hello Hello"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1); @@ -1894,4 +1900,40 @@ mod tests { Ok(()) } + + #[test] + fn terms_aggregation_bool() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let field = schema_builder.add_bool_field("bool_field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut writer = index.writer_with_num_threads(1, 15_000_000)?; + writer.add_document(doc!(field=>true))?; + writer.add_document(doc!(field=>false))?; + writer.add_document(doc!(field=>true))?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_bool": { + "terms": { + "field": "bool_field" + }, + } + })) + .unwrap(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + assert_eq!(res["my_bool"]["buckets"][0]["key"], 1.0); + assert_eq!(res["my_bool"]["buckets"][0]["key_as_string"], "true"); + assert_eq!(res["my_bool"]["buckets"][0]["doc_count"], 2); + assert_eq!(res["my_bool"]["buckets"][1]["key"], 0.0); + assert_eq!(res["my_bool"]["buckets"][1]["key_as_string"], "false"); + assert_eq!(res["my_bool"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["my_bool"]["buckets"][2]["key"], serde_json::Value::Null); + + Ok(()) + } } diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index a863d5eb2c..bb8b295b49 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -73,11 +73,13 @@ impl SegmentAggregationCollector for TermMissingAgg { entries.insert(missing.into(), missing_entry); - let bucket = IntermediateBucketResult::Terms(IntermediateTermBucketResult { - entries, - sum_other_doc_count: 0, - doc_count_error_upper_bound: 0, - }); + let bucket = IntermediateBucketResult::Terms { + buckets: IntermediateTermBucketResult { + entries, + sum_other_doc_count: 0, + doc_count_error_upper_bound: 0, + }, + }; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 9e07d68aa9..9b2527fdea 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -41,6 +41,8 @@ pub struct IntermediateAggregationResults { /// This might seem redundant with `Key`, but the point is to have a different /// Serialize implementation. pub enum IntermediateKey { + /// Bool key + Bool(bool), /// String key Str(String), /// `f64` key @@ -59,6 +61,7 @@ impl From for Key { match value { IntermediateKey::Str(s) => Self::Str(s), IntermediateKey::F64(f) => Self::F64(f), + IntermediateKey::Bool(f) => Self::F64(f as u64 as f64), } } } @@ -71,6 +74,7 @@ impl std::hash::Hash for IntermediateKey { match self { IntermediateKey::Str(text) => text.hash(state), IntermediateKey::F64(val) => val.to_bits().hash(state), + IntermediateKey::Bool(val) => val.hash(state), } } } @@ -166,9 +170,9 @@ impl IntermediateAggregationResults { pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult { use AggregationVariants::*; match req.agg { - Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms( - Default::default(), - )), + Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms { + buckets: Default::default(), + }), Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( Default::default(), )), @@ -363,11 +367,14 @@ pub enum IntermediateBucketResult { Histogram { /// The column_type of the underlying `Column` is DateTime is_date_agg: bool, - /// The buckets + /// The histogram buckets buckets: Vec, }, /// Term aggregation - Terms(IntermediateTermBucketResult), + Terms { + /// The term buckets + buckets: IntermediateTermBucketResult, + }, } impl IntermediateBucketResult { @@ -444,7 +451,7 @@ impl IntermediateBucketResult { }; Ok(BucketResult::Histogram { buckets }) } - IntermediateBucketResult::Terms(terms) => terms.into_final_result( + IntermediateBucketResult::Terms { buckets: terms } => terms.into_final_result( req.agg .as_term() .expect("unexpected aggregation, expected term aggregation"), @@ -457,8 +464,12 @@ impl IntermediateBucketResult { fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> { match (self, other) { ( - IntermediateBucketResult::Terms(term_res_left), - IntermediateBucketResult::Terms(term_res_right), + IntermediateBucketResult::Terms { + buckets: term_res_left, + }, + IntermediateBucketResult::Terms { + buckets: term_res_right, + }, ) => { merge_maps(&mut term_res_left.entries, term_res_right.entries)?; term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count; @@ -542,8 +553,15 @@ impl IntermediateTermBucketResult { .into_iter() .filter(|bucket| bucket.1.doc_count as u64 >= req.min_doc_count) .map(|(key, entry)| { + let key_as_string = match key { + IntermediateKey::Bool(key) => { + let val = if key { "true" } else { "false" }; + Some(val.to_string()) + } + _ => None, + }; Ok(BucketEntry { - key_as_string: None, + key_as_string, key: key.into(), doc_count: entry.doc_count as u64, sub_aggregation: entry diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 9b64825463..fe01a9ec3c 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -281,6 +281,7 @@ pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 { ColumnType::U64 => val as f64, ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64, ColumnType::F64 => f64::from_u64(val), + ColumnType::Bool => val as f64, _ => { panic!("unexpected type {field_type:?}. This should not happen") } @@ -301,6 +302,7 @@ pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &ColumnType) -> Option< ColumnType::U64 => Some(val as u64), ColumnType::I64 | ColumnType::DateTime => Some((val as i64).to_u64()), ColumnType::F64 => Some(val.to_u64()), + ColumnType::Bool => Some(val as u64), _ => None, } }