Skip to content

Commit

Permalink
feat: support LargeList for array_has, array_has_all and `array…
Browse files Browse the repository at this point in the history
…_has_any` (apache#8322)

* support LargeList for array_has, array_has_all and array_has_any

* simplify the code

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and appletreeisyellow committed Dec 15, 2023
1 parent 11c5bb8 commit 5c0c619
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 53 deletions.
143 changes: 90 additions & 53 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1765,82 +1765,119 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let array = as_list_array(&args[0])?;
let element = &args[1];
/// Represents the type of comparison for array_has.
#[derive(Debug, PartialEq)]
enum ComparisonType {
// array_has_all
All,
// array_has_any
Any,
// array_has
Single,
}

fn general_array_has_dispatch<O: OffsetSizeTrait>(
array: &ArrayRef,
sub_array: &ArrayRef,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let array = if comparison_type == ComparisonType::Single {
let arr = as_generic_list_array::<O>(array)?;
check_datatypes("array_has", &[arr.values(), sub_array])?;
arr
} else {
check_datatypes("array_has", &[array, sub_array])?;
as_generic_list_array::<O>(array)?
};

check_datatypes("array_has", &[array.values(), element])?;
let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let r_values = converter.convert_columns(&[element.clone()])?;
for (row_idx, arr) in array.iter().enumerate() {
if let Some(arr) = arr {

let element = sub_array.clone();
let sub_array = if comparison_type != ComparisonType::Single {
as_generic_list_array::<O>(sub_array)?
} else {
array
};

for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let res = arr_values
.iter()
.dedup()
.any(|x| x == r_values.row(row_idx));
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
}

boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_any SQL function
pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_has_any", &[&args[0], &args[1]])?;
/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let array_type = args[0].data_type();

let array = as_list_array(&args[0])?;
let sub_array = as_list_array(&args[1])?;
let mut boolean_builder = BooleanArray::builder(array.len());
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::Single)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Single)
}
_ => internal_err!("array_has does not support type '{array_type:?}'."),
}
}

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
/// Array_has_any SQL function
pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
let array_type = args[0].data_type();

let mut res = false;
for elem in sub_arr_values.iter().dedup() {
res |= arr_values.iter().dedup().any(|x| x == elem);
if res {
break;
}
}
boolean_builder.append_value(res);
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
}
_ => internal_err!("array_has_any does not support type '{array_type:?}'."),
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_all SQL function
pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_has_all", &[&args[0], &args[1]])?;

let array = as_list_array(&args[0])?;
let sub_array = as_list_array(&args[1])?;

let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
let array_type = args[0].data_type();

let mut res = true;
for elem in sub_arr_values.iter().dedup() {
res &= arr_values.iter().dedup().any(|x| x == elem);
if !res {
break;
}
}
boolean_builder.append_value(res);
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
}
_ => internal_err!("array_has_all does not support type '{array_type:?}'."),
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Splits string at occurrences of delimiter and returns an array of parts
Expand Down
111 changes: 111 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,23 @@ select array_has(make_array(1,2), 1),
----
true true true true true false true false true false true false

query BBBBBBBBBBBB
select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1),
array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1),
array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])),
array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])),
array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])),
array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])),
list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4),
array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3),
list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0)
;
----
true true true true true false true false true false true false

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2630,6 +2647,15 @@ from array_has_table_1D;
true true true
false false false

query BBB
select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')),
array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)'))
from array_has_table_1D;
----
true true true
false false false

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2639,6 +2665,15 @@ from array_has_table_1D_Float;
true true false
false false true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')),
array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)'))
from array_has_table_1D_Float;
----
true true false
false false true

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2648,6 +2683,15 @@ from array_has_table_1D_Boolean;
false true true
true true true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')),
array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)'))
from array_has_table_1D_Boolean;
----
false true true
true true true

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2657,6 +2701,15 @@ from array_has_table_1D_UTF8;
true true false
false false true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')),
array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)'))
from array_has_table_1D_UTF8;
----
true true false
false false true

query BB
select array_has(column1, column2),
array_has_all(column3, column4)
Expand All @@ -2665,13 +2718,28 @@ from array_has_table_2D;
false true
true false

query BB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2),
array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))'))
from array_has_table_2D;
----
false true
true false

query B
select array_has_all(column1, column2)
from array_has_table_2D_float;
----
true
false

query B
select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))'))
from array_has_table_2D_float;
----
true
false

query B
select array_has(column1, column2) from array_has_table_3D;
----
Expand All @@ -2683,6 +2751,17 @@ true
false
true

query B
select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D;
----
false
true
false
false
true
false
true

query BBBB
select array_has(column1, make_array(5, 6)),
array_has(column1, make_array(7, NULL)),
Expand All @@ -2697,6 +2776,20 @@ false true false false
false false false false
false false false false

query BBBB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)),
array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)),
array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5),
array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o')
from arrays;
----
false false false true
true false true false
true false false true
false true false false
false false false false
false false false false

query BBBBBBBBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
array_has_all(make_array(1,2,3), make_array(1,4)),
Expand All @@ -2715,6 +2808,24 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
----
true false true false false false true true false false true false true

query BBBBBBBBBBBBB
select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))'))
;
----
true false true false false false true true false false true false true

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
Expand Down

0 comments on commit 5c0c619

Please sign in to comment.