Skip to content

Commit

Permalink
support bool type in term aggregation (#2318)
Browse files Browse the repository at this point in the history
* support bool type in term aggregation

* add Bool to Intermediate Key
  • Loading branch information
PSeitz authored Feb 20, 2024
1 parent f745dbc commit d57622d
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ impl AggregationWithAccessor {
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
// ColumnType::Bytes Unsupported
// ColumnType::Bool Unsupported
// ColumnType::IpAddr Unsupported
];

Expand Down
29 changes: 23 additions & 6 deletions src/aggregation/agg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ fn test_aggregation_on_json_object() {
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red"})))
.unwrap();
Expand Down Expand Up @@ -614,8 +617,8 @@ fn test_aggregation_on_json_object() {
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 2, "key": "red"},
{"doc_count": 1, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
Expand All @@ -637,6 +640,9 @@ fn test_aggregation_on_nested_json_object() {
index_writer
.add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} })))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} })))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
Expand Down Expand Up @@ -664,15 +670,15 @@ fn test_aggregation_on_nested_json_object() {
&serde_json::json!({
"jsonagg1": {
"buckets": [
{"doc_count": 1, "key": "blue"},
{"doc_count": 2, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
},
"jsonagg2": {
"buckets": [
{"doc_count": 1, "key": "blue"},
{"doc_count": 2, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
Expand Down Expand Up @@ -814,6 +820,12 @@ fn test_aggregation_on_json_object_mixed_types() {
.unwrap();
index_writer.commit().unwrap();
// => Segment with all values text
index_writer
.add_document(doc!(json => json!({"mixed_type": "blue"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"mixed_type": "blue"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"mixed_type": "blue"})))
.unwrap();
Expand All @@ -825,6 +837,9 @@ fn test_aggregation_on_json_object_mixed_types() {
index_writer.commit().unwrap();

// => Segment with mixed values
index_writer
.add_document(doc!(json => json!({"mixed_type": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"mixed_type": "red"})))
.unwrap();
Expand Down Expand Up @@ -870,6 +885,8 @@ fn test_aggregation_on_json_object_mixed_types() {

let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
// pretty print as json
use pretty_assertions::assert_eq;
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
Expand All @@ -885,9 +902,9 @@ fn test_aggregation_on_json_object_mixed_types() {
"buckets": [
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
// TODO bool is also not yet handled in aggregation
{ "doc_count": 1, "key": "blue", "min_price": { "value": null } },
{ "doc_count": 1, "key": "red", "min_price": { "value": null } },
{ "doc_count": 2, "key": "red", "min_price": { "value": null } },
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
{ "doc_count": 3, "key": "blue", "min_price": { "value": null } },
],
"sum_other_doc_count": 0
}
Expand Down
18 changes: 10 additions & 8 deletions src/aggregation/bucket/histogram/date_histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ pub mod tests {
let docs = vec![
vec![r#"{ "date": "2015-01-01T12:10:30Z", "text": "aaa" }"#],
vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#],
vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#],
vec![r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb" }"#],
vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#],
vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#],
];
let index = get_test_index_from_docs(merge_segments, &docs).unwrap();

Expand Down Expand Up @@ -382,7 +384,7 @@ pub mod tests {
{
"key_as_string" : "2015-01-01T00:00:00Z",
"key" : 1420070400000.0,
"doc_count" : 4
"doc_count" : 6
}
]
}
Expand Down Expand Up @@ -420,15 +422,15 @@ pub mod tests {
{
"key_as_string" : "2015-01-01T00:00:00Z",
"key" : 1420070400000.0,
"doc_count" : 4,
"doc_count" : 6,
"texts": {
"buckets": [
{
"doc_count": 2,
"doc_count": 3,
"key": "bbb"
},
{
"doc_count": 1,
"doc_count": 2,
"key": "ccc"
},
{
Expand Down Expand Up @@ -467,7 +469,7 @@ pub mod tests {
"sales_over_time": {
"buckets": [
{
"doc_count": 2,
"doc_count": 3,
"key": 1420070400000.0,
"key_as_string": "2015-01-01T00:00:00Z"
},
Expand All @@ -492,7 +494,7 @@ pub mod tests {
"key_as_string": "2015-01-05T00:00:00Z"
},
{
"doc_count": 1,
"doc_count": 2,
"key": 1420502400000.0,
"key_as_string": "2015-01-06T00:00:00Z"
}
Expand Down Expand Up @@ -533,7 +535,7 @@ pub mod tests {
"key_as_string": "2014-12-31T00:00:00Z"
},
{
"doc_count": 2,
"doc_count": 3,
"key": 1420070400000.0,
"key_as_string": "2015-01-01T00:00:00Z"
},
Expand All @@ -558,7 +560,7 @@ pub mod tests {
"key_as_string": "2015-01-05T00:00:00Z"
},
{
"doc_count": 1,
"doc_count": 2,
"key": 1420502400000.0,
"key_as_string": "2015-01-06T00:00:00Z"
},
Expand Down
64 changes: 53 additions & 11 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ pub struct SegmentTermCollector {
term_buckets: TermBuckets,
req: TermsAggregationInternal,
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
field_type: ColumnType,
column_type: ColumnType,
accessor_idx: usize,
}

Expand Down Expand Up @@ -355,7 +355,7 @@ impl SegmentTermCollector {
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
if field_type == ColumnType::Bytes || field_type == ColumnType::Bool {
if field_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {:?}",
field_type
Expand Down Expand Up @@ -389,7 +389,7 @@ impl SegmentTermCollector {
req: TermsAggregationInternal::from_req(req),
term_buckets,
blueprint,
field_type,
column_type: field_type,
accessor_idx,
})
}
Expand Down Expand Up @@ -466,7 +466,7 @@ impl SegmentTermCollector {
Ok(intermediate_entry)
};

if self.field_type == ColumnType::Str {
if self.column_type == ColumnType::Str {
let term_dict = agg_with_accessor
.str_dict_column
.as_ref()
Expand Down Expand Up @@ -531,28 +531,34 @@ impl SegmentTermCollector {
});
}
}
} else if self.field_type == ColumnType::DateTime {
} else if self.column_type == ColumnType::DateTime {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = i64::from_u64(val);
let date = format_date(val)?;
dict.insert(IntermediateKey::Str(date), intermediate_entry);
}
} else if self.column_type == ColumnType::Bool {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = bool::from_u64(val);
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
}
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.field_type);
let val = f64_from_fastfield_u64(val, &self.column_type);
dict.insert(IntermediateKey::F64(val), intermediate_entry);
}
};

Ok(IntermediateBucketResult::Terms(
IntermediateTermBucketResult {
Ok(IntermediateBucketResult::Terms {
buckets: IntermediateTermBucketResult {
entries: dict,
sum_other_doc_count,
doc_count_error_upper_bound: term_doc_count_before_cutoff,
},
))
})
}
}

Expand Down Expand Up @@ -1365,7 +1371,7 @@ mod tests {

#[test]
fn terms_aggregation_different_tokenizer_on_ff_test() -> crate::Result<()> {
let terms = vec!["Hello Hello", "Hallo Hallo"];
let terms = vec!["Hello Hello", "Hallo Hallo", "Hallo Hallo"];

let index = get_test_index_from_terms(true, &[terms])?;

Expand All @@ -1383,7 +1389,7 @@ mod tests {
println!("{}", serde_json::to_string_pretty(&res).unwrap());

assert_eq!(res["my_texts"]["buckets"][0]["key"], "Hallo Hallo");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);

assert_eq!(res["my_texts"]["buckets"][1]["key"], "Hello Hello");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1);
Expand Down Expand Up @@ -1894,4 +1900,40 @@ mod tests {

Ok(())
}

#[test]
fn terms_aggregation_bool() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_bool_field("bool_field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut writer = index.writer_with_num_threads(1, 15_000_000)?;
writer.add_document(doc!(field=>true))?;
writer.add_document(doc!(field=>false))?;
writer.add_document(doc!(field=>true))?;
writer.commit()?;
}

let agg_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "bool_field"
},
}
}))
.unwrap();

let res = exec_request_with_query(agg_req, &index, None)?;

assert_eq!(res["my_bool"]["buckets"][0]["key"], 1.0);
assert_eq!(res["my_bool"]["buckets"][0]["key_as_string"], "true");
assert_eq!(res["my_bool"]["buckets"][0]["doc_count"], 2);
assert_eq!(res["my_bool"]["buckets"][1]["key"], 0.0);
assert_eq!(res["my_bool"]["buckets"][1]["key_as_string"], "false");
assert_eq!(res["my_bool"]["buckets"][1]["doc_count"], 1);
assert_eq!(res["my_bool"]["buckets"][2]["key"], serde_json::Value::Null);

Ok(())
}
}
12 changes: 7 additions & 5 deletions src/aggregation/bucket/term_missing_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ impl SegmentAggregationCollector for TermMissingAgg {

entries.insert(missing.into(), missing_entry);

let bucket = IntermediateBucketResult::Terms(IntermediateTermBucketResult {
entries,
sum_other_doc_count: 0,
doc_count_error_upper_bound: 0,
});
let bucket = IntermediateBucketResult::Terms {
buckets: IntermediateTermBucketResult {
entries,
sum_other_doc_count: 0,
doc_count_error_upper_bound: 0,
},
};

results.push(name, IntermediateAggregationResult::Bucket(bucket))?;

Expand Down
Loading

0 comments on commit d57622d

Please sign in to comment.