Skip to content

Commit

Permalink
Always create a new row at the deserialization side to work with sort…
Browse files Browse the repository at this point in the history
… merge join.
  • Loading branch information
yhuai committed May 7, 2015
1 parent c7e2129 commit 8385f95
Showing 1 changed file with 14 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -91,34 +91,27 @@ private[sql] class Serializer2DeserializationStream(

val rowIn = new DataInputStream(new BufferedInputStream(in))

val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)

override def readObject[T: ClassTag](): T = {
readKeyFunc()
readValueFunc()

(key, value).asInstanceOf[T]
(readKeyFunc(), readValueFunc()).asInstanceOf[T]
}

override def readKey[T: ClassTag](): T = {
readKeyFunc()
key.asInstanceOf[T]
readKeyFunc().asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
readValueFunc()
value.asInstanceOf[T]
readValueFunc().asInstanceOf[T]
}

override def close(): Unit = {
rowIn.close()
}
}

private[sql] class ShuffleSerializerInstance(
private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
extends SerializerInstance {
Expand Down Expand Up @@ -153,7 +146,7 @@ private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema:
with Logging
with Serializable{

def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(keySchema, valueSchema)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down Expand Up @@ -323,12 +316,12 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream,
mutableRow: SpecificMutableRow): () => Unit = {
in: DataInputStream): () => Row = {
() => {
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
var i = 0
val mutableRow = new GenericMutableRow(schema.length)
while (i < schema.length) {
schema(i) match {
// When we read values from the underlying stream, we also first read the null byte
Expand Down Expand Up @@ -440,6 +433,10 @@ private[sql] object SparkSqlSerializer2 {
}
i += 1
}

mutableRow
} else {
null
}
}
}
Expand Down

0 comments on commit 8385f95

Please sign in to comment.