diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index ccb91c34247f..815eef6c8c4b 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -76,6 +76,9 @@ use crate::arrow::datatypes::TimeUnit; use crate::execution::context::TaskContext; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; +use datafusion_expr::binary_rule::coerce_types; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::try_cast; use log::debug; use std::fmt; @@ -295,7 +298,32 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { - let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + // This is a hacky way to support type coercion for join expressions + // Without this it would panic later, in build_join_indexes => equal_rows, when it would try to downcast both sides to same primitive type + // TODO Remove this after rebasing on top of commit ac2e5d15 "Support type coercion for equijoin (#4666)". It was first released at DF 16.0 + + // TODO Rewrite it with iterators on modern toolchain, `impl FromIterator<(AE, BE)> for (A, B)` is not available ATM + let mut on_left = Vec::with_capacity(self.on.len()); + let mut on_right = Vec::with_capacity(self.on.len()); + for on in &self.on { + let l = Arc::new(on.0.clone()); + let r = Arc::new(on.1.clone()); + + let lt = l.data_type(&self.left.schema())?; + let rt = r.data_type(&self.right.schema())?; + let res_type = coerce_types(<, &Operator::Eq, &rt)?; + + let left_cast = try_cast(l, &self.left.schema(), res_type.clone())?; + let right_cast = try_cast(r, &self.right.schema(), res_type)?; + + on_left.push(left_cast); + on_right.push(right_cast); + } + + // Make them immutable + let on_left = on_left; + let on_right = on_right; + // we only want to compute the build side once for PartitionMode::CollectLeft let left_data = { match self.mode { @@ -414,7 +442,6 @@ impl ExecutionPlan for HashJoinExec { // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context.clone()).await?; - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { @@ -473,7 +500,7 @@ impl ExecutionPlan for HashJoinExec { /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th fn update_hash( - on: &[Column], + on: &[Arc], batch: &RecordBatch, hash_map: &mut JoinHashMap, offset: usize, @@ -512,9 +539,9 @@ struct HashJoinStream { /// Input schema schema: Arc, /// columns from the left - on_left: Vec, + on_left: Vec>, /// columns from the right used to compute the hash - on_right: Vec, + on_right: Vec>, /// type of the join join_type: JoinType, /// information from the left @@ -539,8 +566,8 @@ struct HashJoinStream { impl HashJoinStream { fn new( schema: Arc, - on_left: Vec, - on_right: Vec, + on_left: Vec>, + on_right: Vec>, join_type: JoinType, left_data: JoinLeftData, right: SendableRecordBatchStream, @@ -624,8 +651,8 @@ fn build_batch_from_indices( fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, - on_left: &[Column], - on_right: &[Column], + on_left: &[Arc], + on_right: &[Arc], join_type: JoinType, schema: &Schema, column_indices: &[ColumnIndex], @@ -691,8 +718,8 @@ fn build_join_indexes( left_data: &JoinLeftData, right: &RecordBatch, join_type: JoinType, - left_on: &[Column], - right_on: &[Column], + left_on: &[Arc], + right_on: &[Arc], random_state: &RandomState, null_equals_null: &bool, ) -> Result<(UInt64Array, UInt32Array)> { @@ -2002,8 +2029,8 @@ mod tests { &left_data, &right, JoinType::Inner, - &[Column::new("a", 0)], - &[Column::new("a", 0)], + &[Arc::new(Column::new("a", 0))], + &[Arc::new(Column::new("a", 0))], &random_state, &false, )?;