Skip to content

Commit

Permalink
Josh's comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed May 7, 2015
1 parent 487f540 commit 53a5eaa
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ case class Exchange(
def serializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = {
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
Expand All @@ -99,7 +100,7 @@ case class Exchange(

val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
Expand Down Expand Up @@ -142,7 +143,8 @@ case class Exchange(
}
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
shuffled.setSerializer(
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._2)

Expand All @@ -167,7 +169,8 @@ case class Exchange(
new ShuffledRDD[Row, Null, Null](rdd, part)
}
val keySchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
shuffled.setSerializer(
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._1)

Expand All @@ -187,7 +190,7 @@ case class Exchange(
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(null, valueSchema, 1))
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
shuffled.map(_._2)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
import org.apache.spark.sql.types._

/**
Expand All @@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
out: OutputStream)
extends SerializationStream with Logging {

val rowOut = new DataOutputStream(new BufferedOutputStream(out))
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)

override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
Expand Down Expand Up @@ -86,24 +86,44 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
in: InputStream)
extends DeserializationStream with Logging {

val rowIn = new DataInputStream(new BufferedInputStream(in))
private val rowIn = new DataInputStream(new BufferedInputStream(in))

val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
if (schema == null) {
() => null
} else {
if (hasKeyOrdering) {
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
() => new GenericMutableRow(schema.length)
} else {
// It is safe to reuse the mutable row.
val mutableRow = new SpecificMutableRow(schema)
() => mutableRow
}
}
}

// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)

override def readObject[T: ClassTag](): T = {
(readKeyFunc(), readValueFunc()).asInstanceOf[T]
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
}

override def readKey[T: ClassTag](): T = {
readKeyFunc().asInstanceOf[T]
readKeyFunc(getKey()).asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
readValueFunc().asInstanceOf[T]
readValueFunc(getValue()).asInstanceOf[T]
}

override def close(): Unit = {
Expand All @@ -113,7 +133,8 @@ private[sql] class Serializer2DeserializationStream(

private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer =
Expand All @@ -130,7 +151,7 @@ private[sql] class SparkSqlSerializer2Instance(
}

def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
}
}

Expand All @@ -141,12 +162,16 @@ private[sql] class SparkSqlSerializer2Instance(
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
*/
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends Serializer
with Logging
with Serializable{

def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(keySchema, valueSchema)
def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down Expand Up @@ -316,12 +341,12 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream): () => Row = {
() => {
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
in: DataInputStream): (MutableRow) => Row = {
if (schema == null) {
(mutableRow: MutableRow) => null
} else {
(mutableRow: MutableRow) => {
var i = 0
val mutableRow = new GenericMutableRow(schema.length)
while (i < schema.length) {
schema(i) match {
// When we read values from the underlying stream, we also first read the null byte
Expand Down Expand Up @@ -435,8 +460,6 @@ private[sql] object SparkSqlSerializer2 {
}

mutableRow
} else {
null
}
}
}
Expand Down

0 comments on commit 53a5eaa

Please sign in to comment.