From 620f8eb5266472bd31a2d08c45b52d2d093e6fc4 Mon Sep 17 00:00:00 2001 From: Karthik Velayutham Date: Mon, 6 Mar 2023 16:41:35 -0600 Subject: [PATCH] FEAT: Add support for `df.apply` (#7) * FEAT: Add support for df.apply * Get output in client --- .../storage_formats/base/query_compiler.py | 6 ++++- .../storage_formats/pandas/query_compiler.py | 2 +- modin/pandas/__init__.py | 1 + modin/pandas/base.py | 13 ++++++++++- modin/pandas/dataframe.py | 4 +++- modin/pandas/groupby.py | 22 +++++++++---------- modin/pandas/series.py | 8 +++++-- modin/pandas/window.py | 5 +++-- 8 files changed, 41 insertions(+), 20 deletions(-) diff --git a/modin/core/storage_formats/base/query_compiler.py b/modin/core/storage_formats/base/query_compiler.py index fe234cc0e14..d6473fe6069 100644 --- a/modin/core/storage_formats/base/query_compiler.py +++ b/modin/core/storage_formats/base/query_compiler.py @@ -2466,7 +2466,9 @@ def drop(self, index=None, columns=None, errors: str = "raise"): # UDF (apply and agg) methods # There is a wide range of behaviors that are supported, so a lot of the # logic can get a bit convoluted. - def apply(self, func, axis, raw=False, result_type=None, *args, **kwargs): + def apply( + self, func, axis, raw=False, result_type=None, output_meta=None, *args, **kwargs + ): """ Apply passed function across given axis. @@ -2488,6 +2490,8 @@ def apply(self, func, axis, raw=False, result_type=None, *args, **kwargs): - "reduce": keep result into a single cell (opposite of "expand"). - "broadcast": broadcast result to original data shape (overwrite the existing column/row with the function result). - None: use "expand" strategy if Series is returned, "reduce" otherwise. + output_meta : pandas.Series, pandas.DataFrame, or scalar + Output of apply on one partition of df. *args : iterable Positional arguments to pass to `func`. **kwargs : dict diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index 9128eaef371..e7863012821 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -2681,7 +2681,7 @@ def explode(self, column): # UDF (apply and agg) methods # There is a wide range of behaviors that are supported, so a lot of the # logic can get a bit convoluted. - def apply(self, func, axis, *args, **kwargs): + def apply(self, func, axis, *args, output_meta=None, **kwargs): # if any of args contain modin object, we should # convert it to pandas args = try_cast_to_pandas(args) diff --git a/modin/pandas/__init__.py b/modin/pandas/__init__.py index 46afc709361..ade3045b9c6 100644 --- a/modin/pandas/__init__.py +++ b/modin/pandas/__init__.py @@ -98,6 +98,7 @@ def _update_engine(publisher: Parameter): from modin.config import Engine, StorageFormat, CpuCount + Engine.NOINIT_ENGINES.add("Client") from modin.config.envvars import IsExperimental from modin.config.pubsub import ValueSource diff --git a/modin/pandas/base.py b/modin/pandas/base.py index 7d18ba14fc3..7281a2be3bf 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -874,6 +874,7 @@ def apply( """ Apply a function along an axis of the `BasePandasDataset`. """ + import cloudpickle def error_raiser(msg, exception): """Convert passed exception to the same type as pandas do and raise it.""" @@ -900,7 +901,8 @@ def error_raiser(msg, exception): FutureWarning, stacklevel=2, ) - query_compiler = self._query_compiler.apply( + + output_pandas_df = self._to_pandas().apply( func, axis, args=args, @@ -908,6 +910,15 @@ def error_raiser(msg, exception): result_type=result_type, **kwds, ) + query_compiler = self._query_compiler.apply( + cloudpickle.dumps(func), + axis, + args=args, + raw=raw, + result_type=result_type, + output_meta=output_pandas_df, + **kwds, + ) return query_compiler def asfreq( diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index 2973ba29932..5166b1fb907 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -735,7 +735,9 @@ def corrwith( Compute pairwise correlation. """ return self.__constructor__( - query_compiler=self._query_compiler.corrwith(other, axis, drop, method, numeric_only) + query_compiler=self._query_compiler.corrwith( + other, axis, drop, method, numeric_only + ) ) def cov( diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index 2734090d82f..bb5dcf7f654 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -240,12 +240,12 @@ def value_counts( # equivalent to df.value_counts([, ]).sort_index() # it returns a MultiIndex Series which needs to be converted to # pandas for sort_index. - # + # # Semantic Exceptions: # normalize does not work; it will return the normalized results # across the entire dataframe, not within the sub levels # DataFrame(as_index=False) does not work. The default is True - # calling this function will always result in a Series rather + # calling this function will always result in a Series rather # than a DataFrame # if is_list_like(self._by): @@ -255,16 +255,13 @@ def value_counts( for c in self._columns.values.tolist(): if c not in subset: subset.append(c) - return ( - self._df.value_counts( - subset=subset, - normalize=normalize, - sort=sort, - ascending=ascending, - dropna=dropna, - ) - .sort_index(level=0, sort_remaining=False) - ) + return self._df.value_counts( + subset=subset, + normalize=normalize, + sort=sort, + ascending=ascending, + dropna=dropna, + ).sort_index(level=0, sort_remaining=False) def mean(self, numeric_only=None): return self._check_index( @@ -288,6 +285,7 @@ def plot(self): # pragma: no cover def ohlc(self): from .dataframe import DataFrame + if isinstance(self._df, DataFrame): raise NotImplementedError("groupby.ohlc() not supported for dataframes!") else: diff --git a/modin/pandas/series.py b/modin/pandas/series.py index 1b30350aec8..63520bf08ac 100644 --- a/modin/pandas/series.py +++ b/modin/pandas/series.py @@ -995,7 +995,9 @@ def factorize( Encode the object as an enumerated type or categorical variable. """ return self.__constructor__( - query_compiler=self._query_compiler.factorize(sort, na_sentinel, use_na_sentinel) + query_compiler=self._query_compiler.factorize( + sort, na_sentinel, use_na_sentinel + ) ) def fillna( @@ -1983,7 +1985,9 @@ def truncate( """ Truncate a Series before and after some index value. """ - return self.__constructor__(self.__query_compiler__.truncate(before, after, axis, copy)) + return self.__constructor__( + self.__query_compiler__.truncate(before, after, axis, copy) + ) def unique(self): # noqa: RT01, D200 """ diff --git a/modin/pandas/window.py b/modin/pandas/window.py index e9a78134f4e..7efde5d5caa 100644 --- a/modin/pandas/window.py +++ b/modin/pandas/window.py @@ -294,8 +294,10 @@ def __init__(self, dataframe, min_periods=1, center=None, axis=0, method="single def aggregate(self, *args, **kwargs): return self._dataframe.__constructor__( query_compiler=self._query_compiler.expanding_aggregate( - self.axis, self.expanding_args, *args, **kwargs) + self.axis, self.expanding_args, *args, **kwargs + ) ) + def sum(self, *args, **kwargs): return self._dataframe.__constructor__( query_compiler=self._query_compiler.expanding_sum( @@ -351,4 +353,3 @@ def sem(self, *args, **kwargs): self.axis, self.expanding_args, *args, **kwargs ) ) -