diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index fee9198dff425..22fdde139d281 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -727,7 +727,11 @@ def __hash__(self): "pyspark.pandas.tests.test_dataframe_conversion", "pyspark.pandas.tests.test_dataframe_spark_io", "pyspark.pandas.tests.test_default_index", - "pyspark.pandas.tests.test_expanding", + "pyspark.pandas.tests.window.test_expanding", + "pyspark.pandas.tests.window.test_expanding_adv", + "pyspark.pandas.tests.window.test_expanding_error", + "pyspark.pandas.tests.window.test_groupby_expanding", + "pyspark.pandas.tests.window.test_groupby_expanding_adv", "pyspark.pandas.tests.test_extension", "pyspark.pandas.tests.window.test_ewm_error", "pyspark.pandas.tests.window.test_ewm_mean", @@ -1135,7 +1139,11 @@ def __hash__(self): "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling", "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv", "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count", - "pyspark.pandas.tests.connect.test_parity_expanding", + "pyspark.pandas.tests.connect.window.test_parity_expanding", + "pyspark.pandas.tests.connect.window.test_parity_expanding_adv", + "pyspark.pandas.tests.connect.window.test_parity_expanding_error", + "pyspark.pandas.tests.connect.window.test_parity_groupby_expanding", + "pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv", "pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling", "pyspark.pandas.tests.connect.computation.test_parity_missing_data", "pyspark.pandas.tests.connect.groupby.test_parity_index", diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py similarity index 79% rename from python/pyspark/pandas/tests/connect/test_parity_expanding.py rename to python/pyspark/pandas/tests/connect/window/test_parity_expanding.py index 7f8b1a3cac2f3..ac83a1c3b34c1 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_expanding import ExpandingMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase + ExpandingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_expanding import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py new file mode 100644 index 0000000000000..0baec678beded --- /dev/null +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.window.test_expanding_adv import ExpandingAdvMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class ExpandingAdvParityTests( + ExpandingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.window.test_parity_expanding_adv import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py new file mode 100644 index 0000000000000..a8531a02799c0 --- /dev/null +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.window.test_expanding_error import ExpandingErrorMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class ExpandingErrorParityTests( + ExpandingErrorMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.window.test_parity_expanding_error import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py new file mode 100644 index 0000000000000..356bc5298264c --- /dev/null +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.window.test_groupby_expanding import GroupByExpandingMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class GroupByExpandingParityTests( + GroupByExpandingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding import * # noqa + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py new file mode 100644 index 0000000000000..b743e335b154e --- /dev/null +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.window.test_groupby_expanding_adv import GroupByExpandingAdvMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class GroupByExpandingAdvParityTests( + GroupByExpandingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv import * # noqa + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_expanding.py b/python/pyspark/pandas/tests/window/test_expanding.py new file mode 100644 index 0000000000000..ebe54ff217197 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_expanding.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class ExpandingTestingFuncMixin: + def _test_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) + pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + + # Multiindex + pser = pd.Series( + [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) + ) + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) + + pdf = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + + +class ExpandingMixin(ExpandingTestingFuncMixin): + def test_expanding_repr(self): + self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding [min_periods=5]") + + def test_expanding_count(self): + self._test_expanding_func("count") + + def test_expanding_min(self): + self._test_expanding_func("min") + + def test_expanding_max(self): + self._test_expanding_func("max") + + def test_expanding_mean(self): + self._test_expanding_func("mean") + + def test_expanding_sum(self): + self._test_expanding_func("sum") + + +class ExpandingTests( + ExpandingMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_expanding import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_expanding_adv.py b/python/pyspark/pandas/tests/window/test_expanding_adv.py new file mode 100644 index 0000000000000..e537f1ecfbc05 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_expanding_adv.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.window.test_expanding import ExpandingTestingFuncMixin + + +class ExpandingAdvMixin(ExpandingTestingFuncMixin): + def test_expanding_quantile(self): + self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) + + def test_expanding_std(self): + self._test_expanding_func("std") + + def test_expanding_var(self): + self._test_expanding_func("var") + + def test_expanding_skew(self): + self._test_expanding_func("skew") + + def test_expanding_kurt(self): + self._test_expanding_func("kurt") + + +class ExpandingAdvTests( + ExpandingAdvMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_expanding_adv import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_expanding_error.py b/python/pyspark/pandas/tests/window/test_expanding_error.py new file mode 100644 index 0000000000000..fa888f5f1696d --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_expanding_error.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pyspark.pandas as ps +from pyspark.pandas.window import Expanding +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class ExpandingErrorMixin: + def test_expanding_error(self): + with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): + ps.range(10).expanding(-1) + + with self.assertRaisesRegex( + TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" + ): + Expanding(1, 2) + + +class ExpandingErrorTests( + ExpandingErrorMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_expanding_error import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/window/test_groupby_expanding.py similarity index 56% rename from python/pyspark/pandas/tests/test_expanding.py rename to python/pyspark/pandas/tests/window/test_groupby_expanding.py index 5166f8132665b..44fecd7e58eb9 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/window/test_groupby_expanding.py @@ -19,85 +19,10 @@ import pandas as pd import pyspark.pandas as ps -from pyspark.pandas.window import Expanding from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class ExpandingTestsMixin: - def _test_expanding_func(self, ps_func, pd_func=None): - if not pd_func: - pd_func = ps_func - if isinstance(pd_func, str): - pd_func = self.convert_str_to_lambda(pd_func) - if isinstance(ps_func, str): - ps_func = self.convert_str_to_lambda(ps_func) - pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) - - # Multiindex - pser = pd.Series( - [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) - ) - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) - - pdf = pd.DataFrame( - {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) - self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) - - def test_expanding_error(self): - with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): - ps.range(10).expanding(-1) - - with self.assertRaisesRegex( - TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" - ): - Expanding(1, 2) - - def test_expanding_repr(self): - self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding [min_periods=5]") - - def test_expanding_count(self): - self._test_expanding_func("count") - - def test_expanding_min(self): - self._test_expanding_func("min") - - def test_expanding_max(self): - self._test_expanding_func("max") - - def test_expanding_mean(self): - self._test_expanding_func("mean") - - def test_expanding_quantile(self): - self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) - - def test_expanding_sum(self): - self._test_expanding_func("sum") - - def test_expanding_std(self): - self._test_expanding_func("std") - - def test_expanding_var(self): - self._test_expanding_func("var") - - def test_expanding_skew(self): - self._test_expanding_func("skew") - - def test_expanding_kurt(self): - self._test_expanding_func("kurt") - +class GroupByExpandingTestingFuncMixin: def _test_groupby_expanding_func(self, ps_func, pd_func=None): if not pd_func: pd_func = ps_func @@ -172,6 +97,8 @@ def _test_groupby_expanding_func(self, ps_func, pd_func=None): pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), ) + +class GroupByExpandingMixin(GroupByExpandingTestingFuncMixin): def test_groupby_expanding_count(self): self._test_groupby_expanding_func("count") @@ -184,34 +111,20 @@ def test_groupby_expanding_max(self): def test_groupby_expanding_mean(self): self._test_groupby_expanding_func("mean") - def test_groupby_expanding_quantile(self): - self._test_groupby_expanding_func( - lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") - ) - def test_groupby_expanding_sum(self): self._test_groupby_expanding_func("sum") - def test_groupby_expanding_std(self): - self._test_groupby_expanding_func("std") - - def test_groupby_expanding_var(self): - self._test_groupby_expanding_func("var") - - def test_groupby_expanding_skew(self): - self._test_groupby_expanding_func("skew") - - def test_groupby_expanding_kurt(self): - self._test_groupby_expanding_func("kurt") - -class ExpandingTests(ExpandingTestsMixin, PandasOnSparkTestCase, TestUtils): +class GroupByExpandingTests( + GroupByExpandingMixin, + PandasOnSparkTestCase, +): pass if __name__ == "__main__": import unittest - from pyspark.pandas.tests.test_expanding import * # noqa: F401 + from pyspark.pandas.tests.window.test_groupby_expanding import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py new file mode 100644 index 0000000000000..22cb03dc0ff32 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.tests.window.test_groupby_expanding import GroupByExpandingTestingFuncMixin + + +class GroupByExpandingAdvMixin(GroupByExpandingTestingFuncMixin): + def test_groupby_expanding_quantile(self): + self._test_groupby_expanding_func( + lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") + ) + + def test_groupby_expanding_std(self): + self._test_groupby_expanding_func("std") + + def test_groupby_expanding_var(self): + self._test_groupby_expanding_func("var") + + def test_groupby_expanding_skew(self): + self._test_groupby_expanding_func("skew") + + def test_groupby_expanding_kurt(self): + self._test_groupby_expanding_func("kurt") + + +class GroupByExpandingAdvTests( + GroupByExpandingAdvMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_groupby_expanding_adv import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)