Skip to content

Commit

Permalink
Follow BLAS.dot pattern to replace intersect, diff with while-loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 11, 2014
1 parent a36e09f commit f4f5ebb
Showing 1 changed file with 43 additions and 34 deletions.
77 changes: 43 additions & 34 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,32 +271,34 @@ object MLUtils {
var squaredDistance = 0.0
(v1, v2) match {
case (v1: SparseVector, v2: SparseVector) =>
var count = 0
var indices = v1.indices.intersect(v2.indices)

while (count < indices.length) {
val idx = indices(count)
val score = v1(idx) - v2(idx)
squaredDistance += score * score
count += 1
}

count = 0
indices = v1.indices.diff(v2.indices)
while (count < indices.length) {
val idx = indices(count)
val score = v1(idx)
squaredDistance += score * score
count += 1
}
val v1Values = v1.values
val v1Indices = v1.indices
val v2Values = v2.values
val v2Indices = v2.indices
val nnzv1 = v1Indices.size
val nnzv2 = v2Indices.size

count = 0
indices = v2.indices.diff(v1.indices)
while (count < indices.length) {
val idx = indices(count)
val score = v2(idx)
squaredDistance += score * score
count += 1
var kv1 = 0
var kv2 = 0
var score = 0.0
while (kv1 < nnzv1) {
val iv1 = v1Indices(kv1)

if (kv2 >= nnzv2 || iv1 < v2Indices(kv2)) {
score = v1Values(kv1)
squaredDistance += score * score
}
while (kv2 < nnzv2 && v2Indices(kv2) < iv1) {
score = v2Values(kv2)
squaredDistance += score * score
kv2 += 1
}
if (kv2 < nnzv2 && v2Indices(kv2) == iv1) {
score = v1Values(kv1) - v2Values(kv2)
squaredDistance += score * score
kv2 += 1
}
kv1 += 1
}

case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
Expand All @@ -321,21 +323,28 @@ object MLUtils {
var squaredDistance = 0.0
var count = 0
var indices = v1.indices

var score = 0.0
while (count < indices.length) {
val idx = indices(count)
val score = v1(idx) - v2(idx)
score = v1(idx) - v2(idx)
squaredDistance += score * score
count += 1
}

count = 0
indices = (0 to v2.size - 1).toArray.diff(v1.indices)
while (count < indices.length) {
val idx = indices(count)
val score = v2(idx)
squaredDistance += score * score
count += 1
var kv1 = 0
var kv2 = 0
var iv1 = indices(kv1)
val nnzv2 = v2.size
while (kv2 < nnzv2) {
if (kv2 < iv1 || kv2 > iv1) {
score = v2(kv2)
squaredDistance += score * score
}
if (kv2 == iv1 && kv1 < indices.length - 1) {
kv1 += 1
iv1 = indices(kv1)
}
kv2 += 1
}
squaredDistance
}
Expand Down

0 comments on commit f4f5ebb

Please sign in to comment.