Skip to content

Commit

Permalink
[SPARK-6986] [SQL] Use Serializer2 in more cases.
Browse files Browse the repository at this point in the history
With 0a2b15c, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases.

Author: Yin Huai <[email protected]>

Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits:

53a5eaa [Yin Huai] Josh's comments.
487f540 [Yin Huai] Use BufferedOutputStream.
8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join.
c7e2129 [Yin Huai] Update tests.
4513d13 [Yin Huai] Use Serializer2 in more places.

(cherry picked from commit 3af423c)
Signed-off-by: Yin Huai <[email protected]>
  • Loading branch information
yhuai committed May 8, 2015
1 parent 28d4238 commit 9d0d289
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,8 @@ case class Exchange(
def serializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = {
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
// through write(key) and then write(value) instead of write((key, value)). Because
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
// it when spillToMergeableFile in ExternalSorter will be used.
// So, we will not use SparkSqlSerializer2 when
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
// then the bypassMergeThreshold; or
// - newOrdering is defined.
val cannotUseSqlSerializer2 =
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty

// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
Expand All @@ -104,14 +94,13 @@ case class Exchange(

val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
!noField

val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
Expand Down Expand Up @@ -154,7 +143,8 @@ case class Exchange(
}
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
shuffled.setSerializer(
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._2)

Expand All @@ -179,7 +169,8 @@ case class Exchange(
new ShuffledRDD[Row, Null, Null](rdd, part)
}
val keySchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
shuffled.setSerializer(
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._1)

Expand All @@ -199,7 +190,7 @@ case class Exchange(
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(null, valueSchema, 1))
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
shuffled.map(_._2)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
import org.apache.spark.sql.types._

/**
Expand All @@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
out: OutputStream)
extends SerializationStream with Logging {

val rowOut = new DataOutputStream(out)
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)

override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
Expand Down Expand Up @@ -86,41 +86,55 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
in: InputStream)
extends DeserializationStream with Logging {

val rowIn = new DataInputStream(new BufferedInputStream(in))
private val rowIn = new DataInputStream(new BufferedInputStream(in))

private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
if (schema == null) {
() => null
} else {
if (hasKeyOrdering) {
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
() => new GenericMutableRow(schema.length)
} else {
// It is safe to reuse the mutable row.
val mutableRow = new SpecificMutableRow(schema)
() => mutableRow
}
}
}

val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)

override def readObject[T: ClassTag](): T = {
readKeyFunc()
readValueFunc()

(key, value).asInstanceOf[T]
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
}

override def readKey[T: ClassTag](): T = {
readKeyFunc()
key.asInstanceOf[T]
readKeyFunc(getKey()).asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
readValueFunc()
value.asInstanceOf[T]
readValueFunc(getValue()).asInstanceOf[T]
}

override def close(): Unit = {
rowIn.close()
}
}

private[sql] class ShuffleSerializerInstance(
private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer =
Expand All @@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
}

def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
}
}

Expand All @@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
*/
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends Serializer
with Logging
with Serializable{

def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down Expand Up @@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream,
mutableRow: SpecificMutableRow): () => Unit = {
() => {
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
in: DataInputStream): (MutableRow) => Row = {
if (schema == null) {
(mutableRow: MutableRow) => null
} else {
(mutableRow: MutableRow) => {
var i = 0
while (i < schema.length) {
schema(i) match {
Expand Down Expand Up @@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
}
i += 1
}

mutableRow
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
table("shuffle").collect())
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}

test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
Expand All @@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200")
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}

/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {

// We are expecting SparkSqlSerializer.
override val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]

override def beforeAll(): Unit = {
super.beforeAll()
// To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201")
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}

0 comments on commit 9d0d289

Please sign in to comment.