Skip to content

Commit

Permalink
[SPARK-32528][SQL][TEST] The analyze method should make sure the plan…
Browse files Browse the repository at this point in the history
… is analyzed

### What changes were proposed in this pull request?

This PR updates the `analyze` method to make sure the plan can be resolved. It also fixes some miswritten optimizer tests.

### Why are the changes needed?

It's error-prone if the `analyze` method can return an unresolved plan.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

test only

Closes #29349 from cloud-fan/test.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan committed Aug 7, 2020
1 parent aa4d3c1 commit d5682c1
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,11 @@ package object dsl {
def distribute(exprs: Expression*)(n: Int): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)

def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
def analyze: LogicalPlan = {
val analyzed = analysis.SimpleAnalyzer.execute(logicalPlan)
analysis.SimpleAnalyzer.checkAnalysis(analyzed)
EliminateSubqueryAliases(analyzed)
}

def hint(name: String, parameters: Any*): LogicalPlan =
UnresolvedHint(name, parameters, logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.log4j.Level
import org.scalatest.matchers.must.Matchers

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -47,6 +48,13 @@ import org.apache.spark.sql.types._
class AnalysisSuite extends AnalysisTest with Matchers {
import org.apache.spark.sql.catalyst.analysis.TestRelations._

test("fail for unresolved plan") {
intercept[AnalysisException] {
// `testRelation` does not have column `b`.
testRelation.select('b).analyze
}
}

test("union project *") {
val plan = (1 to 120)
.map(_ => testRelation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper

// GetStructField with different names are semantically equal; thus, `EqualTo(fieldA1, fieldA2)`
// will be optimized to `TrueLiteral` by `SimplifyBinaryComparison`.
val originalQuery = nonNullableRelation
.where(EqualTo(fieldA1, fieldA2))
.analyze
val originalQuery = nonNullableRelation.where(EqualTo(fieldA1, fieldA2))

val optimized = Optimize.execute(originalQuery)
val correctAnswer = nonNullableRelation.analyze
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.sql.types.{BooleanType, IntegerType, TimestampType}
import org.apache.spark.unsafe.types.CalendarInterval

class FilterPushdownSuite extends PlanTest {
Expand Down Expand Up @@ -678,14 +678,14 @@ class FilterPushdownSuite extends PlanTest {
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, alias = Some("arr"))
.generate(generator, alias = Some("arr"), outputNames = Seq("c"))
.where(('b >= 5) && ('c > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val referenceResult = {
testRelationWithArrayType
.where('b >= 5)
.generate(generator, alias = Some("arr"))
.generate(generator, alias = Some("arr"), outputNames = Seq("c"))
.where('c > 6).analyze
}

Expand Down Expand Up @@ -1149,75 +1149,60 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer)
}

test("join condition pushdown: deterministic and non-deterministic") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 &&
"x.a".attr === Rand(10) && "y.b".attr === 5))
val correctAnswer =
x.where("x.a".attr === 5).join(y.where("y.a".attr === 5 && "y.b".attr === 5),
condition = Some("x.a".attr === Rand(10)))

// CheckAnalysis will ensure nondeterministic expressions not appear in join condition.
// TODO support nondeterministic expressions in join condition.
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
}

test("watermark pushdown: no pushdown on watermark attribute #1") {
val interval = new CalendarInterval(2, 2, 2000L)
val relation = LocalRelation(attrA, 'b.timestamp, attrC)

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('b, interval, testRelation)
.where('a === 5 && 'b === 10 && 'c === 5)
val originalQuery = EventTimeWatermark('b, interval, relation)
.where('a === 5 && 'b === new java.sql.Timestamp(0) && 'c === 5)
val correctAnswer = EventTimeWatermark(
'b, interval, testRelation.where('a === 5 && 'c === 5))
.where('b === 10)
'b, interval, relation.where('a === 5 && 'c === 5))
.where('b === new java.sql.Timestamp(0))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("watermark pushdown: no pushdown for nondeterministic filter") {
val interval = new CalendarInterval(2, 2, 2000L)
val relation = LocalRelation(attrA, attrB, 'c.timestamp)

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('c, interval, testRelation)
.where('a === 5 && 'b === Rand(10) && 'c === 5)
val originalQuery = EventTimeWatermark('c, interval, relation)
.where('a === 5 && 'b === Rand(10) && 'c === new java.sql.Timestamp(0))
val correctAnswer = EventTimeWatermark(
'c, interval, testRelation.where('a === 5))
.where('b === Rand(10) && 'c === 5)
'c, interval, relation.where('a === 5))
.where('b === Rand(10) && 'c === new java.sql.Timestamp(0))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
}

test("watermark pushdown: full pushdown") {
val interval = new CalendarInterval(2, 2, 2000L)
val relation = LocalRelation(attrA, attrB, 'c.timestamp)

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('c, interval, testRelation)
val originalQuery = EventTimeWatermark('c, interval, relation)
.where('a === 5 && 'b === 10)
val correctAnswer = EventTimeWatermark(
'c, interval, testRelation.where('a === 5 && 'b === 10))
'c, interval, relation.where('a === 5 && 'b === 10))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
}

test("watermark pushdown: no pushdown on watermark attribute #2") {
val interval = new CalendarInterval(2, 2, 2000L)
val relation = LocalRelation('a.timestamp, attrB, attrC)

val originalQuery = EventTimeWatermark('a, interval, testRelation)
.where('a === 5 && 'b === 10)
val originalQuery = EventTimeWatermark('a, interval, relation)
.where('a === new java.sql.Timestamp(0) && 'b === 10)
val correctAnswer = EventTimeWatermark(
'a, interval, testRelation.where('b === 10)).where('a === 5)
'a, interval, relation.where('b === 10)).where('a === new java.sql.Timestamp(0))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
val originalPlan2 =
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t4, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))

assertEqualPlans(originalPlan2, originalPlan2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ class PullupCorrelatedPredicatesSuite extends PlanTest {
.select(max('d))
val scalarSubquery =
testRelation
.where(ScalarSubquery(subPlan))
.where(ScalarSubquery(subPlan) === 1)
.select('a).analyze
assert(scalarSubquery.resolved)

val optimized = Optimize.execute(scalarSubquery)
val doubleOptimized = Optimize.execute(optimized)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down Expand Up @@ -50,10 +51,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
}

test("Not expected type - replaceNullWithFalse") {
val e = intercept[IllegalArgumentException] {
val e = intercept[AnalysisException] {
testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral)
}.getMessage
assert(e.contains("but got the type `int` in `CAST(NULL AS INT)"))
assert(e.contains("'CAST(NULL AS INT)' of type int is not a boolean"))
}

test("replace null in branches of If") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class SimplifyCastsSuite extends PlanTest {

test("nullable element to non-nullable element array cast") {
val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
val attr = input.output.head
val plan = input.select(attr.cast(ArrayType(IntegerType, false)).as("casted"))
val optimized = Optimize.execute(plan)
// Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not
// allowed, here we just ensure that `SimplifyCasts` rule respect the plan.
Expand All @@ -60,13 +61,13 @@ class SimplifyCastsSuite extends PlanTest {

test("nullable value map to non-nullable value map cast") {
val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
val plan = input.select('m.cast(MapType(StringType, StringType, false))
.as("casted")).analyze
val attr = input.output.head
val plan = input.select(attr.cast(MapType(StringType, StringType, false))
.as("casted"))
val optimized = Optimize.execute(plan)
// Though cast from `MapType(StringType, StringType, true)` to
// `MapType(StringType, StringType, false)` is not allowed, here we just ensure that
// `SimplifyCasts` rule respect the plan.
comparePlans(optimized, plan, checkAnalysis = false)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}
import org.apache.spark.sql.types.{BooleanType, IntegerType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Expand All @@ -34,20 +35,18 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
}

private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, relation).analyze)
comparePlans(actual, correctAnswer)
}

private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal.create(null, NullType), Literal(30))

val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a")))
val isNullCond = IsNull(UnresolvedAttribute("b"))
val notCond = Not(UnresolvedAttribute("c"))
private val nullBranch = (Literal.create(null, BooleanType), Literal(30))

test("simplify if") {
assertEquivalent(
Expand All @@ -59,7 +58,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(20))

assertEquivalent(
If(Literal.create(null, NullType), Literal(10), Literal(20)),
If(Literal.create(null, BooleanType), Literal(10), Literal(20)),
Literal(20))
}

Expand Down Expand Up @@ -127,9 +126,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("simplify CaseWhen if all the outputs are semantic equivalence") {
// When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed.
assertEquivalent(
CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) ::
(isNullCond, Literal(1)) ::
(notCond, Add(Literal(6), Literal(-5))) ::
CaseWhen(('a.isNotNull, Subtract(Literal(3), Literal(2))) ::
('b.isNull, Literal(1)) ::
(!'c, Add(Literal(6), Literal(-5))) ::
Nil,
Add(Literal(2), Literal(-1))),
Literal(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
assertEqualPlans(query, expected)
}

test("Test 3: Star join on a subset of dimensions since join column is not unique") {
test("Test 3: Star join on a subset of dimensions since join column is not unique") {
// Star join:
// (=) (=)
// d1 - f1 - d2
Expand Down Expand Up @@ -254,9 +254,9 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
val expected =
f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
.join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
Expand Down Expand Up @@ -316,20 +316,23 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
// Positional join reordering: d3_ns, f1, d1, d2, s3
// Star join reordering: empty

val d3_pk1 = d3_ns.output.find(_.name == "d3_pk1").get
val d3_fk1 = d3_ns.output.find(_.name == "d3_fk1").get

val query =
d3_ns.join(f1).join(d1).join(d2).join(s3)
.where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
(nameToAttr("d2_c2") === 2) &&
(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
(nameToAttr("f1_fk3") === d3_pk1) &&
(d3_fk1 === nameToAttr("s3_pk1")))

val equivQuery =
d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === d3_pk1))
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.join(s3, Inner, Some(d3_fk1 === nameToAttr("s3_pk1")))

assertEqualPlans(query, equivQuery)
}
Expand Down

0 comments on commit d5682c1

Please sign in to comment.