Skip to content

Commit

Permalink
FEAT: Add support for df.apply (modin-project#7)
Browse files Browse the repository at this point in the history
* FEAT: Add support for df.apply

* Get output in client
  • Loading branch information
pyrito authored and vnlitvinov committed Mar 16, 2023
1 parent 4273722 commit 620f8eb
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 20 deletions.
6 changes: 5 additions & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,9 @@ def drop(self, index=None, columns=None, errors: str = "raise"):
# UDF (apply and agg) methods
# There is a wide range of behaviors that are supported, so a lot of the
# logic can get a bit convoluted.
def apply(self, func, axis, raw=False, result_type=None, *args, **kwargs):
def apply(
self, func, axis, raw=False, result_type=None, output_meta=None, *args, **kwargs
):
"""
Apply passed function across given axis.
Expand All @@ -2488,6 +2490,8 @@ def apply(self, func, axis, raw=False, result_type=None, *args, **kwargs):
- "reduce": keep result into a single cell (opposite of "expand").
- "broadcast": broadcast result to original data shape (overwrite the existing column/row with the function result).
- None: use "expand" strategy if Series is returned, "reduce" otherwise.
output_meta : pandas.Series, pandas.DataFrame, or scalar
Output of apply on one partition of df.
*args : iterable
Positional arguments to pass to `func`.
**kwargs : dict
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,7 +2681,7 @@ def explode(self, column):
# UDF (apply and agg) methods
# There is a wide range of behaviors that are supported, so a lot of the
# logic can get a bit convoluted.
def apply(self, func, axis, *args, **kwargs):
def apply(self, func, axis, *args, output_meta=None, **kwargs):
# if any of args contain modin object, we should
# convert it to pandas
args = try_cast_to_pandas(args)
Expand Down
1 change: 1 addition & 0 deletions modin/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@

def _update_engine(publisher: Parameter):
from modin.config import Engine, StorageFormat, CpuCount

Engine.NOINIT_ENGINES.add("Client")
from modin.config.envvars import IsExperimental
from modin.config.pubsub import ValueSource
Expand Down
13 changes: 12 additions & 1 deletion modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ def apply(
"""
Apply a function along an axis of the `BasePandasDataset`.
"""
import cloudpickle

def error_raiser(msg, exception):
"""Convert passed exception to the same type as pandas do and raise it."""
Expand All @@ -900,14 +901,24 @@ def error_raiser(msg, exception):
FutureWarning,
stacklevel=2,
)
query_compiler = self._query_compiler.apply(

output_pandas_df = self._to_pandas().apply(
func,
axis,
args=args,
raw=raw,
result_type=result_type,
**kwds,
)
query_compiler = self._query_compiler.apply(
cloudpickle.dumps(func),
axis,
args=args,
raw=raw,
result_type=result_type,
output_meta=output_pandas_df,
**kwds,
)
return query_compiler

def asfreq(
Expand Down
4 changes: 3 additions & 1 deletion modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,9 @@ def corrwith(
Compute pairwise correlation.
"""
return self.__constructor__(
query_compiler=self._query_compiler.corrwith(other, axis, drop, method, numeric_only)
query_compiler=self._query_compiler.corrwith(
other, axis, drop, method, numeric_only
)
)

def cov(
Expand Down
22 changes: 10 additions & 12 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def value_counts(
# equivalent to df.value_counts([<by>, <other...>]).sort_index()
# it returns a MultiIndex Series which needs to be converted to
# pandas for sort_index.
#
#
# Semantic Exceptions:
# normalize does not work; it will return the normalized results
# across the entire dataframe, not within the sub levels
# DataFrame(as_index=False) does not work. The default is True
# calling this function will always result in a Series rather
# calling this function will always result in a Series rather
# than a DataFrame
#
if is_list_like(self._by):
Expand All @@ -255,16 +255,13 @@ def value_counts(
for c in self._columns.values.tolist():
if c not in subset:
subset.append(c)
return (
self._df.value_counts(
subset=subset,
normalize=normalize,
sort=sort,
ascending=ascending,
dropna=dropna,
)
.sort_index(level=0, sort_remaining=False)
)
return self._df.value_counts(
subset=subset,
normalize=normalize,
sort=sort,
ascending=ascending,
dropna=dropna,
).sort_index(level=0, sort_remaining=False)

def mean(self, numeric_only=None):
return self._check_index(
Expand All @@ -288,6 +285,7 @@ def plot(self): # pragma: no cover

def ohlc(self):
from .dataframe import DataFrame

if isinstance(self._df, DataFrame):
raise NotImplementedError("groupby.ohlc() not supported for dataframes!")
else:
Expand Down
8 changes: 6 additions & 2 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,9 @@ def factorize(
Encode the object as an enumerated type or categorical variable.
"""
return self.__constructor__(
query_compiler=self._query_compiler.factorize(sort, na_sentinel, use_na_sentinel)
query_compiler=self._query_compiler.factorize(
sort, na_sentinel, use_na_sentinel
)
)

def fillna(
Expand Down Expand Up @@ -1983,7 +1985,9 @@ def truncate(
"""
Truncate a Series before and after some index value.
"""
return self.__constructor__(self.__query_compiler__.truncate(before, after, axis, copy))
return self.__constructor__(
self.__query_compiler__.truncate(before, after, axis, copy)
)

def unique(self): # noqa: RT01, D200
"""
Expand Down
5 changes: 3 additions & 2 deletions modin/pandas/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ def __init__(self, dataframe, min_periods=1, center=None, axis=0, method="single
def aggregate(self, *args, **kwargs):
return self._dataframe.__constructor__(
query_compiler=self._query_compiler.expanding_aggregate(
self.axis, self.expanding_args, *args, **kwargs)
self.axis, self.expanding_args, *args, **kwargs
)
)

def sum(self, *args, **kwargs):
return self._dataframe.__constructor__(
query_compiler=self._query_compiler.expanding_sum(
Expand Down Expand Up @@ -351,4 +353,3 @@ def sem(self, *args, **kwargs):
self.axis, self.expanding_args, *args, **kwargs
)
)

0 comments on commit 620f8eb

Please sign in to comment.