Skip to content

Commit

Permalink
[SPARK-35133][SQL] Explain codegen works with AQE
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

`EXPLAIN CODEGEN <query>` (and Dataset.explain("codegen")) prints out the generated code for each stage of plan. The current implementation is to match `WholeStageCodegenExec` operator in query plan and prints out generated code (https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala#L111-L118 ). This does not work with AQE as we wrap the whole query plan inside `AdaptiveSparkPlanExec` and do not run whole stage code-gen physical plan rule eagerly (`CollapseCodegenStages`). This introduces unexpected behavior change for EXPLAIN query (and Dataset.explain), as we enable AQE by default now.

The change is to explain code-gen for the current executed plan of AQE.

### Why are the changes needed?

Make `EXPLAIN CODEGEN` work same as before.

### Does this PR introduce _any_ user-facing change?

No (when comparing with latest Spark release 3.1.1).

### How was this patch tested?

Added unit test in `ExplainSuite.scala`.

Closes #32430 from c21/explain-aqe.

Authored-by: Cheng Su <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
c21 authored and dongjoon-hyun committed May 7, 2021
1 parent 94bbca3 commit 42f59ca
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeFor
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQuery
Expand Down Expand Up @@ -112,6 +113,11 @@ package object debug {
plan foreach {
case s: WholeStageCodegenExec =>
codegenSubtrees += s
case p: AdaptiveSparkPlanExec =>
// Find subtrees from current executed plan of AQE.
findSubtrees(p.executedPlan)
case s: QueryStageExec =>
findSubtrees(s.plan)
case s =>
s.subqueries.foreach(findSubtrees)
}
Expand Down
27 changes: 27 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,33 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
|""".stripMargin
)
}

test("SPARK-35133: explain codegen should work with AQE") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
withTempView("df") {
val df = spark.range(5).select(col("id").as("key"), col("id").as("value"))
df.createTempView("df")

val sqlText = "EXPLAIN CODEGEN SELECT key, MAX(value) FROM df GROUP BY key"
val expectedCodegenText = "Found 2 WholeStageCodegen subtrees."
val expectedNoCodegenText = "Found 0 WholeStageCodegen subtrees."
withNormalizedExplain(sqlText) { normalizedOutput =>
assert(normalizedOutput.contains(expectedNoCodegenText))
}

val aggDf = df.groupBy('key).agg(max('value))
withNormalizedExplain(aggDf, CodegenMode) { normalizedOutput =>
assert(normalizedOutput.contains(expectedNoCodegenText))
}

// trigger the final plan for AQE
aggDf.collect()
withNormalizedExplain(aggDf, CodegenMode) { normalizedOutput =>
assert(normalizedOutput.contains(expectedCodegenText))
}
}
}
}
}

case class ExplainSingleData(id: Int)
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.{CodegenSupport, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.StructType

// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
class DebuggingSuite extends SharedSparkSession with DisableAdaptiveExecutionSuite {
abstract class DebuggingSuiteBase extends SharedSparkSession {

test("DataFrame.debug()") {
testData.debug()
Expand All @@ -43,63 +42,23 @@ class DebuggingSuite extends SharedSparkSession with DisableAdaptiveExecutionSui
}

test("debugCodegen") {
val res = codegenString(spark.range(10).groupBy(col("id") * 2).count()
.queryExecution.executedPlan)
val df = spark.range(10).groupBy(col("id") * 2).count()
df.collect()
val res = codegenString(df.queryExecution.executedPlan)
assert(res.contains("Subtree 1 / 2"))
assert(res.contains("Subtree 2 / 2"))
assert(res.contains("Object[]"))
}

test("debugCodegenStringSeq") {
val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count()
.queryExecution.executedPlan)
val df = spark.range(10).groupBy(col("id") * 2).count()
df.collect()
val res = codegenStringSeq(df.queryExecution.executedPlan)
assert(res.length == 2)
assert(res.forall{ case (subtree, code, _) =>
subtree.contains("Range") && code.contains("Object[]")})
}

test("SPARK-28537: DebugExec cannot debug broadcast related queries") {
val rightDF = spark.range(10)
val leftDF = spark.range(10)
val joinedDF = leftDF.join(rightDF, leftDF("id") === rightDF("id"))

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
joinedDF.debug()
}

val output = captured.toString()
val hashedModeString = "HashedRelationBroadcastMode(List(input[0, bigint, false]),false)"
assert(output.replaceAll("\\[id=#\\d+\\]", "[id=#x]").contains(
s"""== BroadcastExchange $hashedModeString, [id=#x] ==
|Tuples output: 0
| id LongType: {}
|== WholeStageCodegen (1) ==
|Tuples output: 10
| id LongType: {java.lang.Long}
|== Range (0, 10, step=1, splits=2) ==
|Tuples output: 0
| id LongType: {}""".stripMargin))
}

test("SPARK-28537: DebugExec cannot debug columnar related queries") {
val df = spark.range(5)
df.persist()

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
df.debug()
}
df.unpersist()

val output = captured.toString().replaceAll("#\\d+", "#x")
assert(output.contains(
"""== InMemoryTableScan [id#xL] ==
|Tuples output: 0
| id LongType: {}
|""".stripMargin))
}

case class DummyCodeGeneratorPlan(useInnerClass: Boolean)
extends CodegenSupport with LeafExecNode {
override def output: Seq[Attribute] = StructType.fromDDL("d int").toAttributes
Expand Down Expand Up @@ -137,3 +96,51 @@ class DebuggingSuite extends SharedSparkSession with DisableAdaptiveExecutionSui
}
}
}

// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSuite {

test("SPARK-28537: DebugExec cannot debug broadcast related queries") {
val rightDF = spark.range(10)
val leftDF = spark.range(10)
val joinedDF = leftDF.join(rightDF, leftDF("id") === rightDF("id"))

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
joinedDF.debug()
}

val output = captured.toString()
val hashedModeString = "HashedRelationBroadcastMode(List(input[0, bigint, false]),false)"
assert(output.replaceAll("\\[id=#\\d+\\]", "[id=#x]").contains(
s"""== BroadcastExchange $hashedModeString, [id=#x] ==
|Tuples output: 0
| id LongType: {}
|== WholeStageCodegen (1) ==
|Tuples output: 10
| id LongType: {java.lang.Long}
|== Range (0, 10, step=1, splits=2) ==
|Tuples output: 0
| id LongType: {}""".stripMargin))
}

test("SPARK-28537: DebugExec cannot debug columnar related queries") {
val df = spark.range(5)
df.persist()

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
df.debug()
}
df.unpersist()

val output = captured.toString().replaceAll("#\\d+", "#x")
assert(output.contains(
"""== InMemoryTableScan [id#xL] ==
|Tuples output: 0
| id LongType: {}
|""".stripMargin))
}
}

class DebuggingSuiteAE extends DebuggingSuiteBase with EnableAdaptiveExecutionSuite

0 comments on commit 42f59ca

Please sign in to comment.