Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16958] [SQL] Reuse subqueries within the same query #14548

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ case class PredicateSubquery(
override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
query.sameResult(p.query) && nullAware == p.nullAware &&
children.length == p.children.length &&
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
}
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
depth + 2, lastChildren :+ false :+ false, builder, verbose))
depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose))
innerChildren.last.generateTreeString(
depth + 2, lastChildren :+ false :+ true, builder, verbose)
depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose)
}

if (children.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf))
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))

protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression]

/**
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
subqueryResults += e -> futureResult
val allSubqueries = expressions.flatMap(_.collect { case e: ExecSubqueryExpression => e })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not combine this with the foreach below?

allSubqueries.foreach {
case e: ExecSubqueryExpression =>
e.plan.prepare()
runningSubqueries += e
}
}

Expand All @@ -165,21 +161,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
runningSubqueries.foreach { sub =>
sub.updateResult(sub.plan.executeCollect())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subplan is SubqueryExec right? Why not create a method in ExecSubqueryExpression for this?

}
subqueryResults.clear()
runningSubqueries.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.spark.sql.execution

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.SparkException
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
Expand Down Expand Up @@ -502,15 +508,64 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa

/**
* Physical plan for a subquery.
*
* This is used to generate tree string for SparkScalarSubquery.
*/
case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A large part of this class is shared with BroadcastExchangeExec. Should we try to factor out common functionality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to have some duplicated code here, over abstracted code is actually harder to read.


override private[sql] lazy val metrics = Map(
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"))

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def sameResult(o: SparkPlan): Boolean = o match {
case s: SubqueryExec => child.sameResult(s.child)
case _ => false
}

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize

// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}

rows
}
}(SubqueryExec.executionContext)
}

protected override def doPrepare(): Unit = {
relationFuture
}

protected override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
child.execute()
}

override def executeCollect(): Array[InternalRow] = {
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
}
}

object SubqueryExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}
143 changes: 127 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,77 @@

package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}

/**
* The base class for subquery that is used in SparkPlan.
*/
trait ExecSubqueryExpression extends SubqueryExpression {

val executedPlan: SubqueryExec
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression

// does not have logical plan
override def query: LogicalPlan = throw new UnsupportedOperationException
Copy link
Contributor

@hvanhovell hvanhovell Aug 10, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove this from the interface? If we are going to use a type parameter anyway (see comment on withNewPlan)

override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we try to combine this with withExecutedPlan by introducing a type parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this multiple times, does not work that well, we could try again later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, lets address this in a (potential) follow-up.

throw new UnsupportedOperationException

override def plan: SparkPlan = executedPlan

/**
* Fill the expression with result from subquery.
*/
def updateResult(rows: Array[InternalRow]): Unit
}

/**
* A subquery that will return only one row and one column.
*
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class ScalarSubquery(
executedPlan: SparkPlan,
executedPlan: SubqueryExec,
exprId: ExprId)
extends SubqueryExpression {

override def query: LogicalPlan = throw new UnsupportedOperationException
override def withNewPlan(plan: LogicalPlan): SubqueryExpression = {
throw new UnsupportedOperationException
}
override def plan: SparkPlan = SubqueryExec(simpleString, executedPlan)
extends ExecSubqueryExpression {

override def dataType: DataType = executedPlan.schema.fields.head.dataType
override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
override def toString: String = s"subquery#${exprId.id}"
override def toString: String = executedPlan.simpleString

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)

override def semanticEquals(other: Expression): Boolean = other match {
case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
case _ => false
}

// the first column in first row from `query`.
@volatile private var result: Any = null
@volatile private var updated: Boolean = false

def updateResult(v: Any): Unit = {
result = v
def updateResult(rows: Array[InternalRow]): Unit = {
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
result = rows(0).get(0, dataType)
} else {
// If there is no rows returned, the result should be null.
result = null
}
updated = true
}

Expand All @@ -67,6 +102,50 @@ case class ScalarSubquery(
}
}

/**
* A subquery that will check the value of `child` whether is in the result of a query or not.
*/
case class InSubquery(
child: Expression,
executedPlan: SubqueryExec,
exprId: ExprId,
private var result: Array[Any] = null,
private var updated: Boolean = false) extends ExecSubqueryExpression {

override def dataType: DataType = BooleanType
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = child.nullable
override def toString: String = s"$child IN ${executedPlan.name}"

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)

override def semanticEquals(other: Expression): Boolean = other match {
case in: InSubquery => child.semanticEquals(in.child) &&
executedPlan.sameResult(in.executedPlan)
case _ => false
}

def updateResult(rows: Array[InternalRow]): Unit = {
result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]]
updated = true
}

override def eval(input: InternalRow): Any = {
require(updated, s"$this has not finished")
val v = child.eval(input)
if (v == null) {
null
} else {
result.contains(v)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
require(updated, s"$this has not finished")
InSet(child, result.toSet).doGenCode(ctx, ev)
}
}

/**
* Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
Expand All @@ -75,7 +154,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add some size check here? This might as well materialize a billion rows.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what circumstance is this triggered? A predicate subquery in Project?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the PredicateSubquery from original SQL query will be rewritten as join, this could come from other optimization rules, that rule may make sure that there not billions of rows.

val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
}


/**
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
* references.
*/
case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

def apply(plan: SparkPlan): SparkPlan = {
if (!conf.exchangeReuseEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We could also add this to planner conditionally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this to QueryExecution actually will make it ugly :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, lets leave it then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why sharing the same conf spark.sql.exchange.reuse with ReuseExchange?

return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
plan transformAllExpressions {
case sub: ExecSubqueryExpression =>
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
val sameResult = sameSchema.find(_.sameResult(sub.plan))
if (sameResult.isDefined) {
sub.withExecutedPlan(sameResult.get)
} else {
sameSchema += sub.executedPlan
sub
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ private[sql] object SparkPlanGraph {
case "Subquery" if subgraph != null =>
// Subquery should not be included in WholeStageCodegen
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "ReusedExchange" =>
case "Subquery" if exchanges.contains(planInfo) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could post a screen shot of this?

// Point to the re-used subquery
val node = exchanges(planInfo)
edges += SparkPlanGraphEdge(node.id, parent.id)
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
Expand All @@ -115,7 +119,7 @@ private[sql] object SparkPlanGraph {
} else {
subgraph.nodes += node
}
if (name.contains("Exchange")) {
if (name.contains("Exchange") || name == "Subquery") {
exchanges += planInfo -> node
}

Expand Down