Skip to content

Commit

Permalink
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Browse files Browse the repository at this point in the history
Utilities for pickling and unpickling SparseMatrices using SerDe

Author: MechCoder <[email protected]>

Closes apache#5775 from MechCoder/spark-7202 and squashes the following commits:

7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
  • Loading branch information
MechCoder authored and jeanlyn committed Jun 12, 2015
1 parent 6ff3519 commit 201aad3
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
}
}

// Pickler for SparseMatrix
private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {

def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val s = obj.asInstanceOf[SparseMatrix]
val order = ByteOrder.nativeOrder()

val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
val valuesBytes = new Array[Byte](8 * s.values.length)
val isTransposed = if (s.isTransposed) 1 else 0
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)

out.write(Opcodes.MARK)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(s.numRows))
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(s.numCols))
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
out.write(colPtrsBytes)
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
out.write(indicesBytes)
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
out.write(valuesBytes)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(isTransposed))
out.write(Opcodes.TUPLE)
}

def construct(args: Array[Object]): Object = {
if (args.length != 6) {
throw new PickleException("should be 6")
}
val order = ByteOrder.nativeOrder()
val colPtrsBytes = getBytes(args(2))
val indicesBytes = getBytes(args(3))
val valuesBytes = getBytes(args(4))
val colPtrs = new Array[Int](colPtrsBytes.length / 4)
val rowIndices = new Array[Int](indicesBytes.length / 4)
val values = new Array[Double](valuesBytes.length / 8)
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
val isTransposed = args(5).asInstanceOf[Int] == 1
new SparseMatrix(
args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
isTransposed)
}
}

// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {

Expand Down Expand Up @@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
if (!initialized) {
new DenseVectorPickler().register()
new DenseMatrixPickler().register()
new SparseMatrixPickler().register()
new SparseVectorPickler().register()
new LabeledPointPickler().register()
new RatingPickler().register()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating

Expand Down Expand Up @@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
val emptyMatrix = Matrices.dense(0, 0, empty)
val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
assert(emptyMatrix == ne)

val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
assert(sm.toArray === nsm.toArray)

val smt = new SparseMatrix(
3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
isTransposed=true)
val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
assert(smt.toArray === nsmt.toArray)
}

test("pickle rating") {
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def __reduce__(self):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),
self.rowIndices.tostring(), self.values.tostring(),
self.isTransposed)
int(self.isTransposed))

def __getitem__(self, indices):
i, j = indices
Expand Down Expand Up @@ -801,7 +801,7 @@ def toDense(self):

# TODO: More efficient implementation:
def __eq__(self, other):
return np.all(self.toArray == other.toArray)
return np.all(self.toArray() == other.toArray())


class Matrices(object):
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def test_serialize(self):
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
self._test_serialize(SparseVector(3, {}))
self._test_serialize(DenseMatrix(2, 3, range(6)))
sm1 = SparseMatrix(
3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
self._test_serialize(sm1)

def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})
Expand Down

0 comments on commit 201aad3

Please sign in to comment.