From f4f5ebb59b10ab09414690498bdeadceb94ca2e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 11 Dec 2014 20:46:00 +0800 Subject: [PATCH] Follow BLAS.dot pattern to replace intersect, diff with while-loop. --- .../org/apache/spark/mllib/util/MLUtils.scala | 77 +++++++++++-------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7b25ef6ecbc62..89fc88a1cb60b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -271,32 +271,34 @@ object MLUtils { var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => - var count = 0 - var indices = v1.indices.intersect(v2.indices) - - while (count < indices.length) { - val idx = indices(count) - val score = v1(idx) - v2(idx) - squaredDistance += score * score - count += 1 - } - - count = 0 - indices = v1.indices.diff(v2.indices) - while (count < indices.length) { - val idx = indices(count) - val score = v1(idx) - squaredDistance += score * score - count += 1 - } + val v1Values = v1.values + val v1Indices = v1.indices + val v2Values = v2.values + val v2Indices = v2.indices + val nnzv1 = v1Indices.size + val nnzv2 = v2Indices.size - count = 0 - indices = v2.indices.diff(v1.indices) - while (count < indices.length) { - val idx = indices(count) - val score = v2(idx) - squaredDistance += score * score - count += 1 + var kv1 = 0 + var kv2 = 0 + var score = 0.0 + while (kv1 < nnzv1) { + val iv1 = v1Indices(kv1) + + if (kv2 >= nnzv2 || iv1 < v2Indices(kv2)) { + score = v1Values(kv1) + squaredDistance += score * score + } + while (kv2 < nnzv2 && v2Indices(kv2) < iv1) { + score = v2Values(kv2) + squaredDistance += score * score + kv2 += 1 + } + if (kv2 < nnzv2 && v2Indices(kv2) == iv1) { + score = v1Values(kv1) - v2Values(kv2) + squaredDistance += score * score + kv2 += 1 + } + kv1 += 1 } case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => @@ -321,21 +323,28 @@ object MLUtils { var squaredDistance = 0.0 var count = 0 var indices = v1.indices - + var score = 0.0 while (count < indices.length) { val idx = indices(count) - val score = v1(idx) - v2(idx) + score = v1(idx) - v2(idx) squaredDistance += score * score count += 1 } - count = 0 - indices = (0 to v2.size - 1).toArray.diff(v1.indices) - while (count < indices.length) { - val idx = indices(count) - val score = v2(idx) - squaredDistance += score * score - count += 1 + var kv1 = 0 + var kv2 = 0 + var iv1 = indices(kv1) + val nnzv2 = v2.size + while (kv2 < nnzv2) { + if (kv2 < iv1 || kv2 > iv1) { + score = v2(kv2) + squaredDistance += score * score + } + if (kv2 == iv1 && kv1 < indices.length - 1) { + kv1 += 1 + iv1 = indices(kv1) + } + kv2 += 1 } squaredDistance }