diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index e048b01d92462..9067b3ba9a7bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -150,6 +150,12 @@ sealed trait Vector extends Serializable { toDense } } + + /** + * Find the index of a maximal element. Returns the first maximal element in case of a tie. + * Returns -1 if vector has length 0. + */ + def argmax: Int } /** @@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector { new SparseVector(size, ii, vv) } - /** - * Find the index of a maximal element. Returns the first maximal element in case of a tie. - * Returns -1 if vector has length 0. - */ - private[spark] def argmax: Int = { + override def argmax: Int = { if (size == 0) { -1 } else { @@ -717,6 +719,51 @@ class SparseVector( new SparseVector(size, ii, vv) } } + + override def argmax: Int = { + if (size == 0) { + -1 + } else { + // Find the max active entry. + var maxIdx = indices(0) + var maxValue = values(0) + var maxJ = 0 + var j = 1 + val na = numActives + while (j < na) { + val v = values(j) + if (v > maxValue) { + maxValue = v + maxIdx = indices(j) + maxJ = j + } + j += 1 + } + + // If the max active entry is nonpositive and there exists inactive ones, find the first zero. + if (maxValue <= 0.0 && na < size) { + if (maxValue == 0.0) { + // If there exists an inactive entry before maxIdx, find it and return its index. + if (maxJ < maxIdx) { + var k = 0 + while (k < maxJ && indices(k) == k) { + k += 1 + } + maxIdx = k + } + } else { + // If the max active value is negative, find and return the first inactive index. + var k = 0 + while (k < na && indices(k) == k) { + k += 1 + } + maxIdx = k + } + } + + maxIdx + } + } } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 178d95a7b94ec..03be4119bdaca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(vec.toArray.eq(arr)) } + test("dense argmax") { + val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector] + assert(vec.argmax === -1) + + val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector] + assert(vec2.argmax === 3) + + val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector] + assert(vec3.argmax === 3) + } + test("sparse to array") { val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] assert(vec.toArray === arr) } + test("sparse argmax") { + val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec.argmax === -1) + + val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] + assert(vec2.argmax === 3) + + val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7)) + assert(vec3.argmax === 2) + + // check for case that sparse vector is created with + // only negative values {0.0, 0.0,-1.0, -0.7, 0.0} + val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7)) + assert(vec4.argmax === 0) + + val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0)) + assert(vec5.argmax === 1) + + val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0)) + assert(vec6.argmax === 2) + + val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7)) + assert(vec7.argmax === 1) + + val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) + assert(vec8.argmax === 0) + } + test("vector equals") { val dv1 = Vectors.dense(arr.clone()) val dv2 = Vectors.dense(arr.clone()) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 36417f5df9f2d..dd852547492aa 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -98,6 +98,10 @@ object MimaExcludes { "org.apache.spark.api.r.StringRRDD.this"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.r.BaseRRDD.this") + ) ++ Seq( + // SPARK-7422 add argmax for sparse vectors + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.argmax") ) case v if v.startsWith("1.4") =>