From 2a5a92e3f6bcffef34c9e58ac294912b5097c1ac Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 25 Aug 2024 09:20:57 -0400 Subject: [PATCH 1/3] Run ruff format in CI --- .github/workflows/build.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0b974f2d..88ef8202 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,9 @@ jobs: pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff - run: ruff check --output-format=github python/ + run: | + ruff check --output-format=github python/ + ruff format python/ generate-license: runs-on: ubuntu-latest From 57bc8b9816c1134f19f081371495e3d954f486e3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 25 Aug 2024 09:28:20 -0400 Subject: [PATCH 2/3] Add --check parameter --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 88ef8202..a4f8b2da 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,7 +40,7 @@ jobs: - name: Run Ruff run: | ruff check --output-format=github python/ - ruff format python/ + ruff format --check python/ generate-license: runs-on: ubuntu-latest From 1dcb7b7f8e0f9e6539267a5f375807b81a5a96e0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 25 Aug 2024 09:32:06 -0400 Subject: [PATCH 3/3] Apply ruff format --- python/datafusion/functions.py | 9 +- python/datafusion/tests/test_aggregation.py | 105 ++++++++------ python/datafusion/tests/test_dataframe.py | 90 ++++++------ python/datafusion/tests/test_expr.py | 14 +- python/datafusion/tests/test_functions.py | 150 +++++++++++++------- python/datafusion/tests/test_sql.py | 32 +++-- 6 files changed, 248 insertions(+), 152 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 2d3d87ee..59a1974f 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1479,12 +1479,17 @@ def approx_percentile_cont( """Returns the value that is approximately at a given percentile of ``expr``.""" if num_centroids is None: return Expr( - f.approx_percentile_cont(expression.expr, percentile.expr, distinct=distinct, num_centroids=None) + f.approx_percentile_cont( + expression.expr, percentile.expr, distinct=distinct, num_centroids=None + ) ) return Expr( f.approx_percentile_cont( - expression.expr, percentile.expr, distinct=distinct, num_centroids=num_centroids.expr + expression.expr, + percentile.expr, + distinct=distinct, + num_centroids=num_centroids.expr, ) ) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index 03485da4..ab653c40 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -39,6 +39,7 @@ def df(): ) return ctx.create_dataframe([[batch]]) + @pytest.fixture def df_aggregate_100(): ctx = SessionContext() @@ -46,32 +47,46 @@ def df_aggregate_100(): return ctx.table("aggregate_test_data") -@pytest.mark.parametrize("agg_expr, calc_expected", [ - (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), - (f.corr(column("a"), column("b")), lambda a, b, c, d: np.array(np.corrcoef(a, b)[0][1])), - (f.count(column("a")), lambda a, b, c, d: pa.array([len(a)])), - # Sample (co)variance -> ddof=1 - # Population (co)variance -> ddof=0 - (f.covar(column("a"), column("b")), lambda a, b, c, d: np.array(np.cov(a, b, ddof=1)[0][1])), - (f.covar_pop(column("a"), column("c")), lambda a, b, c, d: np.array(np.cov(a, c, ddof=0)[0][1])), - (f.covar_samp(column("b"), column("c")), lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1])), - # f.grouping(col_a), # No physical plan implemented yet - (f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))), - (f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))), - (f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))), - (f.min(column("a")), lambda a, b, c, d: np.array(np.min(a))), - (f.sum(column("b")), lambda a, b, c, d: np.array(np.sum(b.to_pylist()))), - # Sample stdev -> ddof=1 - # Population stdev -> ddof=0 - (f.stddev(column("a")), lambda a, b, c, d: np.array(np.std(a, ddof=1))), - (f.stddev_pop(column("b")), lambda a, b, c, d: np.array(np.std(b, ddof=0))), - (f.stddev_samp(column("c")), lambda a, b, c, d: np.array(np.std(c, ddof=1))), - (f.var(column("a")), lambda a, b, c, d: np.array(np.var(a, ddof=1))), - (f.var_pop(column("b")), lambda a, b, c, d: np.array(np.var(b, ddof=0))), - (f.var_samp(column("c")), lambda a, b, c, d: np.array(np.var(c, ddof=1))), -]) +@pytest.mark.parametrize( + "agg_expr, calc_expected", + [ + (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), + ( + f.corr(column("a"), column("b")), + lambda a, b, c, d: np.array(np.corrcoef(a, b)[0][1]), + ), + (f.count(column("a")), lambda a, b, c, d: pa.array([len(a)])), + # Sample (co)variance -> ddof=1 + # Population (co)variance -> ddof=0 + ( + f.covar(column("a"), column("b")), + lambda a, b, c, d: np.array(np.cov(a, b, ddof=1)[0][1]), + ), + ( + f.covar_pop(column("a"), column("c")), + lambda a, b, c, d: np.array(np.cov(a, c, ddof=0)[0][1]), + ), + ( + f.covar_samp(column("b"), column("c")), + lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1]), + ), + # f.grouping(col_a), # No physical plan implemented yet + (f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))), + (f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))), + (f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))), + (f.min(column("a")), lambda a, b, c, d: np.array(np.min(a))), + (f.sum(column("b")), lambda a, b, c, d: np.array(np.sum(b.to_pylist()))), + # Sample stdev -> ddof=1 + # Population stdev -> ddof=0 + (f.stddev(column("a")), lambda a, b, c, d: np.array(np.std(a, ddof=1))), + (f.stddev_pop(column("b")), lambda a, b, c, d: np.array(np.std(b, ddof=0))), + (f.stddev_samp(column("c")), lambda a, b, c, d: np.array(np.std(c, ddof=1))), + (f.var(column("a")), lambda a, b, c, d: np.array(np.var(a, ddof=1))), + (f.var_pop(column("b")), lambda a, b, c, d: np.array(np.var(b, ddof=0))), + (f.var_samp(column("c")), lambda a, b, c, d: np.array(np.var(c, ddof=1))), + ], +) def test_aggregation_stats(df, agg_expr, calc_expected): - agg_df = df.aggregate([], [agg_expr]) result = agg_df.collect()[0] values_a, values_b, values_c, values_d = df.collect()[0] @@ -79,16 +94,19 @@ def test_aggregation_stats(df, agg_expr, calc_expected): np.testing.assert_array_almost_equal(result.column(0), expected) -@pytest.mark.parametrize("agg_expr, expected", [ - (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64())), - (f.approx_median(column("b")), pa.array([4])), - (f.approx_percentile_cont(column("b"), lit(0.5)), pa.array([4])), - ( - f.approx_percentile_cont_with_weight(column("b"), lit(0.6), lit(0.5)), - pa.array([6], type=pa.float64()) - ), - (f.array_agg(column("b")), pa.array([[4, 4, 6]])), -]) +@pytest.mark.parametrize( + "agg_expr, expected", + [ + (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64())), + (f.approx_median(column("b")), pa.array([4])), + (f.approx_percentile_cont(column("b"), lit(0.5)), pa.array([4])), + ( + f.approx_percentile_cont_with_weight(column("b"), lit(0.6), lit(0.5)), + pa.array([6], type=pa.float64()), + ), + (f.array_agg(column("b")), pa.array([[4, 4, 6]])), + ], +) def test_aggregation(df, agg_expr, expected): agg_df = df.aggregate([], [agg_expr]) result = agg_df.collect()[0] @@ -98,20 +116,21 @@ def test_aggregation(df, agg_expr, expected): def test_aggregate_100(df_aggregate_100): # https://github.com/apache/datafusion/blob/bddb6415a50746d2803dd908d19c3758952d74f9/datafusion/sqllogictest/test_files/aggregate.slt#L1490-L1498 - result = df_aggregate_100.aggregate( - [ - column("c1") - ], - [ - f.approx_percentile_cont(column("c3"), lit(0.95), lit(200)).alias("c3") - ] - ).sort(column("c1").sort(ascending=True)).collect() + result = ( + df_aggregate_100.aggregate( + [column("c1")], + [f.approx_percentile_cont(column("c3"), lit(0.95), lit(200)).alias("c3")], + ) + .sort(column("c1").sort(ascending=True)) + .collect() + ) assert len(result) == 1 result = result[0] assert result.column("c1") == pa.array(["a", "b", "c", "d", "e"]) assert result.column("c3") == pa.array([73, 68, 122, 124, 115]) + def test_bit_add_or_xor(df): df = df.aggregate( [], diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 6444d932..e5e0c9c8 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -279,57 +279,67 @@ def test_distinct(): data_test_window_functions = [ - ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]), - ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]), - ("dense_rank", f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2] ), - ("percent_rank", f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), [0.5, 0, 0.5]), - ("cume_dist", f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), [0.3333333333333333, 0.6666666666666666, 1.0]), - ("ntile", f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), [1, 1, 2]), - ("next", f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), [5, 6, None]), - ("previous", f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), [None, 4, 5]), - pytest.param( - "first_value", - f.window( + ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]), + ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]), + ( + "dense_rank", + f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), + [2, 1, 2], + ), + ( + "percent_rank", + f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), + [0.5, 0, 0.5], + ), + ( + "cume_dist", + f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), + [0.3333333333333333, 0.6666666666666666, 1.0], + ), + ( + "ntile", + f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), + [1, 1, 2], + ), + ( + "next", + f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), + [5, 6, None], + ), + ( + "previous", + f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), + [None, 4, 5], + ), + pytest.param( "first_value", - [column("a")], - order_by=[f.order_by(column("b"))] + f.window("first_value", [column("a")], order_by=[f.order_by(column("b"))]), + [1, 1, 1], + ), + pytest.param( + "last_value", + f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), + [4, 5, 6], ), - [1, 1, 1], - ), - pytest.param( - "last_value", - f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), - [4, 5, 6], - ), - pytest.param( - "2nd_value", - f.window( - "nth_value", - [column("b"), literal(2)], - order_by=[f.order_by(column("b"))], + pytest.param( + "2nd_value", + f.window( + "nth_value", + [column("b"), literal(2)], + order_by=[f.order_by(column("b"))], + ), + [None, 5, 5], ), - [None, 5, 5], - ), ] @pytest.mark.parametrize("name,expr,result", data_test_window_functions) def test_window_functions(df, name, expr, result): - df = df.select( - column("a"), - column("b"), - column("c"), - f.alias(expr, name) - ) + df = df.select(column("a"), column("b"), column("c"), f.alias(expr, name)) table = pa.Table.from_batches(df.collect()) - expected = { - "a": [1, 2, 3], - "b": [4, 5, 6], - "c": [8, 5, 8], - name: result - } + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8], name: result} assert table.sort_by("a").to_pydict() == expected diff --git a/python/datafusion/tests/test_expr.py b/python/datafusion/tests/test_expr.py index 1a41120a..9071108c 100644 --- a/python/datafusion/tests/test_expr.py +++ b/python/datafusion/tests/test_expr.py @@ -146,24 +146,26 @@ def test_expr_to_variant(): from datafusion import SessionContext from datafusion.expr import Filter - def traverse_logical_plan(plan): cur_node = plan.to_variant() if isinstance(cur_node, Filter): return cur_node.predicate().to_variant() - if hasattr(plan, 'inputs'): + if hasattr(plan, "inputs"): for input_plan in plan.inputs(): res = traverse_logical_plan(input_plan) if res is not None: return res ctx = SessionContext() - data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']} - ctx.from_pydict(data, name='table1') + data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} + ctx.from_pydict(data, name="table1") query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')" logical_plan = ctx.sql(query).optimized_logical_plan() variant = traverse_logical_plan(logical_plan) assert variant is not None - assert variant.expr().to_variant().qualified_name() == 'table1.name' - assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' + assert variant.expr().to_variant().qualified_name() == "table1.name" + assert ( + str(variant.list()) + == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' + ) assert not variant.negated() diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index b8ad9c0d..732136ea 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -567,45 +567,86 @@ def test_array_function_obj_tests(stmt, py_expr): assert a == b -@pytest.mark.parametrize("function, expected_result", [ - (f.ascii(column("a")), pa.array([72, 87, 33], type=pa.int32())), # H = 72; W = 87; ! = 33 - (f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())), - (f.btrim(literal(" World ")), pa.array(["World", "World", "World"])), - (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), - (f.chr(literal(68)), pa.array(["D", "D", "D"])), - (f.concat_ws("-", column("a"), literal("test")), pa.array(["Hello-test", "World-test", "!-test"])), - (f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])), - (f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])), - (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])), - (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())), - (f.lower(column("a")), pa.array(["hello", "world", "!"])), - (f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])), - (f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])), - (f.md5(column("a")), pa.array([ - "8b1a9953c4611296a827abf8c47804d7", - "f5a7924e621e84c9280a9a27e1bcb7f6", - "9033e0e305f247c0c3c80d0c7848c8b3", - ])), - (f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), - (f.repeat(column("a"), literal(2)), pa.array(["HelloHello", "WorldWorld", "!!"])), - (f.replace(column("a"), literal("l"), literal("?")), pa.array(["He??o", "Wor?d", "!"])), - (f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])), - (f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])), - (f.rpad(column("a"), literal(8)), pa.array(["Hello ", "World ", "! "])), - (f.rtrim(column("c")), pa.array(["hello", " world", " !"])), - (f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"])), - (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), - (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), - (f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])), - (f.translate(column("a"), literal("or"), literal("ld")), pa.array(["Helll", "Wldld", "!"])), - (f.trim(column("c")), pa.array(["hello", "world", "!"])), - (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])), - (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])), - (f.overlay(column("a"), literal("--"), literal(2)), pa.array(["H--lo", "W--ld", "--"])), - (f.regexp_like(column("a"), literal("(ell|orl)")), pa.array([True, True, False])), - (f.regexp_match(column("a"), literal("(ell|orl)")), pa.array([["ell"], ["orl"], None])), - (f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), pa.array(["H-o", "W-d", "!"])), -]) +@pytest.mark.parametrize( + "function, expected_result", + [ + ( + f.ascii(column("a")), + pa.array([72, 87, 33], type=pa.int32()), + ), # H = 72; W = 87; ! = 33 + (f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())), + (f.btrim(literal(" World ")), pa.array(["World", "World", "World"])), + (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), + (f.chr(literal(68)), pa.array(["D", "D", "D"])), + ( + f.concat_ws("-", column("a"), literal("test")), + pa.array(["Hello-test", "World-test", "!-test"]), + ), + (f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])), + (f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])), + (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])), + (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())), + (f.lower(column("a")), pa.array(["hello", "world", "!"])), + (f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])), + (f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])), + ( + f.md5(column("a")), + pa.array( + [ + "8b1a9953c4611296a827abf8c47804d7", + "f5a7924e621e84c9280a9a27e1bcb7f6", + "9033e0e305f247c0c3c80d0c7848c8b3", + ] + ), + ), + (f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), + ( + f.repeat(column("a"), literal(2)), + pa.array(["HelloHello", "WorldWorld", "!!"]), + ), + ( + f.replace(column("a"), literal("l"), literal("?")), + pa.array(["He??o", "Wor?d", "!"]), + ), + (f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])), + (f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])), + ( + f.rpad(column("a"), literal(8)), + pa.array(["Hello ", "World ", "! "]), + ), + (f.rtrim(column("c")), pa.array(["hello", " world", " !"])), + ( + f.split_part(column("a"), literal("l"), literal(1)), + pa.array(["He", "Wor", "!"]), + ), + (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), + (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), + (f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])), + ( + f.translate(column("a"), literal("or"), literal("ld")), + pa.array(["Helll", "Wldld", "!"]), + ), + (f.trim(column("c")), pa.array(["hello", "world", "!"])), + (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])), + (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])), + ( + f.overlay(column("a"), literal("--"), literal(2)), + pa.array(["H--lo", "W--ld", "--"]), + ), + ( + f.regexp_like(column("a"), literal("(ell|orl)")), + pa.array([True, True, False]), + ), + ( + f.regexp_match(column("a"), literal("(ell|orl)")), + pa.array([["ell"], ["orl"], None]), + ), + ( + f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), + pa.array(["H-o", "W-d", "!"]), + ), + ], +) def test_string_functions(df, function, expected_result): df = df.select(function) result = df.collect() @@ -849,27 +890,30 @@ def test_regr_funcs_sql_2(): assert result_sql[0].column(8) == pa.array([4], type=pa.float64()) -@pytest.mark.parametrize("func, expected", [ - pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"), - pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"), - pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"), - pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"), - pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"), - pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"), - pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"), - pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"), - pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy") -]) +@pytest.mark.parametrize( + "func, expected", + [ + pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"), + pytest.param( + f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept" + ), + pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"), + pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"), + pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"), + pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"), + pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"), + pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"), + pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy"), + ], +) def test_regr_funcs_df(func, expected): - # test case based on `regr_*() basic tests # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1 - ctx = SessionContext() # Create a DataFrame - data = {'column1': [1, 2, 3], 'column2': [2, 4, 6]} + data = {"column1": [1, 2, 3], "column2": [2, 4, 6]} df = ctx.from_pydict(data, name="test_table") # Perform the regression function using DataFrame API diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index 1505fb1e..e41d0100 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -381,15 +381,31 @@ def test_udf( id="binary4", marks=pytest.mark.xfail, ), - pytest.param(helpers.data_datetime("s"), id="datetime_s", marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("ms"), id="datetime_ms", marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("us"), id="datetime_us", marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("ns"), id="datetime_ns", marks=pytest.mark.xfail), + pytest.param( + helpers.data_datetime("s"), id="datetime_s", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_datetime("ms"), id="datetime_ms", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_datetime("us"), id="datetime_us", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_datetime("ns"), id="datetime_ns", marks=pytest.mark.xfail + ), # Not writtable to parquet - pytest.param(helpers.data_timedelta("s"), id="timedelta_s", marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("ms"), id="timedelta_ms", marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("us"), id="timedelta_us", marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("ns"), id="timedelta_ns", marks=pytest.mark.xfail), + pytest.param( + helpers.data_timedelta("s"), id="timedelta_s", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_timedelta("ms"), id="timedelta_ms", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_timedelta("us"), id="timedelta_us", marks=pytest.mark.xfail + ), + pytest.param( + helpers.data_timedelta("ns"), id="timedelta_ns", marks=pytest.mark.xfail + ), ], ) def test_simple_select(ctx, tmp_path, arr):