Skip to content

Commit

Permalink
Rewrite array_ndims to fix List(Null) handling (#8320)
Browse files Browse the repository at this point in the history
* done

Signed-off-by: jayzhan211 <[email protected]>

* add more test

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jayzhan211 and alamb authored Dec 1, 2023
1 parent eb8aff7 commit f5d10e5
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 53 deletions.
32 changes: 32 additions & 0 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, LargeListArray, ListArray};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -402,6 +403,37 @@ pub fn arrays_into_list_array(
))
}

/// Get the base type of a data type.
///
/// Example
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use datafusion_common::utils::base_type;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// assert_eq!(base_type(&data_type), DataType::Int32);
///
/// let data_type = DataType::Int32;
/// assert_eq!(base_type(&data_type), DataType::Int32);
/// ```
pub fn base_type(data_type: &DataType) -> DataType {
if let DataType::List(field) = data_type {
base_type(field.data_type())
} else {
data_type.to_owned()
}
}

/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
if let DataType::List(field) = data_type {
1 + list_ndims(field.data_type())
} else {
0
}
}

/// An extension trait for smart pointers. Provides an interface to get a
/// raw pointer to the data (with metadata stripped away).
///
Expand Down
76 changes: 27 additions & 49 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_common::cast::{
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
as_null_array, as_string_array,
};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
DataFusionError, Result,
Expand Down Expand Up @@ -103,6 +103,7 @@ fn compare_element_to_list(
) -> Result<BooleanArray> {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;

// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let res = match element_array_row.data_type() {
Expand Down Expand Up @@ -176,35 +177,6 @@ fn compute_array_length(
}
}

/// Returns the dimension of the array
fn compute_array_ndims(arr: Option<ArrayRef>) -> Result<Option<u64>> {
Ok(compute_array_ndims_with_datatype(arr)?.0)
}

/// Returns the dimension and the datatype of elements of the array
fn compute_array_ndims_with_datatype(
arr: Option<ArrayRef>,
) -> Result<(Option<u64>, DataType)> {
let mut res: u64 = 1;
let mut value = match arr {
Some(arr) => arr,
None => return Ok((None, DataType::Null)),
};
if value.is_empty() {
return Ok((None, DataType::Null));
}

loop {
match value.data_type() {
DataType::List(..) => {
value = downcast_arg!(value, ListArray).value(0);
res += 1;
}
data_type => return Ok((Some(res), data_type.clone())),
}
}
}

/// Returns the length of each array dimension
fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>> {
let mut value = match arr {
Expand Down Expand Up @@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
let args_ndim = args
.iter()
.map(|arg| compute_array_ndims(Some(arg.to_owned())))
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|x| x.unwrap_or(0))
.map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
.collect::<Vec<_>>();
let max_ndim = args_ndim.iter().max().unwrap_or(&0);

Expand Down Expand Up @@ -919,18 +888,19 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);

Ok(Arc::new(list_arr))
}

/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut new_args = vec![];
for arg in args {
let (ndim, lower_data_type) =
compute_array_ndims_with_datatype(Some(arg.clone()))?;
if ndim.is_none() || ndim == Some(1) {
return not_impl_err!("Array is not type '{lower_data_type:?}'.");
} else if !lower_data_type.equals_datatype(&DataType::Null) {
let ndim = list_ndims(arg.data_type());
let base_type = datafusion_common::utils::base_type(arg.data_type());
if ndim == 0 {
return not_impl_err!("Array is not type '{base_type:?}'.");
} else if !base_type.eq(&DataType::Null) {
new_args.push(arg.clone());
}
}
Expand Down Expand Up @@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
if let Some(list_array) = args[0].as_list_opt::<i32>() {
let ndims = datafusion_common::utils::list_ndims(list_array.data_type());

let result = list_array
.iter()
.map(compute_array_ndims)
.collect::<Result<UInt64Array>>()?;
let mut data = vec![];
for arr in list_array.iter() {
if arr.is_some() {
data.push(Some(ndims))
} else {
data.push(None)
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
} else {
Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef)
}
}

/// Array_has SQL function
Expand Down Expand Up @@ -2034,10 +2012,10 @@ mod tests {
.unwrap();

let expected = as_list_array(&array2d_1).unwrap();
let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
compute_array_ndims(Some(res[0].clone())).unwrap(),
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);

Expand All @@ -2047,10 +2025,10 @@ mod tests {
align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap();

let expected = as_list_array(&array3d_1).unwrap();
let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
compute_array_ndims(Some(res[0].clone())).unwrap(),
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);
}
Expand Down
42 changes: 38 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2479,10 +2479,44 @@ NULL [3] [4]
## array_ndims (aliases: `list_ndims`)

# array_ndims scalar function #1

query III
select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]]));
select
array_ndims(1),
array_ndims(null),
array_ndims([2, 3]);
----
1 2 5
0 0 1

statement ok
CREATE TABLE array_ndims_table
AS VALUES
(1, [1, 2, 3], [[7]], [[[[[10]]]]]),
(2, [4, 5], [[8]], [[[[[10]]]]]),
(null, [6], [[9]], [[[[[10]]]]]),
(3, [6], [[9]], [[[[[10]]]]])
;

query IIII
select
array_ndims(column1),
array_ndims(column2),
array_ndims(column3),
array_ndims(column4)
from array_ndims_table;
----
0 1 2 5
0 1 2 5
0 1 2 5
0 1 2 5

statement ok
drop table array_ndims_table;

query I
select array_ndims(arrow_cast([null], 'List(List(List(Int64)))'));
----
3

# array_ndims scalar function #2
query II
Expand All @@ -2494,7 +2528,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_
query II
select array_ndims(make_array()), array_ndims(make_array(make_array()))
----
NULL 2
1 2

# list_ndims scalar function #4 (function alias `array_ndims`)
query III
Expand All @@ -2505,7 +2539,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])),
query II
select array_ndims(make_array()), array_ndims(make_array(make_array()))
----
NULL 2
1 2

# array_ndims with columns
query III
Expand Down

0 comments on commit f5d10e5

Please sign in to comment.