From 4db3e702b959ad3f494e8299c3fc1258ab294112 Mon Sep 17 00:00:00 2001 From: Iaroslav Igoshev Date: Tue, 28 Feb 2023 22:19:18 +0100 Subject: [PATCH] PERF-#5705: Preserve metadata when applying `Series.cat.codes` (#5706) Signed-off-by: Igoshev, Iaroslav --- modin/core/dataframe/pandas/dataframe/dataframe.py | 14 +++++++++++++- .../core/storage_formats/pandas/query_compiler.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 5152cb26200..30925e40e06 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -1793,7 +1793,7 @@ def window( pass @lazy_metadata_decorator(apply_axis="both") - def fold(self, axis, func): + def fold(self, axis, func, new_columns=None): """ Perform a function across an entire axis. @@ -1803,6 +1803,11 @@ def fold(self, axis, func): The axis to apply over. func : callable The function to apply. + new_columns : list-like, optional + The columns of the result. + Must be the same length as the columns' length of `self`. + The column labels of `self` may change during an operation so + we may want to pass the new column labels in (e.g., see `cat.codes`). Returns ------- @@ -1813,6 +1818,13 @@ def fold(self, axis, func): ----- The data shape is not changed (length and width of the table). """ + if new_columns is not None: + if self._columns_cache is not None: + assert len(self._columns_cache) == len( + new_columns + ), "The length of `new_columns` doesn't match the columns' length of `self`" + self._columns_cache = new_columns + new_partitions = self._partition_mgr_cls.map_axis_partitions( axis, self._partitions, func, keep_partitioning=True ) diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index e0de8ca781f..52c89dd61be 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -3452,7 +3452,7 @@ def func(df) -> np.ndarray: ser = ser.astype("category", copy=False) return ser.cat.codes.to_frame(name=MODIN_UNNAMED_SERIES_LABEL) - res = self._modin_frame.apply_full_axis( + res = self._modin_frame.fold( axis=0, func=func, new_columns=[MODIN_UNNAMED_SERIES_LABEL] ) return self.__constructor__(res, shape_hint="column")