Skip to content

Commit

Permalink
improve error messages while downcasting uint and boolean array (#4261)
Browse files Browse the repository at this point in the history
  • Loading branch information
retikulum authored Nov 20, 2022
1 parent 880e6fc commit 712b9fd
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 130 deletions.
37 changes: 35 additions & 2 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

use crate::DataFusionError;
use arrow::array::{
Array, Date32Array, Decimal128Array, Float32Array, Float64Array, Int32Array,
Int64Array, StringArray, StructArray,
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array,
};

// Downcast ArrayRef to Date32Array
Expand Down Expand Up @@ -116,3 +116,36 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionErro
))
})
}

// Downcast ArrayRef to UInt32Array
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> {
array.as_any().downcast_ref::<UInt32Array>().ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a UInt32Array, got: {}",
array.data_type()
))
})
}

// Downcast ArrayRef to UInt64Array
pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> {
array.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a UInt64Array, got: {}",
array.data_type()
))
})
}

// Downcast ArrayRef to BooleanArray
pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> {
array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a BooleanArray, got: {}",
array.data_type()
))
})
}
24 changes: 12 additions & 12 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2657,7 +2657,7 @@ mod tests {
use arrow::compute::kernels;
use arrow::datatypes::ArrowPrimitiveType;

use crate::cast::as_string_array;
use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
use crate::from_slice::FromSlice;

use super::*;
Expand Down Expand Up @@ -2792,35 +2792,37 @@ mod tests {
}

#[test]
fn scalar_value_to_array_u64() {
fn scalar_value_to_array_u64() -> Result<()> {
let value = ScalarValue::UInt64(Some(13u64));
let array = value.to_array();
let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
let array = as_uint64_array(&array)?;
assert_eq!(array.len(), 1);
assert!(!array.is_null(0));
assert_eq!(array.value(0), 13);

let value = ScalarValue::UInt64(None);
let array = value.to_array();
let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
let array = as_uint64_array(&array)?;
assert_eq!(array.len(), 1);
assert!(array.is_null(0));
Ok(())
}

#[test]
fn scalar_value_to_array_u32() {
fn scalar_value_to_array_u32() -> Result<()> {
let value = ScalarValue::UInt32(Some(13u32));
let array = value.to_array();
let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
let array = as_uint32_array(&array)?;
assert_eq!(array.len(), 1);
assert!(!array.is_null(0));
assert_eq!(array.value(0), 13);

let value = ScalarValue::UInt32(None);
let array = value.to_array();
let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
let array = as_uint32_array(&array)?;
assert_eq!(array.len(), 1);
assert!(array.is_null(0));
Ok(())
}

#[test]
Expand All @@ -2838,7 +2840,7 @@ mod tests {
}

#[test]
fn scalar_list_to_array() {
fn scalar_list_to_array() -> Result<()> {
let list_array_ref = ScalarValue::List(
Some(vec![
ScalarValue::UInt64(Some(100)),
Expand All @@ -2854,14 +2856,12 @@ mod tests {
assert_eq!(list_array.values().len(), 3);

let prim_array_ref = list_array.value(0);
let prim_array = prim_array_ref
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let prim_array = as_uint64_array(&prim_array_ref)?;
assert_eq!(prim_array.len(), 3);
assert_eq!(prim_array.value(0), 100);
assert!(prim_array.is_null(1));
assert_eq!(prim_array.value(2), 101);
Ok(())
}

/// Creates array directly and via ScalarValue and ensures they are the same
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ mod tests {
use crate::datasource::file_format::test_util::scan_format;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{BinaryArray, BooleanArray, TimestampMicrosecondArray};
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
use arrow::array::{BinaryArray, TimestampMicrosecondArray};
use datafusion_common::cast::{
as_boolean_array, as_float32_array, as_float64_array, as_int32_array,
};
use futures::StreamExt;

#[tokio::test]
Expand Down Expand Up @@ -197,11 +199,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
let array = as_boolean_array(batches[0].column(0))?;
let mut values: Vec<bool> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,14 @@ mod tests {
use crate::physical_plan::metrics::MetricValue;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{
Array, ArrayRef, BinaryArray, BooleanArray, StringArray, TimestampNanosecondArray,
Array, ArrayRef, BinaryArray, StringArray, TimestampNanosecondArray,
};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use bytes::Bytes;
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
use datafusion_common::cast::{
as_boolean_array, as_float32_array, as_float64_array, as_int32_array,
};
use datafusion_common::ScalarValue;
use futures::stream::BoxStream;
use futures::StreamExt;
Expand Down Expand Up @@ -945,11 +947,7 @@ mod tests {
assert_eq!(1, batches[0].num_columns());
assert_eq!(8, batches[0].num_rows());

let array = batches[0]
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
let array = as_boolean_array(batches[0].column(0))?;
let mut values: Vec<bool> = vec![];
for i in 0..batches[0].num_rows() {
values.push(array.value(i));
Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use arrow::{
array::{
Array, ArrayBuilder, ArrayRef, Date64Array, Date64Builder, StringBuilder,
UInt64Array, UInt64Builder,
UInt64Builder,
},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
Expand All @@ -38,7 +38,10 @@ use crate::{

use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use datafusion_common::{cast::as_string_array, Column, DataFusionError};
use datafusion_common::{
cast::{as_string_array, as_uint64_array},
Column, DataFusionError,
};
use datafusion_expr::{
expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
Expr, Volatility,
Expand Down Expand Up @@ -300,11 +303,7 @@ fn batches_to_paths(batches: &[RecordBatch]) -> Result<Vec<PartitionedFile>> {
.iter()
.flat_map(|batch| {
let key_array = as_string_array(batch.column(0)).unwrap();
let length_array = batch
.column(1)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let length_array = as_uint64_array(batch.column(1)).unwrap();
let modified_array = batch
.column(2)
.as_any()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow::array::{Array, BooleanArray};
use arrow::datatypes::{DataType, Schema};
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use std::collections::BTreeSet;
Expand Down Expand Up @@ -134,17 +135,12 @@ impl ArrowPredicate for DatafusionArrowPredicate {
.map(|v| v.into_array(batch.num_rows()))
{
Ok(array) => {
if let Some(mask) = array.as_any().downcast_ref::<BooleanArray>() {
let bool_arr = BooleanArray::from(mask.data().clone());
let num_filtered = bool_arr.len() - bool_arr.true_count();
self.rows_filtered.add(num_filtered);
timer.stop();
Ok(bool_arr)
} else {
Err(ArrowError::ComputeError(
"Unexpected result of predicate evaluation, expected BooleanArray".to_owned(),
))
}
let mask = as_boolean_array(&array)?;
let bool_arr = BooleanArray::from(mask.data().clone());
let num_filtered = bool_arr.len() - bool_arr.true_count();
self.rows_filtered.add(num_filtered);
timer.stop();
Ok(bool_arr)
}
Err(e) => Err(ArrowError::ComputeError(format!(
"Error evaluating filter predicate: {:?}",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use std::{time::Instant, vec};

use futures::{ready, Stream, StreamExt, TryStreamExt};

use arrow::array::{as_boolean_array, new_null_array, Array};
use arrow::array::{new_null_array, Array};
use arrow::datatypes::{ArrowNativeType, DataType};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
Expand All @@ -52,7 +52,7 @@ use arrow::array::{
UInt8Array,
};

use datafusion_common::cast::as_string_array;
use datafusion_common::cast::{as_boolean_array, as_string_array};

use hashbrown::raw::RawTable;

Expand Down Expand Up @@ -1027,7 +1027,7 @@ fn apply_join_filter(
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows());
let mask = as_boolean_array(&filter_result);
let mask = as_boolean_array(&filter_result)?;

let left_filtered = PrimitiveArray::<UInt64Type>::from(
compute::filter(&left_indices, mask)?.data().clone(),
Expand All @@ -1050,7 +1050,7 @@ fn apply_join_filter(
.expression()
.evaluate_selection(&intermediate_batch, &has_match)?
.into_array(intermediate_batch.num_rows());
let mask = as_boolean_array(&filter_result);
let mask = as_boolean_array(&filter_result)?;

let mut left_rebuilt = UInt64Builder::with_capacity(0);
let mut right_rebuilt = UInt32Builder::with_capacity(0);
Expand Down
28 changes: 9 additions & 19 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ use arrow::record_batch::RecordBatch;

use crate::physical_expr::down_cast_any_ref;
use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr};
use datafusion_common::cast::as_decimal128_array;
use datafusion_common::cast::{as_boolean_array, as_decimal128_array};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::type_coercion::binary::binary_operator_data_type;
Expand Down Expand Up @@ -472,14 +472,8 @@ macro_rules! binary_array_op {
/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
let ll = as_boolean_array($LEFT).expect("boolean_op failed to downcast array");
let rr = as_boolean_array($RIGHT).expect("boolean_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}
Expand Down Expand Up @@ -1003,7 +997,7 @@ impl BinaryExpr {
Operator::Modulo => binary_primitive_array_op!(left, right, modulus),
Operator::And => {
if left_data_type == &DataType::Boolean {
boolean_op!(left, right, and_kleene)
boolean_op!(&left, &right, and_kleene)
} else {
Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
Expand All @@ -1015,7 +1009,7 @@ impl BinaryExpr {
}
Operator::Or => {
if left_data_type == &DataType::Boolean {
boolean_op!(left, right, or_kleene)
boolean_op!(&left, &right, or_kleene)
} else {
Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
Expand Down Expand Up @@ -1110,10 +1104,8 @@ mod tests {
assert_eq!(result.len(), 5);

let expected = vec![false, false, true, true, true];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
let result =
as_boolean_array(&result).expect("failed to downcast to BooleanArray");
for (i, &expected_item) in expected.iter().enumerate().take(5) {
assert_eq!(result.value(i), expected_item);
}
Expand Down Expand Up @@ -1156,10 +1148,8 @@ mod tests {
assert_eq!(result.len(), 5);

let expected = vec![true, true, false, true, false];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
let result =
as_boolean_array(&result).expect("failed to downcast to BooleanArray");
for (i, &expected_item) in expected.iter().enumerate().take(5) {
assert_eq!(result.value(i), expected_item);
}
Expand Down
7 changes: 2 additions & 5 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result};
use datafusion_expr::ColumnarValue;

use itertools::Itertools;
Expand Down Expand Up @@ -195,10 +195,7 @@ impl CaseExpr {
_ => when_value,
};
let when_value = when_value.into_array(batch.num_rows());
let when_value = when_value
.as_ref()
.as_any()
.downcast_ref::<BooleanArray>()
let when_value = as_boolean_array(&when_value)
.expect("WHEN expression did not return a BooleanArray");

let then_value = self.when_then_expr[i]
Expand Down
Loading

0 comments on commit 712b9fd

Please sign in to comment.