Skip to content

Commit

Permalink
pass sort parameter of DF.stack
Browse files Browse the repository at this point in the history
  • Loading branch information
noloerino committed Sep 12, 2024
1 parent 3357709 commit 88afe4a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
7 changes: 5 additions & 2 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,7 +1880,7 @@ def searchsorted(self, **kwargs): # noqa: PR02
# END Abstract map partitions operations

@doc_utils.add_refer_to("DataFrame.stack")
def stack(self, level, dropna):
def stack(self, level, dropna, sort):
"""
Stack the prescribed level(s) from columns to index.
Expand All @@ -1894,7 +1894,10 @@ def stack(self, level, dropna):
BaseQueryCompiler
"""
return DataFrameDefault.register(pandas.DataFrame.stack)(
self, level=level, dropna=dropna
self,
level=level,
dropna=dropna,
sort=sort,
)

# Abstract map partitions across select indices
Expand Down
6 changes: 4 additions & 2 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ def get_unique_level_values(index):
result = result.reindex(0, new_index)
return result

def stack(self, level, dropna):
def stack(self, level, dropna, sort):
if not isinstance(self.columns, pandas.MultiIndex) or (
isinstance(self.columns, pandas.MultiIndex)
and is_list_like(level)
Expand All @@ -1916,7 +1916,9 @@ def stack(self, level, dropna):

new_modin_frame = self._modin_frame.apply_full_axis(
1,
lambda df: pandas.DataFrame(df.stack(level=level, dropna=dropna)),
lambda df: pandas.DataFrame(
df.stack(level=level, dropna=dropna, sort=sort)
),
new_columns=new_columns,
)
return self.__constructor__(new_modin_frame)
Expand Down
4 changes: 2 additions & 2 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,11 +2113,11 @@ def stack(
is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels
):
return self._reduce_dimension(
query_compiler=self._query_compiler.stack(level, dropna)
query_compiler=self._query_compiler.stack(level, dropna, sort)
)
else:
return self.__constructor__(
query_compiler=self._query_compiler.stack(level, dropna)
query_compiler=self._query_compiler.stack(level, dropna, sort)
)

def sub(
Expand Down
10 changes: 10 additions & 0 deletions modin/tests/pandas/dataframe/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,16 @@ def test_stack(data, is_multi_idx, is_multi_col):
df_equals(modin_df.stack(level=[0, 1, 2]), pandas_df.stack(level=[0, 1, 2]))


@pytest.mark.parametrize("sort", [True, False])
def test_stack_sort(sort):
# Example frame slightly modified from pandas docs to be unsorted
cols = pd.MultiIndex.from_tuples([("weight", "pounds"), ("weight", "kg")])
modin_df, pandas_df = create_test_dfs(
[[1, 2], [2, 4]], index=["cat", "dog"], columns=cols
)
df_equals(modin_df.stack(sort=sort), pandas_df.stack(sort=sort))


@pytest.mark.parametrize("data", test_data_values, ids=test_data_keys)
@pytest.mark.parametrize("axis1", [0, 1])
@pytest.mark.parametrize("axis2", [0, 1])
Expand Down

0 comments on commit 88afe4a

Please sign in to comment.