From b58ec81ab06af7a267ee69b834715727dbef963d Mon Sep 17 00:00:00 2001 From: Burak Date: Wed, 9 Nov 2022 00:58:04 +0300 Subject: [PATCH] improve error messages while downcasting Int32Array (#4146) --- benchmarks/src/tpch.rs | 6 ++-- datafusion/common/src/cast.rs | 12 +++++++- .../src/avro_to_arrow/arrow_array_reader.rs | 9 ++---- .../core/src/datasource/file_format/avro.rs | 10 ++----- .../src/datasource/file_format/parquet.rs | 9 ++---- datafusion/core/tests/sql/udf.rs | 29 ++++--------------- .../simplify_expressions/expr_simplifier.rs | 12 ++------ .../physical-expr/src/expressions/case.rs | 21 ++++---------- .../physical-expr/src/expressions/literal.rs | 3 +- datafusion/physical-expr/src/physical_expr.rs | 8 ++--- .../physical-expr/src/window/lead_lag.rs | 3 +- .../physical-expr/src/window/nth_value.rs | 3 +- test-utils/Cargo.toml | 1 + test-utils/src/lib.rs | 10 ++----- 14 files changed, 51 insertions(+), 85 deletions(-) diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs index bd3b3080fc48..74acf4cc2860 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::{ - Array, ArrayRef, Decimal128Array, Float64Array, Int32Array, Int64Array, StringArray, + Array, ArrayRef, Decimal128Array, Float64Array, Int64Array, StringArray, }; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -26,7 +26,7 @@ use std::path::Path; use std::sync::Arc; use std::time::Instant; -use datafusion::common::cast::as_date32_array; +use datafusion::common::cast::{as_date32_array, as_int32_array}; use datafusion::common::ScalarValue; use datafusion::logical_expr::Cast; use datafusion::prelude::*; @@ -424,7 +424,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue { } match column.data_type() { DataType::Int32 => { - let array = column.as_any().downcast_ref::().unwrap(); + let array = as_int32_array(column).unwrap(); ScalarValue::Int32(Some(array.value(row_index))) } DataType::Int64 => { diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 2ce0ec22439b..20d11c3f737f 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -21,7 +21,7 @@ //! kernels in arrow-rs such as `as_boolean_array` do. use crate::DataFusionError; -use arrow::array::{Array, Date32Array, StructArray}; +use arrow::array::{Array, Date32Array, Int32Array, StructArray}; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> { @@ -42,3 +42,13 @@ pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionErro )) }) } + +// Downcast ArrayRef to Int32Array +pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array, DataFusionError> { + array.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected a Int32Array, got: {}", + array.data_type() + )) + }) +} diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs index 6dee6f3d25a2..c411d4e5c7c2 100644 --- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs @@ -975,8 +975,9 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::avro_to_arrow::{Reader, ReaderBuilder}; - use arrow::array::{Int32Array, Int64Array, ListArray, TimestampMicrosecondArray}; + use arrow::array::{Int64Array, ListArray, TimestampMicrosecondArray}; use arrow::datatypes::DataType; + use datafusion_common::cast::as_int32_array; use std::fs::File; fn build_reader(name: &str, batch_size: usize) -> Reader { @@ -1080,11 +1081,7 @@ mod test { num_batches += 1; let batch_schema = batch.schema(); assert_eq!(schema, batch_schema); - let a_array = batch - .column(col_id_index) - .as_any() - .downcast_ref::() - .unwrap(); + let a_array = as_int32_array(batch.column(col_id_index)).unwrap(); sum_id += (0..a_array.len()).map(|i| a_array.value(i)).sum::(); } assert_eq!(8, sum_num_rows); diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 4a87eee9359f..e93848a76f06 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -93,9 +93,9 @@ mod tests { use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ - BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, - TimestampMicrosecondArray, + BinaryArray, BooleanArray, Float32Array, Float64Array, TimestampMicrosecondArray, }; + use datafusion_common::cast::as_int32_array; use futures::StreamExt; #[tokio::test] @@ -229,11 +229,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_int32_array(batches[0].column(0))?; let mut values: Vec = vec![]; for i in 0..batches[0].num_rows() { values.push(array.value(i)); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 3de58d456a89..d0ccf08aa1d0 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -587,11 +587,12 @@ mod tests { use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, - Int32Array, StringArray, TimestampNanosecondArray, + StringArray, TimestampNanosecondArray, }; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use bytes::Bytes; + use datafusion_common::cast::as_int32_array; use datafusion_common::ScalarValue; use futures::stream::BoxStream; use futures::StreamExt; @@ -975,11 +976,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_int32_array(batches[0].column(0))?; let mut values: Vec = vec![]; for i in 0..batches[0].num_rows() { values.push(array.value(i)); diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index 31c5969b8af5..f554b9424f98 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -21,6 +21,7 @@ use datafusion::{ execution::registry::FunctionRegistry, physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function}, }; +use datafusion_common::cast::as_int32_array; use datafusion_expr::{create_udaf, LogicalPlanBuilder}; /// test that casting happens on udfs. @@ -57,14 +58,8 @@ async fn scalar_udf() -> Result<()> { ctx.register_batch("t", batch)?; let myfunc = |args: &[ArrayRef]| { - let l = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let r = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); + let l = as_int32_array(&args[0])?; + let r = as_int32_array(&args[1])?; Ok(Arc::new(add(l, r)?) as ArrayRef) }; let myfunc = make_scalar_function(myfunc); @@ -113,21 +108,9 @@ async fn scalar_udf() -> Result<()> { assert_batches_eq!(expected, &result); let batch = &result[0]; - let a = batch - .column(0) - .as_any() - .downcast_ref::() - .expect("failed to cast a"); - let b = batch - .column(1) - .as_any() - .downcast_ref::() - .expect("failed to cast b"); - let sum = batch - .column(2) - .as_any() - .downcast_ref::() - .expect("failed to cast sum"); + let a = as_int32_array(batch.column(0))?; + let b = as_int32_array(batch.column(1))?; + let sum = as_int32_array(batch.column(2))?; assert_eq!(4, a.len()); assert_eq!(4, b.len()); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 383fb05c5d00..4fe284bee952 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -770,7 +770,7 @@ mod tests { datatypes::{DataType, Field, Schema}, }; use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::{DFField, ToDFSchema}; + use datafusion_common::{cast::as_int32_array, DFField, ToDFSchema}; use datafusion_expr::*; use datafusion_physical_expr::{ execution_props::ExecutionProps, functions::make_scalar_function, @@ -891,14 +891,8 @@ mod tests { let return_type = Arc::new(DataType::Int32); let fun = |args: &[ArrayRef]| { - let arg0 = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let arg1 = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); + let arg0 = as_int32_array(&args[0])?; + let arg1 = as_int32_array(&args[1])?; // 2. perform the computation let array = arg0 diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index e1da3dd2b210..613fb87d1b8d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -395,6 +395,7 @@ mod tests { use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; + use datafusion_common::cast::as_int32_array; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; @@ -417,10 +418,7 @@ mod tests { schema.as_ref(), )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = result - .as_any() - .downcast_ref::() - .expect("failed to downcast to Int32Array"); + let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -448,10 +446,7 @@ mod tests { schema.as_ref(), )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = result - .as_any() - .downcast_ref::() - .expect("failed to downcast to Int32Array"); + let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); @@ -523,10 +518,7 @@ mod tests { schema.as_ref(), )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = result - .as_any() - .downcast_ref::() - .expect("failed to downcast to Int32Array"); + let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -605,10 +597,7 @@ mod tests { schema.as_ref(), )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = result - .as_any() - .downcast_ref::() - .expect("failed to downcast to Int32Array"); + let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 7957aaef240e..c2ee3e11c02f 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -130,6 +130,7 @@ mod tests { use super::*; use arrow::array::Int32Array; use arrow::datatypes::*; + use datafusion_common::cast::as_int32_array; use datafusion_common::Result; #[test] @@ -144,7 +145,7 @@ mod tests { assert_eq!("42", format!("{}", literal_expr)); let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows()); - let literal_array = literal_array.as_any().downcast_ref::().unwrap(); + let literal_array = as_int32_array(&literal_array)?; // note that the contents of the literal array are unrelated to the batch contents except for the length of the array assert_eq!(literal_array.len(), 5); // 5 rows in the batch diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 772a02097bfd..a1efae79c948 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -224,7 +224,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use datafusion_common::Result; + use datafusion_common::{cast::as_int32_array, Result}; #[test] fn scatter_int() -> Result<()> { @@ -235,7 +235,7 @@ mod tests { let expected = Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); let result = scatter(&mask, truthy.as_ref())?; - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_int32_array(&result)?; assert_eq!(&expected, result); Ok(()) @@ -250,7 +250,7 @@ mod tests { let expected = Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); let result = scatter(&mask, truthy.as_ref())?; - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_int32_array(&result)?; assert_eq!(&expected, result); Ok(()) @@ -266,7 +266,7 @@ mod tests { // output should treat nulls as though they are false let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); let result = scatter(&mask, truthy.as_ref())?; - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_int32_array(&result)?; assert_eq!(&expected, result); Ok(()) diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f7e0226aa826..c50df3c1c9b3 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -182,6 +182,7 @@ mod tests { use crate::expressions::Column; use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + use datafusion_common::cast::as_int32_array; use datafusion_common::Result; fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { @@ -191,7 +192,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?; assert_eq!(1, result.len()); - let result = result[0].as_any().downcast_ref::().unwrap(); + let result = as_int32_array(&result[0])?; assert_eq!(expected, *result); Ok(()) } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 14ce53621bde..e9988032cfee 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -173,6 +173,7 @@ mod tests { use crate::expressions::Column; use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + use datafusion_common::cast::as_int32_array; use datafusion_common::Result; fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { @@ -194,7 +195,7 @@ mod tests { .into_iter() .collect::>>()?; let result = ScalarValue::iter_to_array(result.into_iter())?; - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) } diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 358b6a919d8e..ab168961e056 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -24,5 +24,6 @@ edition = "2021" [dependencies] arrow = { version = "26.0.0", features = ["prettyprint"] } +datafusion-common = { path = "../datafusion/common", version = "14.0.0" } env_logger = "0.9.0" rand = "0.8" diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 7f0e5ef0770d..5c3b64574a68 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -16,7 +16,8 @@ // under the License. //! Common functions used for testing -use arrow::{array::Int32Array, record_batch::RecordBatch}; +use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_int32_array; use rand::prelude::StdRng; use rand::Rng; @@ -32,12 +33,7 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { .iter() .flat_map(|batch| { assert_eq!(batch.num_columns(), 1); - batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .iter() + as_int32_array(batch.column(0)).unwrap().iter() }) .collect() }