Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32945][SQL] Avoid collapsing projects if reaching max allowed common exprs #29950

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) :+
Batch("RewriteSubquery", Once,
// `CollapseProject` cannot collapse all projects in once. So we need `fixedPoint` here.
Batch("RewriteSubquery", fixedPoint,
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -724,20 +725,17 @@ object ColumnPruning extends Rule[LogicalPlan] {
/**
* Combines two [[Project]] operators into one and perform alias substitution,
* merging the expressions into one single expression for the following cases.
* 1. When two [[Project]] operators are adjacent.
* 1. When two [[Project]] operators are adjacent, if the number of common expressions in the
* combined [[Project]] is not more than `spark.sql.optimizer.maxCommonExprsInCollapseProject`.
* 2. When two [[Project]] operators have LocalLimit/Sample/Repartition operator between them
* and the upper project consists of the same number of columns which is equal or aliasing.
* `GlobalLimit(LocalLimit)` pattern is also considered.
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
*/
object CollapseProject extends Rule[LogicalPlan] with AliasHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to change from transformUp to transformDown? If the all test passed, it would be safe if we keep the original one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the previous comment about supporting withColumn. If this is designed for that, shall we add a test case for that?

case p @ Project(_, _: Project) =>
collapseProjects(p)
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
Expand All @@ -758,6 +756,42 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
}

private def collapseProjects(plan: LogicalPlan): LogicalPlan = plan match {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
moreThanMaxAllowedCommonOutput(p1.projectList, p2.projectList)) {
p1
} else {
collapseProjects(
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)))
}
case _ => plan
}

private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
})
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could extend to other cases like case p @ Project(_, agg: Aggregate), but leave it untouched for now.


// Whether the largest times common outputs from lower operator used in upper operators is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper operators -> upper operator?

// larger than allowed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

than allowed -> than the maximum?

private def moreThanMaxAllowedCommonOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
val aliases = collectAliases(lower)
val exprMap = mutable.HashMap.empty[Attribute, Int]

upper.foreach(_.collect {
case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1)
})

if (exprMap.nonEmpty) {
exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject
} else {
false
}
}

private def haveCommonNonDeterministicOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
val aliases = getAliasMap(lower)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -124,14 +126,32 @@ object ScanOperation extends OperationHelper with PredicateHelper {
}.exists(!_.deterministic))
}

private def moreThanMaxAllowedCommonOutput(
expr: Seq[NamedExpression],
aliases: AttributeMap[Expression]): Boolean = {
val exprMap = mutable.HashMap.empty[Attribute, Int]

expr.foreach(_.collect {
case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1)
})

if (exprMap.nonEmpty) {
exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject
} else {
false
}
}

