diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e3b8db676c98..34e007207427 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::List(_) => true, DataType::LargeList(_) => true, DataType::FixedSizeList(_, _) => true, + DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), _ => false, } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index b2f9ef560745..c6ef9936b9c5 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1212,11 +1212,16 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { - // Nested datatypes cannot use the underlying not_distinct function and must use a special + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special // implementation // - if left.data_type().is_nested() && null_equals_null { - return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?); + if left.data_type().is_nested() { + let op = if null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + return Ok(compare_op_for_nested(&op, &left, &right)?); } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), @@ -1546,6 +1551,8 @@ mod tests { use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field}; + use arrow_array::StructArray; + use arrow_buffer::NullBuffer; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, ScalarValue, @@ -3844,6 +3851,104 @@ mod tests { Ok(()) } + fn build_table_struct( + struct_name: &str, + field_name_and_values: (&str, &Vec>), + nulls: Option, + ) -> Arc { + let (field_name, values) = field_name_and_values; + let inner_fields = vec![Field::new(field_name, DataType::Int32, true)]; + let schema = Schema::new(vec![Field::new( + struct_name, + DataType::Struct(inner_fields.clone().into()), + nulls.is_some(), + )]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StructArray::new( + inner_fields.into(), + vec![Arc::new(Int32Array::from(values.clone()))], + nulls, + ))], + ) + .unwrap(); + let schema_ref = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap()) + } + + #[tokio::test] + async fn join_on_struct() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None); + let right = + build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["n1", "n2"]); + + let expected = [ + "+--------+--------+", + "| n1 | n2 |", + "+--------+--------+", + "| {a: } | {a: } |", + "| {a: 1} | {a: 1} |", + "| {a: 2} | {a: 2} |", + "+--------+--------+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_on_struct_with_nulls() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let right = + build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (_, batches_null_eq) = join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + true, + task_ctx.clone(), + ) + .await?; + + let expected_null_eq = [ + "+----+----+", + "| n1 | n2 |", + "+----+----+", + "| | |", + "+----+----+", + ]; + assert_batches_eq!(expected_null_eq, &batches_null_eq); + + let (_, batches_null_neq) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + let expected_null_neq = + ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; + assert_batches_eq!(expected_null_neq, &batches_null_neq); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 3cbeea0f9222..593de07f7d26 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -53,6 +53,20 @@ AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE join_t3(s3 struct) + AS VALUES + (NULL), + (struct(1)), + (struct(2)); + +statement ok +CREATE TABLE join_t4(s4 struct) + AS VALUES + (NULL), + (struct(2)), + (struct(3)); + # Left semi anti join statement ok @@ -1336,6 +1350,44 @@ physical_plan 10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 11)------------MemoryExec: partitions=1, partition_sizes=[1] +# Join on struct +query TT +explain select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +logical_plan +01)Inner Join: join_t3.s3 = join_t4.s4 +02)--TableScan: join_t3 projection=[s3] +03)--TableScan: join_t4 projection=[s4] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2 +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ?? +select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +{id: 2} {id: 2} + +# join with struct key and nulls +# Note that intersect or except applies `null_equals_null` as true for Join. +query ? +SELECT * FROM join_t3 +EXCEPT +SELECT * FROM join_t4 +---- +{id: 1} + query TT EXPLAIN select count(*)