Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support negative preceding/following for ROW-based window functions #9229

Merged
merged 17 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,15 +1462,13 @@ def test_to_date_with_window_functions():
_grpkey_longs_with_nullable_decimals,
_grpkey_longs_with_nullable_larger_decimals
], ids=idfn)
@pytest.mark.parametrize('window_spec', [
"3 PRECEDING AND -1 FOLLOWING",
"-2 PRECEDING AND 4 FOLLOWING",
"UNBOUNDED PRECEDING AND -1 FOLLOWING",
"-1 PRECEDING AND UNBOUNDED FOLLOWING",
"10 PRECEDING AND -1 FOLLOWING",
"5 PRECEDING AND -2 FOLLOWING"
], ids=idfn)
def test_window_aggs_for_negative_rows_partitioned_2(data_gen, batch_size, window_spec):
@pytest.mark.parametrize('window_spec', ["3 PRECEDING AND -1 FOLLOWING",
"-2 PRECEDING AND 4 FOLLOWING",
"UNBOUNDED PRECEDING AND -1 FOLLOWING",
"-1 PRECEDING AND UNBOUNDED FOLLOWING",
"10 PRECEDING AND -1 FOLLOWING",
"5 PRECEDING AND -2 FOLLOWING"], ids=idfn)
def test_window_aggs_for_negative_rows_partitioned(data_gen, batch_size, window_spec):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}
assert_gpu_and_cpu_are_equal_sql(
Expand All @@ -1481,16 +1479,18 @@ def test_window_aggs_for_negative_rows_partitioned_2(data_gen, batch_size, windo
' (PARTITION BY a ORDER BY b,c ASC ROWS BETWEEN {window}) AS sum_c_asc, '
' MAX(c) OVER '
' (PARTITION BY a ORDER BY b DESC, c DESC ROWS BETWEEN {window}) AS max_c_desc, '
' MIN(c) over '
' MIN(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS min_c_asc, '
' COUNT(1) over '
' COUNT(1) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_1, '
' COUNT(c) over '
' COUNT(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_c, '
' AVG(c) over '
' AVG(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS avg_c, '
' COLLECT_LIST(c) over '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS list_c '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS list_c, '
' SORT_ARRAY(COLLECT_LIST(c) OVER '
revans2 marked this conversation as resolved.
Show resolved Hide resolved
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window})) AS sorted_set_c '
'FROM window_agg_table '.format(window=window_spec),
conf=conf)

Expand Down Expand Up @@ -1519,16 +1519,18 @@ def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size):
' (ORDER BY b,c,a ROWS BETWEEN 3 PRECEDING AND -1 FOLLOWING) AS sum_c_asc, '
' MAX(c) OVER '
' (ORDER BY b DESC, c DESC, a DESC ROWS BETWEEN -2 PRECEDING AND 4 FOLLOWING) AS max_c_desc, '
' min(c) over '
' min(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN UNBOUNDED PRECEDING AND -1 FOLLOWING) AS min_c_asc, '
' COUNT(1) over '
' COUNT(1) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS count_1, '
' COUNT(c) over '
' COUNT(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN 10 PRECEDING AND -1 FOLLOWING) AS count_c, '
' AVG(c) over '
' AVG(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS avg_c, '
' COLLECT_LIST(c) over '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING) AS list_c '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING) AS list_c, '
' SORT_ARRAY(COLLECT_SET(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING)) AS set_c '
'FROM window_agg_table ',
conf=conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.rapids.{GpuAggregateExpression, GpuCollectList, GpuCollectSet, GpuCount}
import org.apache.spark.sql.rapids.GpuAggregateExpression
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -754,14 +754,6 @@ object GroupedAggregations {
}
}

private def getMinPeriodsFor(boundGpuWindowFunction: BoundGpuWindowFunction): Int =
boundGpuWindowFunction.windowFunc match {
case GpuCount(_, _) => 0
case GpuCollectList(_, _, _) => 0
case GpuCollectSet(_, _, _) => 0
case _ => 1
}

private def isUnbounded(boundary: Expression): Boolean = boundary match {
case special: GpuSpecialFrameBoundary => special.isUnbounded
case _ => false
Expand Down Expand Up @@ -939,22 +931,19 @@ class GroupedAggregations {
// For now I am going to assume that we don't need to combine calls across frame specs
// because it would just not help that much
val result = {
val functionsSeq = functions.toIndexedSeq
val allWindowOpts = {for (i <- 0 until functionsSeq.size) yield
val allWindowOpts = functions.map { f =>
getWindowOptions(boundOrderSpec, orderByPositions, frameSpec,
getMinPeriodsFor(functionsSeq(i)._1))
f._1.windowFunc.getMinPeriods)
}
withResource(allWindowOpts.toIndexedSeq) { allWindowOpts =>
val allAggs = {
for (i <- 0 until functionsSeq.size) yield
functionsSeq(i)._1.aggOverWindow(inputCb, allWindowOpts(i))
}.toIndexedSeq
withResource(allWindowOpts.toSeq) { allWindowOpts =>
val allAggs = allWindowOpts.zip(functions).map { case (windowOpt, f) =>
f._1.aggOverWindow(inputCb, windowOpt)
}
withResource(GpuColumnVector.from(inputCb)) { initProjTab =>
aggIt(initProjTab.groupBy(partByPositions: _*), allAggs)
}
}
}

withResource(result) { result =>
functions.zipWithIndex.foreach {
case ((func, outputIndexes), resultIndex) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, Count, Max, Min, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, Max, Min, Sum}
import org.apache.spark.sql.rapids.{AddOverflowChecks, GpuAggregateExpression, GpuCount, GpuCreateNamedStruct, GpuDivide, GpuSubtract}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -92,6 +92,7 @@ abstract class GpuWindowExpressionMetaBase(
windowFunction.asInstanceOf[AggregateExpression].aggregateFunction match {
case _: Average => // Supported
case _: CollectList => // Supported
case _: CollectSet => // Supported
case _: Count => // Supported
case _: Max => // Supported
case _: Min => // Supported
Expand Down Expand Up @@ -661,7 +662,15 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary)

// This is here for now just to tag an expression as being a GpuWindowFunction and match
// Spark. This may expand in the future if other types of window functions show up.
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression {
/**
* Get "min-periods" value, i.e. the minimum number of periods/rows
* above which a non-null value is returned for the function.
* Otherwise, null is returned.
* @return Non-negative value for min-periods.
*/
def getMinPeriods: Int = 1
}

/**
* This is a special window function that simply replaces itself with one or more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,13 @@ case class GpuCount(children: Seq[Expression],

override def newUnboundedToUnboundedFixer: BatchedUnboundedToUnboundedWindowFixer =
new CountUnboundedToUnboundedFixer(failOnError)

// minPeriods should be 0.
// Consider the following rows:
// v = [ 0, 1, 2, 3, 4, 5 ]
// A `COUNT` window aggregation over (2, -1) should yield 0, not null,
// for the first row.
override def getMinPeriods: Int = 0
}

object GpuAverage {
Expand Down Expand Up @@ -1961,6 +1968,12 @@ case class GpuCollectList(
override def windowAggregation(
inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn =
RollingAggregation.collectList().onColumn(inputs.head._2)

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_LIST` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

/**
Expand Down Expand Up @@ -1995,6 +2008,12 @@ case class GpuCollectSet(
RollingAggregation.collectSet(NullPolicy.EXCLUDE, NullEquality.EQUAL,
NaNEquality.ALL_EQUAL).onColumn(inputs.head._2)
}

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_SET` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

trait CpuToGpuAggregateBufferConverter {
Expand Down