private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
plan match {
case Project(fields, child) =>
collectProjectsAndFilters(child) match {
case Some((_, filters, other, aliases)) =>
// Follow CollapseProject and only keep going if the collected Projects
// do not have common non-deterministic expressions.
if (!hasCommonNonDeterministic(fields, aliases)) {
// do not have common non-deterministic expressions, and do not have more than
// maximum allowed common outputs.
if (!hasCommonNonDeterministic(fields, aliases) &&
!moreThanMaxAllowedCommonOutput(fields, aliases)) {
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,27 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT =
buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set this value to 1, all the existing tests can pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess not. We might have at lease few common expressions in collapsed projection. If set to 1, any duplicated expression is not allowed.

.doc("An integer number indicates the maximum allowed number of common input expression " +
"from lower Project when being collapsed into upper Project by optimizer rule " +
"`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " +
Copy link
Member

@maropu maropu Oct 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Just a comment) Even if we set spark.sql.optimizer.excludedRules to CollapseProject, it seems like Spark still respects this value in ScanOperation? That behaviour might be okay, but it looks a bit weird to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but currently if we exclude CollapseProject, ScanOperation will work and collapse projections. Maybe update this doc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I see. Yea, updating the doc sounds nice to me.

"and merge expressions. But in some edge cases, expensive expressions might be " +
"duplicated many times in merged Project by this optimization. This config sets " +
"a maximum number. Once an expression is duplicated more than this number " +
"if merging two Project, Spark SQL will skip the merging. Note that normally " +
"in whole-stage codegen Project operator will de-duplicate expressions internally, " +
"but in edge cases Spark cannot do whole-stage codegen and fallback to interpreted " +
"mode. In such cases, users can use this config to avoid duplicate expressions. " +
"Note that even users exclude `CollapseProject` rule using " +
"`spark.sql.optimizer.excludedRules`, at physical planning phase Spark will still " +
"collapse projections. This config is also effective on collapsing projections in " +
"the physical planning.")
.version("3.1.0")
.intConf
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
.checkValue(_ > 0, "The value of maxCommonExprsInCollapseProject must be larger than zero.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

larger than zero -> positive.

.createWithDefault(Int.MaxValue)

val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS =
buildConf("spark.sql.decimalOperations.allowPrecisionLoss")
.internal()
Expand Down Expand Up @@ -3405,6 +3426,8 @@ class SQLConf extends Serializable with Logging {

def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)

def maxCommonExprsInCollapseProject: Int = getConf(MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT)

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Rand}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{MetadataBuilder, StructType}

class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -170,4 +171,59 @@ class CollapseProjectSuite extends PlanTest {
val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze
comparePlans(optimized, expected)
}

test("SPARK-32945: avoid collapsing projects if reaching max allowed common exprs") {
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
val options = Map.empty[String, String]
val schema = StructType.fromDDL("a int, b int, c string, d long")

Seq("1", "2", "3", "4").foreach { maxCommonExprs =>
withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) {
// If we collapse two Projects, `JsonToStructs` will be repeated three times.
val relation = LocalRelation('json.string)
val query1 = relation.select(
JsonToStructs(schema, options, 'json).as("struct"))
Copy link
Member

@dongjoon-hyun dongjoon-hyun Nov 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation? Maybe, the following is better?

- val query1 = relation.select(
-   JsonToStructs(schema, options, 'json).as("struct"))
-   .select(
+ val query1 = relation.select(JsonToStructs(schema, options, 'json).as("struct"))
+   .select(

.select(
GetStructField('struct, 0).as("a"),
GetStructField('struct, 1).as("b"),
GetStructField('struct, 2).as("c"),
GetStructField('struct, 3).as("d")).analyze
val optimized1 = Optimize.execute(query1)

val query2 = relation
.select('json, JsonToStructs(schema, options, 'json).as("struct"))
.select('json, 'struct, GetStructField('struct, 0).as("a"))
.select('json, 'struct, 'a, GetStructField('struct, 1).as("b"))
.select('json, 'struct, 'a, 'b, GetStructField('struct, 2).as("c"))
.analyze
val optimized2 = Optimize.execute(query2)

if (maxCommonExprs.toInt < 4) {
val expected1 = query1
comparePlans(optimized1, expected1)

val expected2 = relation
.select('json, JsonToStructs(schema, options, 'json).as("struct"))
.select('json, 'struct,
GetStructField('struct, 0).as("a"),
GetStructField('struct, 1).as("b"),
GetStructField('struct, 2).as("c"))
.analyze
comparePlans(optimized2, expected2)
} else {
val expected1 = relation.select(
GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"),
GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"),
GetStructField(JsonToStructs(schema, options, 'json), 2).as("c"),
GetStructField(JsonToStructs(schema, options, 'json), 3).as("d")).analyze
comparePlans(optimized1, expected1)

val expected2 = relation.select('json, JsonToStructs(schema, options, 'json).as("struct"),
GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"),
GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"),
GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze
comparePlans(optimized2, expected2)
}
}
}
}
}
50 changes: 48 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ import org.scalatest.matchers.should.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -2567,6 +2568,51 @@ class DataFrameSuite extends QueryTest
val df = l.join(r, $"col2" === $"col4", "LeftOuter")
checkAnswer(df, Row("2", "2"))
}

test("SPARK-32945: Avoid collapsing projects if reaching max allowed common exprs") {
val options = Map.empty[String, String]
val schema = StructType.fromDDL("a int, b int, c long, d string")

withTable("test_table") {
val jsonDf = Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""").toDF("json")
jsonDf.write.saveAsTable("test_table")

Seq("1", "2", "3", "4").foreach { maxCommonExprs =>
withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) {

val jsonDf = spark.read.table("test_table")
val jsonStruct = UnresolvedAttribute("struct")
val df = jsonDf
.select(from_json('json, schema, options).as("struct"))
.select(
Column(GetStructField(jsonStruct, 0)).as("a"),
Column(GetStructField(jsonStruct, 1)).as("b"),
Column(GetStructField(jsonStruct, 2)).as("c"),
Column(GetStructField(jsonStruct, 3)).as("d"))

val numProjects = df.queryExecution.executedPlan.collect {
case p: ProjectExec => p
}.size

val numFromJson = df.queryExecution.executedPlan.collect {
case p: ProjectExec => p.projectList.flatMap(_.collect {
case j: JsonToStructs => j
})
}.flatten.size

if (maxCommonExprs.toInt < 4) {
assert(numProjects == 2)
assert(numFromJson == 1)
} else {
assert(numProjects == 1)
assert(numFromJson == 4)
}

checkAnswer(df, Row(1, 2, 123L, "test"))
}
}
}
}
}

case class GroupByKey(a: Int, b: Int)