Skip to content

Commit

Permalink
use table stats for descrbe method count result
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Mar 3, 2023
1 parent db4610e commit 81456f5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 33 deletions.
42 changes: 24 additions & 18 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, Int64Array, StringArray};
use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use async_trait::async_trait;
Expand Down Expand Up @@ -329,10 +329,10 @@ impl DataFrame {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];

let fields_iter = self.schema().fields().iter();
let original_schema_fields = self.schema().fields().iter();

//define describe column
let mut describe_schemas = fields_iter
let mut describe_schemas = original_schema_fields
.clone()
.map(|field| {
if field.data_type().is_numeric() {
Expand All @@ -344,24 +344,30 @@ impl DataFrame {
.collect::<Vec<_>>();
describe_schemas.insert(0, Field::new("describe", DataType::Utf8, false));

//count aggregation
let cnt = self.clone().count().await?;

//collect recordBatch
let describe_record_batch = vec![
// count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
vec![RecordBatch::try_new(
Arc::new(Schema::new(
original_schema_fields
.clone()
.map(|f| count(col(f.name())).alias(f.name()))
.map(|field| Field::new(field.name(), DataType::Float64, true))
.collect::<Vec<_>>(),
)?
.collect()
.await?,
)),
(0..original_schema_fields.len())
.map(|_n| {
Arc::new(Float64Array::from_slice(vec![cnt as f64])) as ArrayRef
})
.collect::<Vec<_>>(),
)?],
// null_count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.map(|f| count(is_null(col(f.name()))).alias(f.name()))
.collect::<Vec<_>>(),
Expand All @@ -372,7 +378,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
Expand All @@ -384,7 +390,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(col(f.name())).alias(f.name()))
Expand All @@ -396,7 +402,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -410,7 +416,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -424,7 +430,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(col(f.name())).alias(f.name()))
Expand All @@ -435,7 +441,7 @@ impl DataFrame {
];

let mut array_ref_vec: Vec<ArrayRef> = vec![];
for field in fields_iter {
for field in original_schema_fields {
let mut array_datas = vec![];
for record_batch in describe_record_batch.iter() {
let column = record_batch.get(0).unwrap().column_by_name(field.name());
Expand Down
30 changes: 15 additions & 15 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ async fn describe() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();

let filename = &format!("{testdata}/alltypes_plain.parquet");

let df = ctx
.read_parquet(filename, ParquetReadOptions::default())
.read_parquet(
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;

let describe_record_batch = df.describe().await.unwrap().collect().await.unwrap();
#[rustfmt::skip]
let expected = vec![
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |",
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"| count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |",
"| null_count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |",
"| mean | 3.5 | null | 0.5 | 0.5 | 0.5 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |",
"| std | 2.4494897427831783 | null | 0.5345224838248488 | 0.5345224838248488 | 0.5345224838248488 | 5.3452248382484875 | 0.5879747449513427 | 5.398677086630973 | null | null | null |",
"| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | null | null | 2009-01-01T00:00:00 |",
"| max | 7.0 | null | 1.0 | 1.0 | 1.0 | 10.0 | 1.100000023841858 | 10.1 | null | null | 2009-04-01T00:01:00 |",
"| median | 3.0 | null | 0.0 | 0.0 | 0.0 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |",
"+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
"| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
"| count | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 |",
"| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |",
"| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45000000000001 | null | null | null | 2009.5 | 6.526027397260274 |",
"| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |",
"| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |",
"| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |",
"| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |",
"+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+",
];
assert_batches_eq!(expected, &describe_record_batch);

Expand Down

0 comments on commit 81456f5

Please sign in to comment.