diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index dce1420793169..c9a7dd0b7ecc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -728,7 +728,7 @@ class SparseVector( var maxValue = values(0) foreachActive { (i, v) => - if(v > maxValue){ + if(v != 0.0 && v > maxValue){ maxIdx = i maxValue = v } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 7d35186df62cd..24d614b0eb973 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -71,6 +71,10 @@ class VectorsSuite extends FunSuite { val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector] val max = vec2.argmax assert(max === 3) + + val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector] + val max2 = vec3.argmax + assert(max === 3) } test("sparse to array") { @@ -87,9 +91,10 @@ class VectorsSuite extends FunSuite { val max = vec2.argmax assert(max === 3) - val vec3 = Vectors.sparse(5,Array(1,3,4),Array(1.0,.5,.7)) + // check for case that sparse vector is created with a zero value in it by mistake + val vec3 = Vectors.sparse(5,Array(0, 2, 4),Array(-1.0, 0.0, -.7)) val max2 = vec3.argmax - assert(max2 === 1) + assert(max2 === 4) } test("vector equals") {