diff --git a/Cargo.lock b/Cargo.lock index 8d60981bd3..db446053c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2084,6 +2084,7 @@ dependencies = [ "serde", "snafu", "test-log", + "uuid 1.10.0", ] [[package]] diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 7de191fe81..13d2c4307f 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -34,6 +34,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 2a68390066..8e6d0b005e 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ col, @@ -13,6 +13,7 @@ use daft_dsl::{ }; use itertools::Itertools; use snafu::ResultExt; +use uuid::Uuid; use crate::{ logical_ops::Project, @@ -54,14 +55,31 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { - let (left_on, left_fields) = - resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; - let (right_on, right_fields) = + let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; + let (right_on, _) = resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; - for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { - let on_schema = Schema::new(on_fields).context(CreationSnafu)?; - for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) { + let (unique_left_on, unique_right_on) = + Self::rename_join_keys(left_on.clone(), right_on.clone()); + + let left_fields: Vec = unique_left_on + .iter() + .map(|e| e.to_field(&left.schema())) + .collect::>>() + .context(CreationSnafu)?; + + let right_fields: Vec = unique_right_on + .iter() + .map(|e| e.to_field(&right.schema())) + .collect::>>() + .context(CreationSnafu)?; + + for (on_exprs, on_fields) in [ + (&unique_left_on, &left_fields), + (&unique_right_on, &right_fields), + ] { + for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + // Null type check for both fields and expressions if matches!(field.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" @@ -167,6 +185,60 @@ impl Join { } } + /// Renames join keys for the given left and right expressions. This is required to + /// prevent errors when the join keys on the left and right expressions have the same key + /// name. + /// + /// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and + /// checks for pairs of column expressions that differ. If both expressions in a pair + /// are column expressions and they are not identical, it generates a unique identifier + /// and renames both expressions by appending this identifier to their original names. + /// + /// The function returns two vectors of expressions, where the renamed expressions are + /// substituted for the original expressions in the cases where renaming occurred. + /// + /// # Parameters + /// - `left_exprs`: A vector of expressions from the left side of a join. + /// - `right_exprs`: A vector of expressions from the right side of a join. + /// + /// # Returns + /// A tuple containing two vectors of expressions, one for the left side and one for the + /// right side, where expressions that needed to be renamed have been modified. + /// + /// # Example + /// ``` + /// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions); + /// ``` + /// + /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). + + fn rename_join_keys( + left_exprs: Vec>, + right_exprs: Vec>, + ) -> (Vec>, Vec>) { + left_exprs + .into_iter() + .zip(right_exprs) + .map( + |(left_expr, right_expr)| match (&*left_expr, &*right_expr) { + (Expr::Column(left_name), Expr::Column(right_name)) + if left_name == right_name => + { + (left_expr, right_expr) + } + _ => { + let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = + left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); + let renamed_right_expr = + right_expr.alias(format!("{}_{}", right_expr.name(), unique_id)); + (renamed_left_expr, renamed_right_expr) + } + }, + ) + .unzip() + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..5e79acf698 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -53,6 +53,17 @@ def test_columns_after_join(make_df): assert set(joined_df2.schema().column_names()) == set(["A", "B"]) +def test_rename_join_keys_in_dataframe(make_df): + df1 = make_df({"A": [1, 2], "B": [2, 2]}) + + df2 = make_df({"A": [1, 2]}) + joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"]) + joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"]) + + assert set(joined_df1.schema().column_names()) == set(["A", "B"]) + assert set(joined_df2.schema().column_names()) == set(["A", "B"]) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize( "join_strategy",