From d36f4249fc6be5220697a7965ad881371ede434e Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 11:15:43 +0200 Subject: [PATCH 1/2] refactor: dataframe join params --- python/datafusion/dataframe.py | 79 +++++++++++++++++++++++++++++++--- python/tests/test_dataframe.py | 56 ++++++++++++++++++++++-- src/dataframe.rs | 9 ++-- 3 files changed, 129 insertions(+), 15 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c5ac0bb8..2494420b 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -20,8 +20,8 @@ """ from __future__ import annotations - -from typing import Any, List, TYPE_CHECKING +import warnings +from typing import Any, List, TYPE_CHECKING, Literal, overload from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated from datafusion.plan import LogicalPlan, ExecutionPlan @@ -31,7 +31,7 @@ import pandas as pd import polars as pl import pathlib - from typing import Callable + from typing import Callable, Sequence from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr, SortExpr, sort_or_default @@ -271,11 +271,51 @@ def distinct(self) -> DataFrame: """ return DataFrame(self.df.distinct()) + @overload + def join( + self, + right: DataFrame, + on: str | Sequence[str], + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: None = None, + right_on: None = None, + join_keys: None = None, + ) -> DataFrame: ... + + @overload def join( self, right: DataFrame, + on: None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: str | Sequence[str], + right_on: str | Sequence[str], + join_keys: tuple[list[str], list[str]] | None = None, + ) -> DataFrame: ... + + @overload + def join( + self, + right: DataFrame, + on: None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, join_keys: tuple[list[str], list[str]], - how: str, + left_on: None = None, + right_on: None = None, + ) -> DataFrame: ... + + def join( + self, + right: DataFrame, + on: str | Sequence[str] | None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: str | Sequence[str] | None = None, + right_on: str | Sequence[str] | None = None, + join_keys: tuple[list[str], list[str]] | None = None, ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. @@ -284,14 +324,41 @@ def join( Args: right: Other DataFrame to join with. - join_keys: Tuple of two lists of column names to join on. + on: Column names to join on in both dataframes. how: Type of join to perform. Supported types are "inner", "left", "right", "full", "semi", "anti". + left_on: Join column of the left dataframe. + right_on: Join column of the right dataframe. + join_keys: Tuple of two lists of column names to join on. [Deprecated] Returns: DataFrame after join. """ - return DataFrame(self.df.join(right.df, join_keys, how)) + if join_keys is not None: + warnings.warn( + "`join_keys` is deprecated, use `on` or `left_on` with `right_on`", + category=DeprecationWarning, + stacklevel=2, + ) + left_on = join_keys[0] + right_on = join_keys[1] + + if on: + if left_on or right_on: + raise ValueError( + "`left_on` or `right_on` should not provided with `on`" + ) + left_on = on + right_on = on + elif left_on or right_on: + if left_on is None or right_on is None: + raise ValueError("`left_on` and `right_on` should both be provided.") + else: + raise ValueError( + "either `on` or `left_on` and `right_on` should be provided." + ) + + return DataFrame(self.df.join(right.df, how, left_on, right_on)) def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: """Return a DataFrame with the explanation of its plan so far. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index e89c5715..535656a8 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -250,15 +250,63 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]], "r") - df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") - df.show() - df = df.sort(column("l.a")) - table = pa.Table.from_batches(df.collect()) + df2 = df.join(df1, on="a", how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + df2 = df.join(df1, left_on="a", right_on="a", how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected +def test_join_invalid_params(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], "l") + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]], "r") + + with pytest.deprecated_call(): + df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + with pytest.raises( + ValueError, match=r"`left_on` or `right_on` should not provided with `on`" + ): + df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore + + with pytest.raises( + ValueError, match=r"`left_on` and `right_on` should both be provided." + ): + df2 = df.join(df1, left_on="a", how="inner") # type: ignore + + with pytest.raises( + ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." + ): + df2 = df.join(df1, how="inner") # type: ignore + + def test_distinct(): ctx = SessionContext() diff --git a/src/dataframe.rs b/src/dataframe.rs index e77ca842..dfffaa9e 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -254,8 +254,9 @@ impl PyDataFrame { fn join( &self, right: PyDataFrame, - join_keys: (Vec, Vec), how: &str, + left_on: Vec, + right_on: Vec, ) -> PyResult { let join_type = match how { "inner" => JoinType::Inner, @@ -272,13 +273,11 @@ impl PyDataFrame { } }; - let left_keys = join_keys - .0 + let left_keys = left_on .iter() .map(|s| s.as_ref()) .collect::>(); - let right_keys = join_keys - .1 + let right_keys = right_on .iter() .map(|s| s.as_ref()) .collect::>(); From bb83a74234e8b20c8c8062f4df732d00be80121c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:41:42 +0200 Subject: [PATCH 2/2] chore: add description for on params --- python/datafusion/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 2494420b..ae509e09 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -319,8 +319,7 @@ def join( ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. - Join keys are a pair of lists of column names in the left and right - dataframes, respectively. These lists must have the same length. + `on` has to be provided or both `left_on` and `right_on` in conjunction. Args: right: Other DataFrame to join with.