Skip to content

Commit

Permalink
improve RowMatrix multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
Li Pu committed Jun 26, 2014
1 parent 5543cce commit 7148426
Showing 1 changed file with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,24 @@ class RowMatrix(
*/
def multiply(B: Matrix): RowMatrix = {
val n = numCols().toInt
val k = B.numCols
require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")

require(B.isInstanceOf[DenseMatrix],
s"Only support dense matrix at this time but found ${B.getClass.getName}.")

val Bb = rows.context.broadcast(B)
val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
val AB = rows.mapPartitions({ iter =>
val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]]
iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze))
val Bi = Bb.value
iter.map(row => {
val v = BDV.zeros[Double](k)
var i = 0
while (i < k) {
v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n))
i += 1
}
Vectors.fromBreeze(v)
})
}, preservesPartitioning = true)

new RowMatrix(AB, nRows, B.numCols)
Expand Down

0 comments on commit 7148426

Please sign in to comment.