Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: dataframe join params #912

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"""

from __future__ import annotations

from typing import Any, List, TYPE_CHECKING
import warnings
from typing import Any, List, TYPE_CHECKING, Literal, overload
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
from datafusion.plan import LogicalPlan, ExecutionPlan
Expand All @@ -31,7 +31,7 @@
import pandas as pd
import polars as pl
import pathlib
from typing import Callable
from typing import Callable, Sequence

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
Expand Down Expand Up @@ -271,27 +271,93 @@ def distinct(self) -> DataFrame:
"""
return DataFrame(self.df.distinct())

@overload
def join(
self,
right: DataFrame,
on: str | Sequence[str],
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
*,
left_on: None = None,
right_on: None = None,
join_keys: None = None,
) -> DataFrame: ...

@overload
def join(
self,
right: DataFrame,
on: None = None,
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
*,
left_on: str | Sequence[str],
right_on: str | Sequence[str],
join_keys: tuple[list[str], list[str]] | None = None,
) -> DataFrame: ...

@overload
def join(
self,
right: DataFrame,
on: None = None,
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
*,
join_keys: tuple[list[str], list[str]],
how: str,
left_on: None = None,
right_on: None = None,
) -> DataFrame: ...

def join(
self,
right: DataFrame,
on: str | Sequence[str] | None = None,
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
*,
left_on: str | Sequence[str] | None = None,
right_on: str | Sequence[str] | None = None,
join_keys: tuple[list[str], list[str]] | None = None,
) -> DataFrame:
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.

Join keys are a pair of lists of column names in the left and right
dataframes, respectively. These lists must have the same length.
`on` has to be provided or both `left_on` and `right_on` in conjunction.

Args:
right: Other DataFrame to join with.
join_keys: Tuple of two lists of column names to join on.
on: Column names to join on in both dataframes.
how: Type of join to perform. Supported types are "inner", "left",
"right", "full", "semi", "anti".
left_on: Join column of the left dataframe.
right_on: Join column of the right dataframe.
join_keys: Tuple of two lists of column names to join on. [Deprecated]
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved

Returns:
DataFrame after join.
"""
return DataFrame(self.df.join(right.df, join_keys, how))
if join_keys is not None:
warnings.warn(
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
category=DeprecationWarning,
stacklevel=2,
)
left_on = join_keys[0]
right_on = join_keys[1]

if on:
if left_on or right_on:
raise ValueError(
"`left_on` or `right_on` should not provided with `on`"
)
left_on = on
right_on = on
elif left_on or right_on:
if left_on is None or right_on is None:
raise ValueError("`left_on` and `right_on` should both be provided.")
else:
raise ValueError(
"either `on` or `left_on` and `right_on` should be provided."
)

return DataFrame(self.df.join(right.df, how, left_on, right_on))

def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
"""Return a DataFrame with the explanation of its plan so far.
Expand Down
56 changes: 52 additions & 4 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,63 @@ def test_join():
)
df1 = ctx.create_dataframe([[batch]], "r")

df = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
df.show()
df = df.sort(column("l.a"))
table = pa.Table.from_batches(df.collect())
df2 = df.join(df1, on="a", how="inner")
df2.show()
df2 = df2.sort(column("l.a"))
table = pa.Table.from_batches(df2.collect())

expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
assert table.to_pydict() == expected

df2 = df.join(df1, left_on="a", right_on="a", how="inner")
df2.show()
df2 = df2.sort(column("l.a"))
table = pa.Table.from_batches(df2.collect())

expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
assert table.to_pydict() == expected


def test_join_invalid_params():
ctx = SessionContext()

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], "l")

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2]), pa.array([8, 10])],
names=["a", "c"],
)
df1 = ctx.create_dataframe([[batch]], "r")

with pytest.deprecated_call():
df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
df2.show()
df2 = df2.sort(column("l.a"))
table = pa.Table.from_batches(df2.collect())

expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
assert table.to_pydict() == expected

with pytest.raises(
ValueError, match=r"`left_on` or `right_on` should not provided with `on`"
):
df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore

with pytest.raises(
ValueError, match=r"`left_on` and `right_on` should both be provided."
):
df2 = df.join(df1, left_on="a", how="inner") # type: ignore

with pytest.raises(
ValueError, match=r"either `on` or `left_on` and `right_on` should be provided."
):
df2 = df.join(df1, how="inner") # type: ignore


def test_distinct():
ctx = SessionContext()

Expand Down
9 changes: 4 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ impl PyDataFrame {
fn join(
&self,
right: PyDataFrame,
join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
how: &str,
left_on: Vec<PyBackedStr>,
right_on: Vec<PyBackedStr>,
) -> PyResult<Self> {
let join_type = match how {
"inner" => JoinType::Inner,
Expand All @@ -272,13 +273,11 @@ impl PyDataFrame {
}
};

let left_keys = join_keys
.0
let left_keys = left_on
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();
let right_keys = join_keys
.1
let right_keys = right_on
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();
Expand Down