From 5c0c619a398f39c95495078508db44f43dc1eb62 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 5 Dec 2023 21:43:41 +0100 Subject: [PATCH] feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` (#8322) * support LargeList for array_has, array_has_all and array_has_any * simplify the code --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/array_expressions.rs | 143 +++++++++++------- datafusion/sqllogictest/test_files/array.slt | 111 ++++++++++++++ 2 files changed, 201 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9489a51fa385..6104566450c3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1765,82 +1765,119 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } } -/// Array_has SQL function -pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; - let element = &args[1]; +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, +} + +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; - check_datatypes("array_has", &[array.values(), element])?; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let r_values = converter.convert_columns(&[element.clone()])?; - for (row_idx, arr) in array.iter().enumerate() { - if let Some(arr) = arr { + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { let arr_values = converter.convert_columns(&[arr])?; - let res = arr_values - .iter() - .dedup() - .any(|x| x == r_values.row(row_idx)); + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -/// Array_has_any SQL function -pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => internal_err!("array_has does not support type '{array_type:?}'."), + } +} - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let mut res = false; - for elem in sub_arr_values.iter().dedup() { - res |= arr_values.iter().dedup().any(|x| x == elem); - if res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let array_type = args[0].data_type(); - let mut res = true; - for elem in sub_arr_values.iter().dedup() { - res &= arr_values.iter().dedup().any(|x| x == elem); - if !res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6ec2b2cb013b..d8bf441d7169 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2621,6 +2621,23 @@ select array_has(make_array(1,2), 1), ---- true true true true true false true false true false true false +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2630,6 +2647,15 @@ from array_has_table_1D; true true true false false false +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2639,6 +2665,15 @@ from array_has_table_1D_Float; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2648,6 +2683,15 @@ from array_has_table_1D_Boolean; false true true true true true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2657,6 +2701,15 @@ from array_has_table_1D_UTF8; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + query BB select array_has(column1, column2), array_has_all(column3, column4) @@ -2665,6 +2718,14 @@ from array_has_table_2D; false true true false +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + query B select array_has_all(column1, column2) from array_has_table_2D_float; @@ -2672,6 +2733,13 @@ from array_has_table_2D_float; true false +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + query B select array_has(column1, column2) from array_has_table_3D; ---- @@ -2683,6 +2751,17 @@ true false true +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -2697,6 +2776,20 @@ false true false false false false false false false false false false +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), array_has_all(make_array(1,2,3), make_array(1,4)), @@ -2715,6 +2808,24 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + query ??? select array_intersect(column1, column2), array_intersect(column3, column4),