Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ScalarValue handling of NULL values for ListArray #7969

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 100 additions & 25 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,10 +1312,11 @@ impl ScalarValue {
Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
scalars.into_iter().map(|x| match x {
ScalarValue::List(arr) => {
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);
if list_arr.is_null(0) {
None
} else {
let list_arr = as_list_array(&arr);
let primitive_arr =
list_arr.values().as_primitive::<$ARRAY_TY>();
Some(
Expand All @@ -1339,12 +1340,14 @@ impl ScalarValue {
for scalar in scalars.into_iter() {
match scalar {
ScalarValue::List(arr) => {
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);

if list_arr.is_null(0) {
builder.append(false);
continue;
}

let list_arr = as_list_array(&arr);
let string_arr = $STRING_ARRAY(list_arr.values());

for v in string_arr.iter() {
Expand Down Expand Up @@ -1699,15 +1702,16 @@ impl ScalarValue {

for scalar in scalars {
if let ScalarValue::List(arr) = scalar {
// i.e. NullArray(1)
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);

if list_arr.is_null(0) {
// Repeat previous offset index
offsets.push(0);

// Element is null
valid.append(false);
} else {
let list_arr = as_list_array(&arr);
let arr = list_arr.values().to_owned();
offsets.push(arr.len());
elements.push(arr);
Expand Down Expand Up @@ -2234,28 +2238,20 @@ impl ScalarValue {
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
DataType::List(nested_type) => {
DataType::List(_) => {
let list_array = as_list_array(array);
let arr = match list_array.is_null(index) {
true => new_null_array(nested_type.data_type(), 0),
Copy link
Member Author

@viirya viirya Oct 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition is redundant and not reached. try_from_array handles NULL value at the beginning.

false => {
let nested_array = list_array.value(index);
Arc::new(wrap_into_list_array(nested_array))
}
};
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
}
// TODO: There is no test for FixedSizeList now, add it later
DataType::FixedSizeList(nested_type, _len) => {
Comment on lines 2249 to -2250
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add some tests for FixedSizeList. I can do it in other follow up.

DataType::FixedSizeList(_, _) => {
let list_array = as_fixed_size_list_array(array)?;
let arr = match list_array.is_null(index) {
true => new_null_array(nested_type.data_type(), 0),
false => {
let nested_array = list_array.value(index);
Arc::new(wrap_into_list_array(nested_array))
}
};
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
}
Expand Down Expand Up @@ -2944,8 +2940,15 @@ impl TryFrom<&DataType> for ScalarValue {
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
DataType::List(_) => ScalarValue::List(new_null_array(&DataType::Null, 0)),

// `ScalaValue::List` contains single element `ListArray`.
DataType::List(field) => ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
))),
1,
)),
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
DataType::Null => ScalarValue::Null,
_ => {
Expand Down Expand Up @@ -3885,6 +3888,78 @@ mod tests {
);
}

#[test]
fn scalar_try_from_array_list_array_null() {
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
]);

let non_null_list_scalar = ScalarValue::try_from_array(&list, 0).unwrap();
Copy link
Member Author

@viirya viirya Oct 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, for a NULL value from ListArray, the generated ScalarValue::List value's datatype is DataType::Null. That's is not correct.

let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap();

let data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));

assert_eq!(non_null_list_scalar.data_type(), data_type.clone());
assert_eq!(null_list_scalar.data_type(), data_type);
}

#[test]
fn scalar_try_from_list() {
let data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let data_type = &data_type;
let scalar: ScalarValue = data_type.try_into().unwrap();

let expected = ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
1,
));

assert_eq!(expected, scalar)
}

#[test]
fn scalar_try_from_list_of_list() {
let data_type = DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
)));
let data_type = &data_type;
let scalar: ScalarValue = data_type.try_into().unwrap();

let expected = ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
))),
1,
));

assert_eq!(expected, scalar)
}

#[test]
fn scalar_try_from_not_equal_list_nested_list() {
let list_data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let data_type = &list_data_type;
let list_scalar: ScalarValue = data_type.try_into().unwrap();

let nested_list_data_type = DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
)));
let data_type = &nested_list_data_type;
let nested_list_scalar: ScalarValue = data_type.try_into().unwrap();

assert_ne!(list_scalar, nested_list_scalar);
Comment on lines +3946 to +3960
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the ScalarValues for list and nested list NULL values are equal, that is incorrect.

}

#[test]
fn scalar_try_from_dict_datatype() {
let data_type =
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10)
;

query TTT
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays;
----
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

# arrays table
query ???
select column1, column2, column3 from arrays;
Expand Down