Skip to content

Commit

Permalink
Crude hack to introduce type coercion for hash join keys
Browse files Browse the repository at this point in the history
Remove this after rebasing on top of commit ac2e5d1 "Support type coercion for equijoin (apache#4666)". It was first released at DF 16.0
ARROW-11838: fix offset buffer in golden file (#60)
  • Loading branch information
mcheshkov committed Sep 6, 2024
1 parent 400fa0d commit dcf3e4a
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions datafusion/core/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ use crate::arrow::datatypes::TimeUnit;
use crate::execution::context::TaskContext;
use crate::physical_plan::coalesce_batches::concat_batches;
use crate::physical_plan::PhysicalExpr;
use datafusion_expr::binary_rule::coerce_types;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::try_cast;
use log::debug;
use std::fmt;

Expand Down Expand Up @@ -295,7 +298,32 @@ impl ExecutionPlan for HashJoinExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
// This is a hacky way to support type coercion for join expressions
// Without this it would panic later, in build_join_indexes => equal_rows, when it would try to downcast both sides to same primitive type
// TODO Remove this after rebasing on top of commit ac2e5d15 "Support type coercion for equijoin (#4666)". It was first released at DF 16.0

// TODO Rewrite it with iterators on modern toolchain, `impl FromIterator<(AE, BE)> for (A, B)` is not available ATM
let mut on_left = Vec::with_capacity(self.on.len());
let mut on_right = Vec::with_capacity(self.on.len());
for on in &self.on {
let l = Arc::new(on.0.clone());
let r = Arc::new(on.1.clone());

let lt = l.data_type(&self.left.schema())?;
let rt = r.data_type(&self.right.schema())?;
let res_type = coerce_types(&lt, &Operator::Eq, &rt)?;

let left_cast = try_cast(l, &self.left.schema(), res_type.clone())?;
let right_cast = try_cast(r, &self.right.schema(), res_type)?;

on_left.push(left_cast);
on_right.push(right_cast);
}

// Make them immutable
let on_left = on_left;
let on_right = on_right;

// we only want to compute the build side once for PartitionMode::CollectLeft
let left_data = {
match self.mode {
Expand Down Expand Up @@ -414,7 +442,6 @@ impl ExecutionPlan for HashJoinExec {
// over the right that uses this information to issue new batches.

let right_stream = self.right.execute(partition, context.clone()).await?;
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();

let num_rows = left_data.1.num_rows();
let visited_left_side = match self.join_type {
Expand Down Expand Up @@ -473,7 +500,7 @@ impl ExecutionPlan for HashJoinExec {
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
/// assuming that the [RecordBatch] corresponds to the `index`th
fn update_hash(
on: &[Column],
on: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
hash_map: &mut JoinHashMap,
offset: usize,
Expand Down Expand Up @@ -512,9 +539,9 @@ struct HashJoinStream {
/// Input schema
schema: Arc<Schema>,
/// columns from the left
on_left: Vec<Column>,
on_left: Vec<Arc<dyn PhysicalExpr>>,
/// columns from the right used to compute the hash
on_right: Vec<Column>,
on_right: Vec<Arc<dyn PhysicalExpr>>,
/// type of the join
join_type: JoinType,
/// information from the left
Expand All @@ -539,8 +566,8 @@ struct HashJoinStream {
impl HashJoinStream {
fn new(
schema: Arc<Schema>,
on_left: Vec<Column>,
on_right: Vec<Column>,
on_left: Vec<Arc<dyn PhysicalExpr>>,
on_right: Vec<Arc<dyn PhysicalExpr>>,
join_type: JoinType,
left_data: JoinLeftData,
right: SendableRecordBatchStream,
Expand Down Expand Up @@ -624,8 +651,8 @@ fn build_batch_from_indices(
fn build_batch(
batch: &RecordBatch,
left_data: &JoinLeftData,
on_left: &[Column],
on_right: &[Column],
on_left: &[Arc<dyn PhysicalExpr>],
on_right: &[Arc<dyn PhysicalExpr>],
join_type: JoinType,
schema: &Schema,
column_indices: &[ColumnIndex],
Expand Down Expand Up @@ -691,8 +718,8 @@ fn build_join_indexes(
left_data: &JoinLeftData,
right: &RecordBatch,
join_type: JoinType,
left_on: &[Column],
right_on: &[Column],
left_on: &[Arc<dyn PhysicalExpr>],
right_on: &[Arc<dyn PhysicalExpr>],
random_state: &RandomState,
null_equals_null: &bool,
) -> Result<(UInt64Array, UInt32Array)> {
Expand Down Expand Up @@ -2002,8 +2029,8 @@ mod tests {
&left_data,
&right,
JoinType::Inner,
&[Column::new("a", 0)],
&[Column::new("a", 0)],
&[Arc::new(Column::new("a", 0))],
&[Arc::new(Column::new("a", 0))],
&random_state,
&false,
)?;
Expand Down

0 comments on commit dcf3e4a

Please sign in to comment.