Skip to content

Commit

Permalink
Reuse subquery
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Aug 8, 2016
1 parent 03d46aa commit 1348ba7
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ case class PredicateSubquery(
override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
query.sameResult(p.query) && nullAware == p.nullAware &&
children.length == p.children.length &&
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
}
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
depth + 2, lastChildren :+ false :+ false, builder, verbose))
depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose))
innerChildren.last.generateTreeString(
depth + 2, lastChildren :+ false :+ true, builder, verbose)
depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose)
}

if (children.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf))
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))

protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression]

/**
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
subqueryResults += e -> futureResult
val allSubqueries = expressions.flatMap(_.collect { case e: ExecSubqueryExpression => e })
allSubqueries.foreach {
case e: ExecSubqueryExpression =>
e.plan.prepare()
runningSubqueries += e
}
}

Expand All @@ -165,21 +161,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
runningSubqueries.foreach { sub =>
sub.updateResult(sub.plan.executeCollect())
}
subqueryResults.clear()
runningSubqueries.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.spark.sql.execution

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.SparkException
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
Expand Down Expand Up @@ -502,15 +508,64 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa

/**
* Physical plan for a subquery.
*
* This is used to generate tree string for SparkScalarSubquery.
*/
case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {

override private[sql] lazy val metrics = Map(
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"))

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def sameResult(o: SparkPlan): Boolean = o match {
case s: SubqueryExec => child.sameResult(s.child)
case _ => false
}

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize

// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}

rows
}
}(SubqueryExec.executionContext)
}

protected override def doPrepare(): Unit = {
relationFuture
}

protected override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
child.execute()
}

override def executeCollect(): Array[InternalRow] = {
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
}
}

object SubqueryExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}
143 changes: 127 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,77 @@

package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}

/**
* The base class for subquery that is used in SparkPlan.
*/
trait ExecSubqueryExpression extends SubqueryExpression {

val executedPlan: SubqueryExec
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression

// does not have logical plan
override def query: LogicalPlan = throw new UnsupportedOperationException
override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
throw new UnsupportedOperationException

override def plan: SparkPlan = executedPlan

/**
* Fill the expression with result from subquery.
*/
def updateResult(rows: Array[InternalRow]): Unit
}

/**
* A subquery that will return only one row and one column.
*
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class ScalarSubquery(
executedPlan: SparkPlan,
executedPlan: SubqueryExec,
exprId: ExprId)
extends SubqueryExpression {

override def query: LogicalPlan = throw new UnsupportedOperationException
override def withNewPlan(plan: LogicalPlan): SubqueryExpression = {
throw new UnsupportedOperationException
}
override def plan: SparkPlan = SubqueryExec(simpleString, executedPlan)
extends ExecSubqueryExpression {

override def dataType: DataType = executedPlan.schema.fields.head.dataType
override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
override def toString: String = s"subquery#${exprId.id}"
override def toString: String = executedPlan.simpleString

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)

override def semanticEquals(other: Expression): Boolean = other match {
case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
case _ => false
}

// the first column in first row from `query`.
@volatile private var result: Any = null
@volatile private var updated: Boolean = false

def updateResult(v: Any): Unit = {
result = v
def updateResult(rows: Array[InternalRow]): Unit = {
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
result = rows(0).get(0, dataType)
} else {
// If there is no rows returned, the result should be null.
result = null
}
updated = true
}

Expand All @@ -67,6 +102,50 @@ case class ScalarSubquery(
}
}

/**
* A subquery that will check the value of `child` whether is in the result of a query or not.
*/
case class InSubquery(
child: Expression,
executedPlan: SubqueryExec,
exprId: ExprId,
private var result: Array[Any] = null,
private var updated: Boolean = false) extends ExecSubqueryExpression {

override def dataType: DataType = BooleanType
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = child.nullable
override def toString: String = s"$child IN ${executedPlan.name}"

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)

override def semanticEquals(other: Expression): Boolean = other match {
case in: InSubquery => child.semanticEquals(in.child) &&
executedPlan.sameResult(in.executedPlan)
case _ => false
}

def updateResult(rows: Array[InternalRow]): Unit = {
result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]]
updated = true
}

override def eval(input: InternalRow): Any = {
require(updated, s"$this has not finished")
val v = child.eval(input)
if (v == null) {
null
} else {
result.contains(v)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
require(updated, s"$this has not finished")
InSet(child, result.toSet).doGenCode(ctx, ev)
}
}

/**
* Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
Expand All @@ -75,7 +154,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
}


/**
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
* references.
*/
case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

def apply(plan: SparkPlan): SparkPlan = {
if (!conf.exchangeReuseEnabled) {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
plan transformAllExpressions {
case sub: ExecSubqueryExpression =>
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
val sameResult = sameSchema.find(_.sameResult(sub.plan))
if (sameResult.isDefined) {
sub.withExecutedPlan(sameResult.get)
} else {
sameSchema += sub.executedPlan
sub
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ private[sql] object SparkPlanGraph {
case "Subquery" if subgraph != null =>
// Subquery should not be included in WholeStageCodegen
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "ReusedExchange" =>
case "Subquery" if exchanges.contains(planInfo) =>
// Point to the re-used subquery
val node = exchanges(planInfo)
edges += SparkPlanGraphEdge(node.id, parent.id)
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
Expand All @@ -115,7 +119,7 @@ private[sql] object SparkPlanGraph {
} else {
subgraph.nodes += node
}
if (name.contains("Exchange")) {
if (name.contains("Exchange") || name == "Subquery") {
exchanges += planInfo -> node
}

Expand Down

0 comments on commit 1348ba7

Please sign in to comment.