Skip to content

Commit

Permalink
feat: Streamline use of predicates connected by & with IEJoin (`joi…
Browse files Browse the repository at this point in the history
…n_where`) (#19552)
  • Loading branch information
alexander-beedie authored Nov 1, 2024
1 parent 7c2f31e commit f38e56b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
20 changes: 20 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,26 @@ impl JoinBuilder {
opt_state |= OptFlags::FILE_CACHING;
}

// Decompose `And` conjunctions into their component expressions
fn decompose_and(predicate: Expr, expanded_predicates: &mut Vec<Expr>) {
if let Expr::BinaryExpr {
op: Operator::And,
left,
right,
} = predicate
{
decompose_and((*left).clone(), expanded_predicates);
decompose_and((*right).clone(), expanded_predicates);
} else {
expanded_predicates.push(predicate);
}
}
let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2);
for predicate in predicates {
decompose_and(predicate, &mut expanded_predicates);
}
let predicates: Vec<Expr> = expanded_predicates;

// Decompose `is_between` predicates to allow for cleaner expression of range joins
#[cfg(feature = "is_between")]
let predicates: Vec<Expr> = {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ fn resolve_join_where(
{
comparison_count += 1;
if comparison_count > 1 {
polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate, found: {:?}", expr);
polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate; found {:?}", expr);
}
}

Expand Down
18 changes: 12 additions & 6 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def test_basic_ie_join() -> None:
)

actual = east.join_where(
west, pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")
west,
pl.col("dur") < pl.col("time"),
pl.col("rev") > pl.col("cost"),
)

expected = pl.DataFrame(
Expand Down Expand Up @@ -111,7 +113,9 @@ def test_ie_join_with_slice(offset: int, length: int) -> None:

actual = (
east.join_where(
west, pl.col("dur") < pl.col("time"), pl.col("rev") < pl.col("cost")
west,
pl.col("dur") < pl.col("time"),
pl.col("rev") < pl.col("cost"),
)
.slice(offset, length)
.collect()
Expand Down Expand Up @@ -260,7 +264,9 @@ def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None:
q = (
left.lazy()
.join_where(
right.lazy(), pl.col("group") != pl.col("group_right"), *range_constraint
right.lazy(),
pl.col("group") != pl.col("group_right"),
*range_constraint,
)
.select("id", "id_right", "group")
.sort("id")
Expand Down Expand Up @@ -405,7 +411,7 @@ def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) ->
expr0 = _inequality_expression("dur", op1, "time")
expr1 = _inequality_expression("rev", op2, "cost")

actual = east.join_where(west, expr0, expr1)
actual = east.join_where(west, expr0 & expr1)

expected = east.join(west, how="cross").filter(expr0 & expr1)
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
Expand All @@ -423,7 +429,7 @@ def test_ie_join_with_nulls(
expr0 = _inequality_expression("dur", op1, "time")
expr1 = _inequality_expression("rev", op2, "cost")

actual = east.join_where(west, expr0, expr1)
actual = east.join_where(west, expr0 & expr1)

expected = east.join(west, how="cross").filter(expr0 & expr1)
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
Expand Down Expand Up @@ -460,7 +466,7 @@ def test_raise_on_multiple_binary_comparisons() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="only one binary comparison allowed in each 'join_where' predicate, found: ",
match="only one binary comparison allowed in each 'join_where' predicate; found ",
):
df.join_where(
df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id"))
Expand Down

0 comments on commit f38e56b

Please sign in to comment.