Skip to content

Commit

Permalink
Fix bug and some modifications for comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 16, 2014
1 parent f4f5ebb commit 35db395
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
50 changes: 22 additions & 28 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,38 +280,35 @@ object MLUtils {

var kv1 = 0
var kv2 = 0
var score = 0.0
while (kv1 < nnzv1) {
val iv1 = v1Indices(kv1)
while (kv1 < nnzv1 || kv2 < nnzv2) {
var score = 0.0

if (kv2 >= nnzv2 || iv1 < v2Indices(kv2)) {
if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
score = v1Values(kv1)
squaredDistance += score * score
}
while (kv2 < nnzv2 && v2Indices(kv2) < iv1) {
kv1 += 1
} else if (kv1 >= nnzv1 || (kv2 < nnzv2 && v2Indices(kv2) < v1Indices(kv1))) {
score = v2Values(kv2)
squaredDistance += score * score
kv2 += 1
}
if (kv2 < nnzv2 && v2Indices(kv2) == iv1) {
} else if ((kv1 < nnzv1 && kv2 < nnzv2) && v1Indices(kv1) == v2Indices(kv2)) {
score = v1Values(kv1) - v2Values(kv2)
squaredDistance += score * score
kv1 += 1
kv2 += 1
}
kv1 += 1
squaredDistance += score * score
}

// The following two cases are used to handle dense and approximately dense vectors
case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
squaredDistance = vectorSquaredDistance(v1, v2)

case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 =>
squaredDistance = vectorSquaredDistance(v2, v1)

case (v1, v2) =>
squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0)((distance, elems) => {
squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){(distance, elems) =>
val score = elems._1 - elems._2
distance + score * score
})
}
}
squaredDistance
}
Expand All @@ -320,29 +317,26 @@ object MLUtils {
* Returns the squared distance between DenseVector and SparseVector.
*/
private[util] def vectorSquaredDistance(v1: SparseVector, v2: DenseVector): Double = {
var squaredDistance = 0.0
var count = 0
var indices = v1.indices
var score = 0.0
while (count < indices.length) {
val idx = indices(count)
score = v1(idx) - v2(idx)
squaredDistance += score * score
count += 1
}

var kv1 = 0
var kv2 = 0
var indices = v1.indices
var squaredDistance = 0.0
var iv1 = indices(kv1)
val nnzv2 = v2.size

while (kv2 < nnzv2) {
var score = 0.0
if (kv2 < iv1 || kv2 > iv1) {
score = v2(kv2)
squaredDistance += score * score
}
if (kv2 == iv1 && kv1 < indices.length - 1) {
kv1 += 1
iv1 = indices(kv1)
if (kv2 == iv1 && kv1 < indices.length) {
score = v1.values(iv1) - v2(kv2)
squaredDistance += score * score
if (kv1 < indices.length - 1) {
kv1 += 1
iv1 = indices(kv1)
}
}
kv2 += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
val fastSquaredDist3 =
fastSquaredDistance(v2, norm2, v3, norm3, precision)
assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m")
if (m > 10) {
val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10))
val norm4 = Vectors.norm(v4, 2.0)
val squaredDist = breezeSquaredDistance(v2.toBreeze, v4.toBreeze)
val fastSquaredDist =
fastSquaredDistance(v2, norm2, v4, norm4, precision)
assert((fastSquaredDist - squaredDist) <= precision * squaredDist, s"failed with m = $m")
}
}
}

Expand Down

0 comments on commit 35db395

Please sign in to comment.