Skip to content

Commit

Permalink
test(python): Add benchmark tests for join_where with inequalities (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve authored Sep 9, 2024
1 parent d3a14de commit aa3b2c3
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions py-polars/tests/benchmark/test_join_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Benchmark tests for join_where with inequality conditions."""

from __future__ import annotations

import numpy as np
import pytest

import polars as pl

pytestmark = pytest.mark.benchmark()


def test_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:
east, west = east_west
result = (
east.lazy()
.join_where(
west.lazy(),
[pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")],
)
.collect()
)

assert len(result) > 0


def test_non_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:
east, west = east_west
result = (
east.lazy()
.join_where(
west.lazy(),
[pl.col("dur") <= pl.col("time"), pl.col("rev") >= pl.col("cost")],
)
.collect()
)

assert len(result) > 0


@pytest.fixture(scope="module")
def east_west() -> tuple[pl.DataFrame, pl.DataFrame]:
num_rows_left, num_rows_right = 50_000, 5_000
rng = np.random.default_rng(42)

# Generate two separate datasets where revenue/cost are linearly related to
# duration/time, but add some noise to the west table so that there are some
# rows where the cost for the same or greater time will be less than the east table.
east_dur = rng.integers(1_000, 50_000, num_rows_left)
east_rev = (east_dur * 0.123).astype(np.int32)
west_time = rng.integers(1_000, 50_000, num_rows_right)
west_cost = west_time * 0.123
west_cost += rng.normal(0.0, 1.0, num_rows_right)
west_cost = west_cost.astype(np.int32)

east = pl.DataFrame(
{
"id": np.arange(0, num_rows_left),
"dur": east_dur,
"rev": east_rev,
"cores": rng.integers(1, 10, num_rows_left),
}
)
west = pl.DataFrame(
{
"t_id": np.arange(0, num_rows_right),
"time": west_time,
"cost": west_cost,
"cores": rng.integers(1, 10, num_rows_right),
}
)

return east, west

0 comments on commit aa3b2c3

Please sign in to comment.