Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support negative preceding/following for ROW-based window functions #9229

Merged
merged 17 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,91 @@ def test_to_date_with_window_functions():
"""
)


@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_dates,
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
_grpkey_longs_with_nullable_larger_decimals
], ids=idfn)
@pytest.mark.parametrize('window_spec', ["3 PRECEDING AND -1 FOLLOWING",
"-2 PRECEDING AND 4 FOLLOWING",
"UNBOUNDED PRECEDING AND -1 FOLLOWING",
"-1 PRECEDING AND UNBOUNDED FOLLOWING",
"10 PRECEDING AND -1 FOLLOWING",
"5 PRECEDING AND -2 FOLLOWING"], ids=idfn)
def test_window_aggs_for_negative_rows_partitioned(data_gen, batch_size, window_spec):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}
assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'SELECT '
' SUM(c) OVER '
' (PARTITION BY a ORDER BY b,c ASC ROWS BETWEEN {window}) AS sum_c_asc, '
' MAX(c) OVER '
' (PARTITION BY a ORDER BY b DESC, c DESC ROWS BETWEEN {window}) AS max_c_desc, '
' MIN(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS min_c_asc, '
' COUNT(1) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_1, '
' COUNT(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_c, '
' AVG(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS avg_c, '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS list_c, '
' SORT_ARRAY(COLLECT_LIST(c) OVER '
revans2 marked this conversation as resolved.
Show resolved Hide resolved
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window})) AS sorted_set_c '
'FROM window_agg_table '.format(window=window_spec),
conf=conf)


@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_dates,
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
# TODO: Sorting DECIMAL(23,10) borked on CPU, Spark 3.2.1.
_grpkey_longs_with_nullable_larger_decimals,
], ids=idfn)
def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}

assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'SELECT '
' SUM(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN 3 PRECEDING AND -1 FOLLOWING) AS sum_c_asc, '
' MAX(c) OVER '
' (ORDER BY b DESC, c DESC, a DESC ROWS BETWEEN -2 PRECEDING AND 4 FOLLOWING) AS max_c_desc, '
' min(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN UNBOUNDED PRECEDING AND -1 FOLLOWING) AS min_c_asc, '
' COUNT(1) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS count_1, '
' COUNT(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN 10 PRECEDING AND -1 FOLLOWING) AS count_c, '
' AVG(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS avg_c, '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING) AS list_c, '
' SORT_ARRAY(COLLECT_SET(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING)) AS set_c '
'FROM window_agg_table ',
conf=conf)


def test_lru_cache_datagen():
# log cache info at the end of integration tests, not related to window functions
info = gen_df_help.cache_info()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,13 @@ object GroupedAggregations {
private def getWindowOptions(
orderSpec: Seq[SortOrder],
orderPositions: Seq[Int],
frame: GpuSpecifiedWindowFrame): WindowOptions = {
frame: GpuSpecifiedWindowFrame,
minPeriods: Int): WindowOptions = {
frame.frameType match {
case RowFrame =>
withResource(getRowBasedLower(frame)) { lower =>
withResource(getRowBasedUpper(frame)) { upper =>
val builder = WindowOptions.builder().minPeriods(1)
val builder = WindowOptions.builder().minPeriods(minPeriods)
if (isUnbounded(frame.lower)) builder.unboundedPreceding() else builder.preceding(lower)
if (isUnbounded(frame.upper)) builder.unboundedFollowing() else builder.following(upper)
builder.build
Expand All @@ -718,7 +719,7 @@ object GroupedAggregations {
withResource(asScalarRangeBoundary(orderType, lower)) { preceding =>
withResource(asScalarRangeBoundary(orderType, upper)) { following =>
val windowOptionBuilder = WindowOptions.builder()
.minPeriods(1)
.minPeriods(1) // Does not currently support custom minPeriods.
.orderByColumnIndex(orderByIndex)

if (preceding.isEmpty) {
Expand Down Expand Up @@ -929,13 +930,18 @@ class GroupedAggregations {
if (frameSpec.frameType == frameType) {
// For now I am going to assume that we don't need to combine calls across frame specs
// because it would just not help that much
val result = withResource(
getWindowOptions(boundOrderSpec, orderByPositions, frameSpec)) { windowOpts =>
val allAggs = functions.map {
case (winFunc, _) => winFunc.aggOverWindow(inputCb, windowOpts)
}.toSeq
withResource(GpuColumnVector.from(inputCb)) { initProjTab =>
aggIt(initProjTab.groupBy(partByPositions: _*), allAggs)
val result = {
val allWindowOpts = functions.map { f =>
getWindowOptions(boundOrderSpec, orderByPositions, frameSpec,
f._1.windowFunc.getMinPeriods)
}
withResource(allWindowOpts.toSeq) { allWindowOpts =>
val allAggs = allWindowOpts.zip(functions).map { case (windowOpt, f) =>
f._1.aggOverWindow(inputCb, windowOpt)
}
withResource(GpuColumnVector.from(inputCb)) { initProjTab =>
aggIt(initProjTab.groupBy(partByPositions: _*), allAggs)
}
}
}
withResource(result) { result =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, Max, Min, Sum}
import org.apache.spark.sql.rapids.{AddOverflowChecks, GpuAggregateExpression, GpuCount, GpuCreateNamedStruct, GpuDivide, GpuSubtract}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,13 +82,25 @@ abstract class GpuWindowExpressionMetaBase(
case _: Lead | _: Lag => // ignored we are good
case _ =>
// need to be sure that the lower/upper are acceptable
if (lower > 0) {
willNotWorkOnGpu(s"lower-bounds ahead of current row is not supported. " +
s"Found $lower")
// Negative bounds are allowed, so long as lower does not exceed upper.
if (upper < lower) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
willNotWorkOnGpu("upper-bounds must equal or exceed the lower bounds. " +
s"Found lower=$lower, upper=$upper ")
}
if (upper < 0) {
willNotWorkOnGpu(s"upper-bounds behind the current row is not supported. " +
s"Found $upper")
// Also check for negative offsets.
if (upper < 0 || lower < 0) {
windowFunction.asInstanceOf[AggregateExpression].aggregateFunction match {
case _: Average => // Supported
case _: CollectList => // Supported
case _: CollectSet => // Supported
case _: Count => // Supported
case _: Max => // Supported
case _: Min => // Supported
case _: Sum => // Supported
case f: AggregateFunction =>
willNotWorkOnGpu("negative row bounds unsupported for specified " +
s"aggregation: ${f.prettyName}")
}
}
}
case RangeFrame =>
Expand Down Expand Up @@ -649,7 +662,15 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary)

// This is here for now just to tag an expression as being a GpuWindowFunction and match
// Spark. This may expand in the future if other types of window functions show up.
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression {
/**
* Get "min-periods" value, i.e. the minimum number of periods/rows
* above which a non-null value is returned for the function.
* Otherwise, null is returned.
* @return Non-negative value for min-periods.
*/
def getMinPeriods: Int = 1
}

/**
* This is a special window function that simply replaces itself with one or more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,13 @@ case class GpuCount(children: Seq[Expression],

override def newUnboundedToUnboundedFixer: BatchedUnboundedToUnboundedWindowFixer =
new CountUnboundedToUnboundedFixer(failOnError)

// minPeriods should be 0.
// Consider the following rows:
// v = [ 0, 1, 2, 3, 4, 5 ]
// A `COUNT` window aggregation over (2, -1) should yield 0, not null,
// for the first row.
override def getMinPeriods: Int = 0
}

object GpuAverage {
Expand Down Expand Up @@ -1961,6 +1968,12 @@ case class GpuCollectList(
override def windowAggregation(
inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn =
RollingAggregation.collectList().onColumn(inputs.head._2)

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_LIST` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

/**
Expand Down Expand Up @@ -1995,6 +2008,12 @@ case class GpuCollectSet(
RollingAggregation.collectSet(NullPolicy.EXCLUDE, NullEquality.EQUAL,
NaNEquality.ALL_EQUAL).onColumn(inputs.head._2)
}

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_SET` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

trait CpuToGpuAggregateBufferConverter {
Expand Down