diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 8107d51b..da712790 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -278,63 +278,50 @@ def test_distinct(): assert df_a.collect() == df_b.collect() -def test_window_functions(df): +test_data_window_functions = [ + ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]), + ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]), + ("dense_rank", f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2] ), + ("percent_rank", f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), [0.5, 0, 0.5]), + ("cume_dist", f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), [0.3333333333333333, 0.6666666666666666, 1.0]), + ("ntile", f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), [1, 1, 2]), + ("next", f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), [5, 6, None]), + ("previous", f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), [None, 4, 5]), + pytest.param( + "first_value", + f.window( + "first_value", + [column("a")], + order_by=[f.order_by(column("b"))] + ), + [1, 1, 1], + marks=pytest.mark.xfail, + ), + pytest.param( + "last_value", + f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), + [4, 5, 6], + marks=pytest.mark.xfail, + ), + pytest.param( + "2nd_value", + f.window( + "nth_value", + [column("b"), literal(2)], + order_by=[f.order_by(column("b"))], + ), + [None, 5, 5], + ), +] + + +@pytest.mark.parametrize("name,expr,result", test_data_window_functions) +def test_window_functions(df, name, expr, result): df = df.select( column("a"), column("b"), column("c"), - f.alias( - f.window("row_number", [], order_by=[f.order_by(column("c"))]), - "row", - ), - f.alias( - f.window("rank", [], order_by=[f.order_by(column("c"))]), - "rank", - ), - f.alias( - f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), - "dense_rank", - ), - f.alias( - f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), - "percent_rank", - ), - f.alias( - f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), - "cume_dist", - ), - f.alias( - f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), - "ntile", - ), - f.alias( - f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), - "previous", - ), - f.alias( - f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), - "next", - ), - f.alias( - f.window( - "first_value", - [column("a")], - order_by=[f.order_by(column("b"))], - ), - "first_value", - ), - f.alias( - f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), - "last_value", - ), - f.alias( - f.window( - "nth_value", - [column("b"), literal(2)], - order_by=[f.order_by(column("b"))], - ), - "2nd_value", - ), + f.alias(expr, name) ) table = pa.Table.from_batches(df.collect()) @@ -343,18 +330,9 @@ def test_window_functions(df): "a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8], - "row": [2, 1, 3], - "rank": [2, 1, 2], - "dense_rank": [2, 1, 2], - "percent_rank": [0.5, 0, 0.5], - "cume_dist": [0.3333333333333333, 0.6666666666666666, 1.0], - "ntile": [1, 1, 2], - "next": [5, 6, None], - "previous": [None, 4, 5], - "first_value": [1, 1, 1], - "last_value": [4, 5, 6], - "2nd_value": [None, 5, 5], + name: result } + assert table.sort_by("a").to_pydict() == expected