From 35db395c34385f5ee88c914467956cda9a71cc7d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 16 Dec 2014 18:02:57 +0800 Subject: [PATCH] Fix bug and some modifications for comments. --- .../org/apache/spark/mllib/util/MLUtils.scala | 50 ++++++++----------- .../spark/mllib/util/MLUtilsSuite.scala | 8 +++ 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 89fc88a1cb60b..84b87b5c78d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -280,27 +280,24 @@ object MLUtils { var kv1 = 0 var kv2 = 0 - var score = 0.0 - while (kv1 < nnzv1) { - val iv1 = v1Indices(kv1) + while (kv1 < nnzv1 || kv2 < nnzv2) { + var score = 0.0 - if (kv2 >= nnzv2 || iv1 < v2Indices(kv2)) { + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) - squaredDistance += score * score - } - while (kv2 < nnzv2 && v2Indices(kv2) < iv1) { + kv1 += 1 + } else if (kv1 >= nnzv1 || (kv2 < nnzv2 && v2Indices(kv2) < v1Indices(kv1))) { score = v2Values(kv2) - squaredDistance += score * score kv2 += 1 - } - if (kv2 < nnzv2 && v2Indices(kv2) == iv1) { + } else if ((kv1 < nnzv1 && kv2 < nnzv2) && v1Indices(kv1) == v2Indices(kv2)) { score = v1Values(kv1) - v2Values(kv2) - squaredDistance += score * score + kv1 += 1 kv2 += 1 } - kv1 += 1 + squaredDistance += score * score } + // The following two cases are used to handle dense and approximately dense vectors case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => squaredDistance = vectorSquaredDistance(v1, v2) @@ -308,10 +305,10 @@ object MLUtils { squaredDistance = vectorSquaredDistance(v2, v1) case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0)((distance, elems) => { + squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){(distance, elems) => val score = elems._1 - elems._2 distance + score * score - }) + } } squaredDistance } @@ -320,29 +317,26 @@ object MLUtils { * Returns the squared distance between DenseVector and SparseVector. */ private[util] def vectorSquaredDistance(v1: SparseVector, v2: DenseVector): Double = { - var squaredDistance = 0.0 - var count = 0 - var indices = v1.indices - var score = 0.0 - while (count < indices.length) { - val idx = indices(count) - score = v1(idx) - v2(idx) - squaredDistance += score * score - count += 1 - } - var kv1 = 0 var kv2 = 0 + var indices = v1.indices + var squaredDistance = 0.0 var iv1 = indices(kv1) val nnzv2 = v2.size + while (kv2 < nnzv2) { + var score = 0.0 if (kv2 < iv1 || kv2 > iv1) { score = v2(kv2) squaredDistance += score * score } - if (kv2 == iv1 && kv1 < indices.length - 1) { - kv1 += 1 - iv1 = indices(kv1) + if (kv2 == iv1 && kv1 < indices.length) { + score = v1.values(iv1) - v2(kv2) + squaredDistance += score * score + if (kv1 < indices.length - 1) { + kv1 += 1 + iv1 = indices(kv1) + } } kv2 += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 304a3e8b7a59d..640041b7bcf8d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -64,6 +64,14 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") + if (m > 10) { + val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) + val norm4 = Vectors.norm(v4, 2.0) + val squaredDist = breezeSquaredDistance(v2.toBreeze, v4.toBreeze) + val fastSquaredDist = + fastSquaredDistance(v2, norm2, v4, norm4, precision) + assert((fastSquaredDist - squaredDist) <= precision * squaredDist, s"failed with m = $m") + } } }