Skip to content

Commit

Permalink
[SPARK-26572][SQL] fix aggregate codegen result evaluation
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR is a correctness fix in `HashAggregateExec` code generation. It forces evaluation of result expressions before calling `consume()` to avoid multiple executions.

This PR fixes a use case where an aggregate is nested into a broadcast join and appears on the "stream" side. The issue is that Broadcast join generates it's own loop. And without forcing evaluation of `resultExpressions` of `HashAggregateExec` before the join's loop these expressions can be executed multiple times giving incorrect results.

## How was this patch tested?

New UT was added.

Closes apache#23731 from peter-toth/SPARK-26572.

Authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
peter-toth authored and cloud-fan committed Feb 14, 2019
1 parent ac9c053 commit 2228ee5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan {
evaluateVars.toString()
}

/**
* Returns source code to evaluate the variables for non-deterministic expressions, and clear the
* code of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateNondeterministicVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
expressions: Seq[NamedExpression]): String = {
val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute)
evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs))
}

/**
* The subset of inputSet those should be evaluated before this plan.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,13 @@ case class HashAggregateExec(
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
val evaluateNondeterministicResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateKeyVars
$evaluateBufferVars
$evaluateAggResults
$evaluateNondeterministicResults
${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
Expand Down Expand Up @@ -506,10 +509,15 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = bindReferences[Expression](
val resultVars = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
val evaluateNondeterministicResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateNondeterministicResults
${consume(ctx, resultVars)}
"""
}
ctx.addNewFunction(funcName,
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
Expand Down Expand Up @@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {

checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
}

test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val baseTable = Seq(1, 1).toDF("idx")

// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))

// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
// expression
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0)))
}
}
}

0 comments on commit 2228ee5

Please sign in to comment.