Skip to content

Commit

Permalink
test regr_ function wrappers
Browse files Browse the repository at this point in the history
Closes apache#778
  • Loading branch information
Michael-J-Ward committed Jul 29, 2024
1 parent e17ba64 commit d293398
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def test_case(df):
assert result.column(2) == pa.array(["Hola", "Mundo", None])


def test_regr_funcs(df):
def test_regr_funcs_sql(df):
# test case base on
# https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
ctx = SessionContext()
Expand All @@ -817,6 +817,68 @@ def test_regr_funcs(df):
assert result[0].column(8) == pa.array([0], type=pa.float64())


def test_regr_funcs_sql_2():
# test case based on `regr_*() basic tests
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
ctx = SessionContext()

# Perform the regression functions using SQL
result_sql = ctx.sql(
"select "
"regr_slope(column2, column1), "
"regr_intercept(column2, column1), "
"regr_count(column2, column1), "
"regr_r2(column2, column1), "
"regr_avgx(column2, column1), "
"regr_avgy(column2, column1), "
"regr_sxx(column2, column1), "
"regr_syy(column2, column1), "
"regr_sxy(column2, column1) "
"from (values (1,2), (2,4), (3,6))"
).collect()

# Assertions for SQL results
assert result_sql[0].column(0) == pa.array([2], type=pa.float64())
assert result_sql[0].column(1) == pa.array([0], type=pa.float64())
assert result_sql[0].column(2) == pa.array([3], type=pa.float64()) # todo: i would not expect this to be float
assert result_sql[0].column(3) == pa.array([1], type=pa.float64())
assert result_sql[0].column(4) == pa.array([2], type=pa.float64())
assert result_sql[0].column(5) == pa.array([4], type=pa.float64())
assert result_sql[0].column(6) == pa.array([2], type=pa.float64())
assert result_sql[0].column(7) == pa.array([8], type=pa.float64())
assert result_sql[0].column(8) == pa.array([4], type=pa.float64())


@pytest.mark.parametrize("func, expected", [
pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"),
pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"),
pytest.param(f.regr_count, pa.array([3], type=pa.float64()), id="regr_count"), # TODO: I would expect this to return an int array
pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"),
pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"),
pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"),
pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"),
pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"),
pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy")
])
def test_regr_funcs_df(func, expected):

# test case based on `regr_*() basic tests
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1


ctx = SessionContext()

# Create a DataFrame
data = {'column1': [1, 2, 3], 'column2': [2, 4, 6]}
df = ctx.from_pydict(data, name="test_table")

# Perform the regression function using DataFrame API
result_df = df.aggregate([], [func(f.col("column2"), f.col("column1"))]).collect()

# Assertion for DataFrame API result
assert result_df[0].column(0) == expected


def test_first_last_value(df):
df = df.aggregate(
[],
Expand Down

0 comments on commit d293398

Please sign in to comment.