Skip to content

Commit

Permalink
Update tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed May 7, 2015
1 parent 4513d13 commit c7e2129
Showing 1 changed file with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
table("shuffle").collect())
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}

test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
Expand All @@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200")
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}

/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {

// We are expecting SparkSqlSerializer.
override val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]

override def beforeAll(): Unit = {
super.beforeAll()
// To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201")
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}

0 comments on commit c7e2129

Please sign in to comment.