From 1348ba72d9562714ad020483ff2e067c05114957 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Aug 2016 15:24:07 -0700 Subject: [PATCH 1/3] Reuse subquery --- .../sql/catalyst/expressions/subquery.scala | 7 + .../spark/sql/catalyst/trees/TreeNode.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 3 +- .../spark/sql/execution/SparkPlan.scala | 33 ++-- .../execution/basicPhysicalOperators.scala | 63 +++++++- .../apache/spark/sql/execution/subquery.scala | 143 ++++++++++++++++-- .../sql/execution/ui/SparkPlanGraph.scala | 8 +- 7 files changed, 212 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 08cb6c0134e3a..ac44f08897cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -102,6 +102,13 @@ case class PredicateSubquery( override def nullable: Boolean = nullAware override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) + override def semanticEquals(o: Expression): Boolean = o match { + case p: PredicateSubquery => + query.sameResult(p.query) && nullAware == p.nullAware && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8bce404735785..24a2dc9d3b35f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -538,9 +538,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ false :+ false, builder, verbose)) + depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ false :+ true, builder, verbose) + depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose) } if (children.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5b9af26dfc4f8..d4845637be049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -101,7 +101,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), - ReuseExchange(sparkSession.sessionState.conf)) + ReuseExchange(sparkSession.sessionState.conf), + ReuseSubquery(sparkSession.sessionState.conf)) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 045ccc7bd6eae..9dcb6aefedf10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -142,21 +142,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * This list is populated by [[prepareSubqueries]], which is called in [[prepare]]. */ @transient - private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])] + private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression] /** * Finds scalar subquery expressions in this plan node and starts evaluating them. - * The list of subqueries are added to [[subqueryResults]]. */ protected def prepareSubqueries(): Unit = { - val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) - allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => - val futureResult = Future { - // Each subquery should return only one row (and one column). We take two here and throws - // an exception later if the number of rows is greater than one. - e.executedPlan.executeTake(2) - }(SparkPlan.subqueryExecutionContext) - subqueryResults += e -> futureResult + val allSubqueries = expressions.flatMap(_.collect { case e: ExecSubqueryExpression => e }) + allSubqueries.foreach { + case e: ExecSubqueryExpression => + e.plan.prepare() + runningSubqueries += e } } @@ -165,21 +161,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries - subqueryResults.foreach { case (e, futureResult) => - val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) - if (rows.length > 1) { - sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") - } - if (rows.length == 1) { - assert(rows(0).numFields == 1, - s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") - e.updateResult(rows(0).get(0, e.dataType)) - } else { - // If there is no rows returned, the result should be null. - e.updateResult(null) - } + runningSubqueries.foreach { sub => + sub.updateResult(sub.plan.executeCollect()) } - subqueryResults.clear() + runningSubqueries.clear() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 185c79f899e68..31f3904e0cb8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.execution +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration + +import org.apache.spark.SparkException import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates +import org.apache.spark.sql.types.LongType +import org.apache.spark.util.ThreadUtils import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ @@ -502,15 +508,64 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa /** * Physical plan for a subquery. - * - * This is used to generate tree string for SparkScalarSubquery. */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { + + override private[sql] lazy val metrics = Map( + "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), + "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def sameResult(o: SparkPlan): Boolean = o match { + case s: SubqueryExec => child.sameResult(s.child) + case _ => false + } + + @transient + private lazy val relationFuture: Future[Array[InternalRow]] = { + // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + val beforeCollect = System.nanoTime() + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val rows: Array[InternalRow] = child.executeCollect() + val beforeBuild = System.nanoTime() + longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 + val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( + executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) + } + + rows + } + }(SubqueryExec.executionContext) + } + + protected override def doPrepare(): Unit = { + relationFuture + } + protected override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + child.execute() } + + override def executeCollect(): Array[InternalRow] = { + ThreadUtils.awaitResult(relationFuture, Duration.Inf) + } +} + +object SubqueryExec { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 461d3010ada7e..cc9e347a8e9cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,14 +17,38 @@ package org.apache.spark.sql.execution +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} + +/** + * The base class for subquery that is used in SparkPlan. + */ +trait ExecSubqueryExpression extends SubqueryExpression { + + val executedPlan: SubqueryExec + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression + + // does not have logical plan + override def query: LogicalPlan = throw new UnsupportedOperationException + override def withNewPlan(plan: LogicalPlan): SubqueryExpression = + throw new UnsupportedOperationException + + override def plan: SparkPlan = executedPlan + + /** + * Fill the expression with result from subquery. + */ + def updateResult(rows: Array[InternalRow]): Unit +} /** * A subquery that will return only one row and one column. @@ -32,27 +56,38 @@ import org.apache.spark.sql.types.DataType * This is the physical copy of ScalarSubquery to be used inside SparkPlan. */ case class ScalarSubquery( - executedPlan: SparkPlan, + executedPlan: SubqueryExec, exprId: ExprId) - extends SubqueryExpression { - - override def query: LogicalPlan = throw new UnsupportedOperationException - override def withNewPlan(plan: LogicalPlan): SubqueryExpression = { - throw new UnsupportedOperationException - } - override def plan: SparkPlan = SubqueryExec(simpleString, executedPlan) + extends ExecSubqueryExpression { override def dataType: DataType = executedPlan.schema.fields.head.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true - override def toString: String = s"subquery#${exprId.id}" + override def toString: String = executedPlan.simpleString + + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + + override def semanticEquals(other: Expression): Boolean = other match { + case s: ScalarSubquery => executedPlan.sameResult(executedPlan) + case _ => false + } // the first column in first row from `query`. @volatile private var result: Any = null @volatile private var updated: Boolean = false - def updateResult(v: Any): Unit = { - result = v + def updateResult(rows: Array[InternalRow]): Unit = { + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, + s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") + result = rows(0).get(0, dataType) + } else { + // If there is no rows returned, the result should be null. + result = null + } updated = true } @@ -67,6 +102,50 @@ case class ScalarSubquery( } } +/** + * A subquery that will check the value of `child` whether is in the result of a query or not. + */ +case class InSubquery( + child: Expression, + executedPlan: SubqueryExec, + exprId: ExprId, + private var result: Array[Any] = null, + private var updated: Boolean = false) extends ExecSubqueryExpression { + + override def dataType: DataType = BooleanType + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = child.nullable + override def toString: String = s"$child IN ${executedPlan.name}" + + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + + override def semanticEquals(other: Expression): Boolean = other match { + case in: InSubquery => child.semanticEquals(in.child) && + executedPlan.sameResult(in.executedPlan) + case _ => false + } + + def updateResult(rows: Array[InternalRow]): Unit = { + result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]] + updated = true + } + + override def eval(input: InternalRow): Any = { + require(updated, s"$this has not finished") + val v = child.eval(input) + if (v == null) { + null + } else { + result.contains(v) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + require(updated, s"$this has not finished") + InSet(child, result.toSet).doGenCode(ctx, ev) + } +} + /** * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ @@ -75,7 +154,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan - ScalarSubquery(executedPlan, subquery.exprId) + ScalarSubquery( + SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), + subquery.exprId) + case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) => + val executedPlan = new QueryExecution(sparkSession, plan).executedPlan + InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) + } + } +} + + +/** + * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * references. + */ +case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.exchangeReuseEnabled) { + return plan + } + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() + plan transformAllExpressions { + case sub: ExecSubqueryExpression => + val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) + val sameResult = sameSchema.find(_.sameResult(sub.plan)) + if (sameResult.isDefined) { + sub.withExecutedPlan(sameResult.get) + } else { + sameSchema += sub.executedPlan + sub + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 8f5681bfc7cc6..6e13bc6eb033a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -99,7 +99,11 @@ private[sql] object SparkPlanGraph { case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) - case "ReusedExchange" => + case "Subquery" if exchanges.contains(planInfo) => + // Point to the re-used subquery + val node = exchanges(planInfo) + edges += SparkPlanGraphEdge(node.id, parent.id) + case "ReusedExchange" if exchanges.contains(planInfo.children.head) => // Point to the re-used exchange val node = exchanges(planInfo.children.head) edges += SparkPlanGraphEdge(node.id, parent.id) @@ -115,7 +119,7 @@ private[sql] object SparkPlanGraph { } else { subgraph.nodes += node } - if (name.contains("Exchange")) { + if (name.contains("Exchange") || name == "Subquery") { exchanges += planInfo -> node } From 8444447c95969afd757fbd744617fbb0b883b0f1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Aug 2016 15:20:27 -0700 Subject: [PATCH 2/3] address comments --- .../org/apache/spark/sql/execution/SparkPlan.scala | 13 +++++++------ .../org/apache/spark/sql/execution/subquery.scala | 10 ++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9dcb6aefedf10..d02f8686e2690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -148,11 +148,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Finds scalar subquery expressions in this plan node and starts evaluating them. */ protected def prepareSubqueries(): Unit = { - val allSubqueries = expressions.flatMap(_.collect { case e: ExecSubqueryExpression => e }) - allSubqueries.foreach { - case e: ExecSubqueryExpression => - e.plan.prepare() - runningSubqueries += e + expressions.foreach { + _.collect { + case e: ExecSubqueryExpression => + e.plan.prepare() + runningSubqueries += e + } } } @@ -162,7 +163,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries runningSubqueries.foreach { sub => - sub.updateResult(sub.plan.executeCollect()) + sub.updateResult() } runningSubqueries.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index cc9e347a8e9cc..c730bee6ae050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -45,9 +45,9 @@ trait ExecSubqueryExpression extends SubqueryExpression { override def plan: SparkPlan = executedPlan /** - * Fill the expression with result from subquery. + * Fill the expression with collected result from executed plan. */ - def updateResult(rows: Array[InternalRow]): Unit + def updateResult(): Unit } /** @@ -76,7 +76,8 @@ case class ScalarSubquery( @volatile private var result: Any = null @volatile private var updated: Boolean = false - def updateResult(rows: Array[InternalRow]): Unit = { + def updateResult(): Unit = { + val rows = plan.executeCollect() if (rows.length > 1) { sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}") } @@ -125,7 +126,8 @@ case class InSubquery( case _ => false } - def updateResult(rows: Array[InternalRow]): Unit = { + def updateResult(): Unit = { + val rows = plan.executeCollect() result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]] updated = true } From dd1581be4875c481b413654f8c99c97da769fb03 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Aug 2016 15:45:18 -0700 Subject: [PATCH 3/3] fix conflict --- .../org/apache/spark/sql/execution/basicPhysicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 975007a0ad538..ad8a71689895b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -511,7 +511,7 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"))