diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 8a320c42..2f9e6d61 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1326,7 +1326,7 @@ def grouping(arg: Expr, distinct: bool = False) -> Expr: Returns 1 if the value of the argument is aggregated, 0 if not. """ - return Expr(f.grouping([arg.expr], distinct=distinct)) + return Expr(f.grouping(arg.expr, distinct=distinct)) def max(arg: Expr, distinct: bool = False) -> Expr: diff --git a/src/functions.rs b/src/functions.rs index ce0d2bf8..81a09255 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -108,6 +108,16 @@ pub fn corr(y: PyExpr, x: PyExpr, distinct: bool) -> PyResult { } } +#[pyfunction] +pub fn grouping(expression: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::grouping(expression.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + #[pyfunction] pub fn sum(args: PyExpr) -> PyExpr { functions_aggregate::expr_fn::sum(args.expr).into() @@ -799,7 +809,6 @@ array_fn!(flatten, array); array_fn!(range, start stop step); aggregate_function!(array_agg, ArrayAgg); -aggregate_function!(grouping, Grouping); aggregate_function!(max, Max); aggregate_function!(mean, Avg); aggregate_function!(min, Min);