From 70b93e32d2c1e4b7d09197a136404db661a0f95a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sun, 12 Jul 2015 16:48:27 -0700 Subject: [PATCH] Performance improvements in LocalPrefixSpan, fix tests --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 48 ++++++++++--------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 6 ++- .../spark/mllib/fpm/PrefixSpanSuite.scala | 14 ++---- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 39c48b084e550..892bfa61403e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -20,6 +20,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import scala.collection.mutable.ArrayBuffer + /** * * :: Experimental :: @@ -42,22 +44,20 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def run( minCount: Long, maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { + prefix: ArrayBuffer[Int], + projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = { val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) + .map(x => ((prefix :+ x._1).toArray, x._2)) val prefixProjectedDatabases = getPatternAndProjectedDatabase( prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => run(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns + if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) { + frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap { + case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB) + } } else { - frequentPatternAndCounts + frequentPatternAndCounts.iterator } } @@ -86,28 +86,30 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, sequences: Array[Array[Int]]): Array[(Int, Long)] = { sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) + .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => + ctr + (item -> (ctr(item) + 1)) + } .filter(_._2 >= minCount) .toArray } /** * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database + * @param prefix the frequent prefixes' prefix + * @param frequentPrefixes frequent next prefixes + * @param projDB projected database for given prefix + * @return extensions of prefix by one item and corresponding projected databases */ private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], + prefix: ArrayBuffer[Int], frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) + projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = { + val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_))) + frequentPrefixes.map { nextItem => + val nextProjDB = filteredProjectedDatabase + .map(candidateSeq => getSuffix(nextItem, candidateSeq)) + .filter(_.nonEmpty) + (prefix :+ nextItem, nextProjDB) }.filter(x => x._2.nonEmpty) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9d8c60ef0fc45..0bccb37fa9cd5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -22,6 +22,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import scala.collection.mutable.ArrayBuffer + /** * * :: Experimental :: @@ -150,8 +152,8 @@ class PrefixSpan private ( private def getPatternsInLocal( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { x => - LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) + data.flatMap { case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 413436d3db85f..87b87569e2ec9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { test("PrefixSpan using Integer type") { @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - val sortedExpectedValue = expectedValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - val sortedActualValue = actualValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - sortedExpectedValue.zip(sortedActualValue) - .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) - .reduce(_&&_) + expectedValue.map(x => (x._1.toList, x._2)).toSet == + actualValue.map(x => (x._1.toList, x._2)).toSet } val prefixspan = new PrefixSpan()