Skip to content

Commit

Permalink
Added comment before we start arg max calculation. Updated unit tests…
Browse files Browse the repository at this point in the history
… to cover corner cases
  • Loading branch information
GeorgeDittmar committed May 29, 2015
1 parent f21dcce commit b1f059f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,9 @@ class SparseVector(
-1
} else {

var maxIdx = 0
var maxValue = if(indices(0) != 0) 0.0 else values(0)
//grab first active index and value by default
var maxIdx = indices(0)
var maxValue = values(0)

foreachActive { (i, v) =>
if (v > maxValue) {
Expand All @@ -736,8 +737,8 @@ class SparseVector(
}

// look for inactive values incase all active node values are negative
if(size != values.size && maxValue < 0){
maxIdx = calcInactiveIdx(indices(0))
if(size != values.size && maxValue <= 0){
maxIdx = calcInactiveIdx(0)
maxValue = 0
}
maxIdx
Expand All @@ -748,20 +749,19 @@ class SparseVector(
* Calculates the first instance of an inactive node in a sparse vector and returns the Idx
* of the element.
* @param idx starting index of computation
* @return index of first inactive node or -1 if it cannot find one
* @return index of first inactive node
*/
private[SparseVector] def calcInactiveIdx(idx: Int): Int ={
if(idx < size){
if(!indices.contains(idx)){
private[SparseVector] def calcInactiveIdx(idx: Int): Int = {
if (idx < size) {
if (!indices.contains(idx)) {
idx
}else{
calcInactiveIdx(idx+1)
} else {
calcInactiveIdx(idx + 1)
}
}else{
} else {
-1
}
}

}

object SparseVector {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class VectorsSuite extends FunSuite {
val max4 = vec5.argmax
assert(max4 === 1)

val vec6 = Vectors.sparse(5,Array(0, 1, 2),Array(-1.0, -.025, -.7))
val vec6 = Vectors.sparse(2,Array(0, 1),Array(-1.0, 0.0))
val max5 = vec6.argmax
assert(max5 === 3)
assert(max5 === 1)
}

test("vector equals") {
Expand Down

0 comments on commit b1f059f

Please sign in to comment.