Skip to content

Commit

Permalink
[SPARK-24502][SQL] flaky test: UnsafeRowSerializerSuite
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`UnsafeRowSerializerSuite` calls `UnsafeProjection.create` which accesses `SQLConf.get`, while the current active SparkSession may already be stopped, and we may hit exception like this

```
sbt.ForkMain$ForkError: java.lang.IllegalStateException: LiveListenerBus is stopped.
	at org.apache.spark.scheduler.LiveListenerBus.addToQueue(LiveListenerBus.scala:97)
	at org.apache.spark.scheduler.LiveListenerBus.addToStatusQueue(LiveListenerBus.scala:80)
	at org.apache.spark.sql.internal.SharedState.<init>(SharedState.scala:93)
	at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120)
	at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.sql.SparkSession.sharedState$lzycompute(SparkSession.scala:120)
	at org.apache.spark.sql.SparkSession.sharedState(SparkSession.scala:119)
	at org.apache.spark.sql.internal.BaseSessionStateBuilder.build(BaseSessionStateBuilder.scala:286)
	at org.apache.spark.sql.test.TestSparkSession.sessionState$lzycompute(TestSQLContext.scala:42)
	at org.apache.spark.sql.test.TestSparkSession.sessionState(TestSQLContext.scala:41)
	at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95)
	at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95)
	at scala.Option.map(Option.scala:146)
	at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:95)
	at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:94)
	at org.apache.spark.sql.internal.SQLConf$.get(SQLConf.scala:126)
	at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:54)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:157)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:150)
	at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$unsafeRowConverter(UnsafeRowSerializerSuite.scala:54)
	at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$toUnsafeRow(UnsafeRowSerializerSuite.scala:49)
	at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:63)
	at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:60)
...
```

## How was this patch tested?

N/A

Author: Wenchen Fan <[email protected]>

Closes apache#21518 from cloud-fan/test.
  • Loading branch information
cloud-fan committed Jun 12, 2018
1 parent dc22465 commit 01452ea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self
override def beforeAll() {
super.beforeAll()
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}

override def afterEach() {
try {
resetSparkContext()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
} finally {
super.afterEach()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
import java.util.Properties

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

/**
Expand All @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}

class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {

private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
Expand All @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}

test("toUnsafeRow() test helper method") {
// This currently doesnt work because the generic getter throws an exception.
// This currently doesn't work because the generic getter throws an exception.
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
assert(row.getString(0) === unsafeRow.getUTF8String(0).toString)
Expand Down Expand Up @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}

test("SPARK-10466: external sorter spilling with unsafe row serializer") {
var sc: SparkContext = null
var outputFile: File = null
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
Utils.tryWithSafeFinally {
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")

sc = new SparkContext("local", "test", conf)
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
outputFile.deleteOnExit()
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
} {
// Clean up
if (sc != null) {
sc.stop()
}
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// restore the spark env
SparkEnv.set(oldEnv)
// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)

if (outputFile != null) {
outputFile.delete()
}
}
// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
}

test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
val conf = new SparkConf().set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
val rowsRDD = spark.sparkContext.parallelize(
Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))
).asInstanceOf[RDD[Product2[Int, InternalRow]]]
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rowsRDD,
Expand Down

0 comments on commit 01452ea

Please sign in to comment.