Skip to content

Commit

Permalink
Fix hash join for nested types (apache#11232)
Browse files Browse the repository at this point in the history
* Fixes to 10749 and generalization

* Add e2e tests for joins on struct

* PR comments

* Add Struct to can_hash method

* Add explain query as well

* Use EXCEPT to trigger failure

* Update datafusion/sqllogictest/test_files/joins.slt

Co-authored-by: Liang-Chi Hsieh <[email protected]>

---------

Co-authored-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
2 people authored and comphead committed Jul 8, 2024
1 parent 03208dd commit df72337
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 3 deletions.
1 change: 1 addition & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool {
DataType::List(_) => true,
DataType::LargeList(_) => true,
DataType::FixedSizeList(_, _) => true,
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
_ => false,
}
}
Expand Down
111 changes: 108 additions & 3 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1212,11 +1212,16 @@ fn eq_dyn_null(
right: &dyn Array,
null_equals_null: bool,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct function and must use a special
// Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() && null_equals_null {
return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?);
if left.data_type().is_nested() {
let op = if null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};
return Ok(compare_op_for_nested(&op, &left, &right)?);
}
match (left.data_type(), right.data_type()) {
_ if null_equals_null => not_distinct(&left, &right),
Expand Down Expand Up @@ -1546,6 +1551,8 @@ mod tests {

use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder};
use arrow::datatypes::{DataType, Field};
use arrow_array::StructArray;
use arrow_buffer::NullBuffer;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err,
ScalarValue,
Expand Down Expand Up @@ -3844,6 +3851,104 @@ mod tests {
Ok(())
}

fn build_table_struct(
struct_name: &str,
field_name_and_values: (&str, &Vec<Option<i32>>),
nulls: Option<NullBuffer>,
) -> Arc<dyn ExecutionPlan> {
let (field_name, values) = field_name_and_values;
let inner_fields = vec![Field::new(field_name, DataType::Int32, true)];
let schema = Schema::new(vec![Field::new(
struct_name,
DataType::Struct(inner_fields.clone().into()),
nulls.is_some(),
)]);

let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(StructArray::new(
inner_fields.into(),
vec![Arc::new(Int32Array::from(values.clone()))],
nulls,
))],
)
.unwrap();
let schema_ref = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap())
}

#[tokio::test]
async fn join_on_struct() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None);
let right =
build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None);
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (columns, batches) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

assert_eq!(columns, vec!["n1", "n2"]);

let expected = [
"+--------+--------+",
"| n1 | n2 |",
"+--------+--------+",
"| {a: } | {a: } |",
"| {a: 1} | {a: 1} |",
"| {a: 2} | {a: 2} |",
"+--------+--------+",
];
assert_batches_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_on_struct_with_nulls() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let right =
build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (_, batches_null_eq) = join_collect(
left.clone(),
right.clone(),
on.clone(),
&JoinType::Inner,
true,
task_ctx.clone(),
)
.await?;

let expected_null_eq = [
"+----+----+",
"| n1 | n2 |",
"+----+----+",
"| | |",
"+----+----+",
];
assert_batches_eq!(expected_null_eq, &batches_null_eq);

let (_, batches_null_neq) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

let expected_null_neq =
["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
assert_batches_eq!(expected_null_neq, &batches_null_neq);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
52 changes: 52 additions & 0 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ AS VALUES
(44, 'x', 3),
(55, 'w', 3);

statement ok
CREATE TABLE join_t3(s3 struct<id INT>)
AS VALUES
(NULL),
(struct(1)),
(struct(2));

statement ok
CREATE TABLE join_t4(s4 struct<id INT>)
AS VALUES
(NULL),
(struct(2)),
(struct(3));

# Left semi anti join

statement ok
Expand Down Expand Up @@ -1336,6 +1350,44 @@ physical_plan
10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)------------MemoryExec: partitions=1, partition_sizes=[1]

# Join on struct
query TT
explain select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
logical_plan
01)Inner Join: join_t3.s3 = join_t4.s4
02)--TableScan: join_t3 projection=[s3]
03)--TableScan: join_t4 projection=[s4]
physical_plan
01)CoalesceBatchesExec: target_batch_size=2
02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)]
03)----CoalesceBatchesExec: target_batch_size=2
04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2
05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
06)----------MemoryExec: partitions=1, partition_sizes=[1]
07)----CoalesceBatchesExec: target_batch_size=2
08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2
09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
10)----------MemoryExec: partitions=1, partition_sizes=[1]

query ??
select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
{id: 2} {id: 2}

# join with struct key and nulls
# Note that intersect or except applies `null_equals_null` as true for Join.
query ?
SELECT * FROM join_t3
EXCEPT
SELECT * FROM join_t4
----
{id: 1}

query TT
EXPLAIN
select count(*)
Expand Down

0 comments on commit df72337

Please sign in to comment.