diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 97568c9c659e7..f5e325330701e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -134,7 +134,7 @@ object LinearDataGenerator { val rnd = new Random(seed) val rndG = new Random(seed) - if (sparsity <= 0.0) { + if (sparsity == 0.0) { (0 until nPoints).map { _ => val features = Vectors.dense((0 until weights.length).map { i => (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) @@ -148,7 +148,7 @@ object LinearDataGenerator { val sparseRnd = new Random(seed) (0 until nPoints).map { _ => val (values, indices) = (0 until weights.length).filter { _ => - sparseRnd.nextDouble() >= sparsity }.map { i => + sparseRnd.nextDouble() <= sparsity }.map { i => ((rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i), i) }.unzip val features = Vectors.sparse(weights.length, indices.toArray, values.toArray)