Skip to content

Commit

Permalink
Apply a patch while #5444 is not fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Mar 4, 2023
1 parent db4610e commit bb290e9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 35 deletions.
58 changes: 38 additions & 20 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, Int64Array, StringArray};
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use async_trait::async_trait;
Expand Down Expand Up @@ -329,10 +329,10 @@ impl DataFrame {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];

let fields_iter = self.schema().fields().iter();
let original_schema_fields = self.schema().fields().iter();

//define describe column
let mut describe_schemas = fields_iter
let mut describe_schemas = original_schema_fields
.clone()
.map(|field| {
if field.data_type().is_numeric() {
Expand All @@ -344,24 +344,42 @@ impl DataFrame {
.collect::<Vec<_>>();
describe_schemas.insert(0, Field::new("describe", DataType::Utf8, false));

//count aggregation
let cnt = self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| count(col(f.name())))
.collect::<Vec<_>>(),
)?;
// The optimization of AggregateStatistics will rewrite the physical plan
// for the count function and ignore alias functions,
// as shown in https://github.com/apache/arrow-datafusion/issues/5444.
// This logic should be removed when #5444 is fixed.
let cnt = cnt
.clone()
.select(
cnt.schema()
.fields()
.iter()
.zip(original_schema_fields.clone())
.map(|(count_field, orgin_field)| {
col(count_field.name()).alias(orgin_field.name())
})
.collect::<Vec<_>>(),
)
.unwrap();
//should be removed when #5444 is fixed

//collect recordBatch
let describe_record_batch = vec![
// count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
.clone()
.map(|f| count(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
)?
.collect()
.await?,
cnt.collect().await.unwrap(),
// null_count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.map(|f| count(is_null(col(f.name()))).alias(f.name()))
.collect::<Vec<_>>(),
Expand All @@ -372,7 +390,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
Expand All @@ -384,7 +402,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(col(f.name())).alias(f.name()))
Expand All @@ -396,7 +414,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -410,7 +428,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -424,7 +442,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(col(f.name())).alias(f.name()))
Expand All @@ -435,7 +453,7 @@ impl DataFrame {
];

let mut array_ref_vec: Vec<ArrayRef> = vec![];
for field in fields_iter {
for field in original_schema_fields {
let mut array_datas = vec![];
for record_batch in describe_record_batch.iter() {
let column = record_batch.get(0).unwrap().column_by_name(field.name());
Expand Down
30 changes: 15 additions & 15 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ async fn describe() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();

let filename = &format!("{testdata}/alltypes_plain.parquet");

let df = ctx
.read_parquet(filename, ParquetReadOptions::default())
.read_parquet(
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;

let describe_record_batch = df.describe().await.unwrap().collect().await.unwrap();
#[rustfmt::skip]
let expected = vec![
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |",
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"| count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |",
"| null_count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |",
"| mean | 3.5 | null | 0.5 | 0.5 | 0.5 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |",
"| std | 2.4494897427831783 | null | 0.5345224838248488 | 0.5345224838248488 | 0.5345224838248488 | 5.3452248382484875 | 0.5879747449513427 | 5.398677086630973 | null | null | null |",
"| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | null | null | 2009-01-01T00:00:00 |",
"| max | 7.0 | null | 1.0 | 1.0 | 1.0 | 10.0 | 1.100000023841858 | 10.1 | null | null | 2009-04-01T00:01:00 |",
"| median | 3.0 | null | 0.0 | 0.0 | 0.0 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |",
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
"| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
"| count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |",
"| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |",
"| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45000000000001 | null | null | null | 2009.5 | 6.526027397260274 |",
"| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |",
"| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |",
"| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |",
"| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
];
assert_batches_eq!(expected, &describe_record_batch);

Expand Down

0 comments on commit bb290e9

Please sign in to comment.