Skip to content

Commit

Permalink
[SPARK-46391][PS][TESTS] Reorganize ExpandingParityTests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Reorganize `ExpandingParityTests`

### Why are the changes needed?
to make the test more consistent with pandas

### Does this PR introduce _any_ user-facing change?
no, test-only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#44332 from zhengruifeng/ps_test_expanding.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Dec 14, 2023
1 parent ab9ca96 commit 2893cd3
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 101 deletions.
12 changes: 10 additions & 2 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,11 @@ def __hash__(self):
"pyspark.pandas.tests.test_dataframe_conversion",
"pyspark.pandas.tests.test_dataframe_spark_io",
"pyspark.pandas.tests.test_default_index",
"pyspark.pandas.tests.test_expanding",
"pyspark.pandas.tests.window.test_expanding",
"pyspark.pandas.tests.window.test_expanding_adv",
"pyspark.pandas.tests.window.test_expanding_error",
"pyspark.pandas.tests.window.test_groupby_expanding",
"pyspark.pandas.tests.window.test_groupby_expanding_adv",
"pyspark.pandas.tests.test_extension",
"pyspark.pandas.tests.window.test_ewm_error",
"pyspark.pandas.tests.window.test_ewm_mean",
Expand Down Expand Up @@ -1135,7 +1139,11 @@ def __hash__(self):
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling",
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv",
"pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count",
"pyspark.pandas.tests.connect.test_parity_expanding",
"pyspark.pandas.tests.connect.window.test_parity_expanding",
"pyspark.pandas.tests.connect.window.test_parity_expanding_adv",
"pyspark.pandas.tests.connect.window.test_parity_expanding_error",
"pyspark.pandas.tests.connect.window.test_parity_groupby_expanding",
"pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
"pyspark.pandas.tests.connect.computation.test_parity_missing_data",
"pyspark.pandas.tests.connect.groupby.test_parity_index",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
#
import unittest

from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin
from pyspark.pandas.tests.window.test_expanding import ExpandingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
from pyspark.testing.pandasutils import PandasOnSparkTestUtils


class ExpandingParityTests(
ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
ExpandingMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401
from pyspark.pandas.tests.connect.window.test_parity_expanding import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.pandas.tests.window.test_expanding_adv import ExpandingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils


class ExpandingAdvParityTests(
ExpandingAdvMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.window.test_parity_expanding_adv import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.pandas.tests.window.test_expanding_error import ExpandingErrorMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils


class ExpandingErrorParityTests(
ExpandingErrorMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.window.test_parity_expanding_error import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.pandas.tests.window.test_groupby_expanding import GroupByExpandingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils


class GroupByExpandingParityTests(
GroupByExpandingMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding import * # noqa

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.pandas.tests.window.test_groupby_expanding_adv import GroupByExpandingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils


class GroupByExpandingAdvParityTests(
GroupByExpandingAdvMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv import * # noqa

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
96 changes: 96 additions & 0 deletions python/pyspark/pandas/tests/window/test_expanding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pandas as pd

import pyspark.pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase


class ExpandingTestingFuncMixin:
def _test_expanding_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
if isinstance(pd_func, str):
pd_func = self.convert_str_to_lambda(pd_func)
if isinstance(ps_func, str):
ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True)
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True)

# Multiindex
pser = pd.Series(
[1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")])
)
psser = ps.from_pandas(pser)
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)))

pdf = pd.DataFrame(
{"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
)
psdf = ps.from_pandas(pdf)
self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum())

# Multiindex column
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
pdf.columns = columns
psdf.columns = columns
self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))


class ExpandingMixin(ExpandingTestingFuncMixin):
def test_expanding_repr(self):
self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding [min_periods=5]")

def test_expanding_count(self):
self._test_expanding_func("count")

def test_expanding_min(self):
self._test_expanding_func("min")

def test_expanding_max(self):
self._test_expanding_func("max")

def test_expanding_mean(self):
self._test_expanding_func("mean")

def test_expanding_sum(self):
self._test_expanding_func("sum")


class ExpandingTests(
ExpandingMixin,
PandasOnSparkTestCase,
):
pass


if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.window.test_expanding import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
56 changes: 56 additions & 0 deletions python/pyspark/pandas/tests/window/test_expanding_adv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.pandas.tests.window.test_expanding import ExpandingTestingFuncMixin


class ExpandingAdvMixin(ExpandingTestingFuncMixin):
def test_expanding_quantile(self):
self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower"))

def test_expanding_std(self):
self._test_expanding_func("std")

def test_expanding_var(self):
self._test_expanding_func("var")

def test_expanding_skew(self):
self._test_expanding_func("skew")

def test_expanding_kurt(self):
self._test_expanding_func("kurt")


class ExpandingAdvTests(
ExpandingAdvMixin,
PandasOnSparkTestCase,
):
pass


if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.window.test_expanding_adv import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit 2893cd3

Please sign in to comment.