diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 59c70cc78932..8ba7ab855186 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -2162,6 +2162,26 @@ impl JoinBuilder { opt_state |= OptFlags::FILE_CACHING; } + // Decompose `And` conjunctions into their component expressions + fn decompose_and(predicate: Expr, expanded_predicates: &mut Vec) { + if let Expr::BinaryExpr { + op: Operator::And, + left, + right, + } = predicate + { + decompose_and((*left).clone(), expanded_predicates); + decompose_and((*right).clone(), expanded_predicates); + } else { + expanded_predicates.push(predicate); + } + } + let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2); + for predicate in predicates { + decompose_and(predicate, &mut expanded_predicates); + } + let predicates: Vec = expanded_predicates; + // Decompose `is_between` predicates to allow for cleaner expression of range joins #[cfg(feature = "is_between")] let predicates: Vec = { diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 7684062de23f..a81c36bef1f8 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -172,7 +172,7 @@ fn resolve_join_where( { comparison_count += 1; if comparison_count > 1 { - polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate, found: {:?}", expr); + polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate; found {:?}", expr); } } diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index fd9fca28f72d..89eec6bfd224 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -69,7 +69,9 @@ def test_basic_ie_join() -> None: ) actual = east.join_where( - west, pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost") + west, + pl.col("dur") < pl.col("time"), + pl.col("rev") > pl.col("cost"), ) expected = pl.DataFrame( @@ -111,7 +113,9 @@ def test_ie_join_with_slice(offset: int, length: int) -> None: actual = ( east.join_where( - west, pl.col("dur") < pl.col("time"), pl.col("rev") < pl.col("cost") + west, + pl.col("dur") < pl.col("time"), + pl.col("rev") < pl.col("cost"), ) .slice(offset, length) .collect() @@ -260,7 +264,9 @@ def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None: q = ( left.lazy() .join_where( - right.lazy(), pl.col("group") != pl.col("group_right"), *range_constraint + right.lazy(), + pl.col("group") != pl.col("group_right"), + *range_constraint, ) .select("id", "id_right", "group") .sort("id") @@ -405,7 +411,7 @@ def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) -> expr0 = _inequality_expression("dur", op1, "time") expr1 = _inequality_expression("rev", op2, "cost") - actual = east.join_where(west, expr0, expr1) + actual = east.join_where(west, expr0 & expr1) expected = east.join(west, how="cross").filter(expr0 & expr1) assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) @@ -423,7 +429,7 @@ def test_ie_join_with_nulls( expr0 = _inequality_expression("dur", op1, "time") expr1 = _inequality_expression("rev", op2, "cost") - actual = east.join_where(west, expr0, expr1) + actual = east.join_where(west, expr0 & expr1) expected = east.join(west, how="cross").filter(expr0 & expr1) assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) @@ -460,7 +466,7 @@ def test_raise_on_multiple_binary_comparisons() -> None: df = pl.DataFrame({"id": [1, 2]}) with pytest.raises( pl.exceptions.InvalidOperationError, - match="only one binary comparison allowed in each 'join_where' predicate, found: ", + match="only one binary comparison allowed in each 'join_where' predicate; found ", ): df.join_where( df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id"))