diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 307034f7cd607..7ead6327486cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -30,34 +30,25 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param database the projected dabase + * @param prefixes prefixes in reversed order + * @param database the projected database * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items), + * the key of pair is sequential pattern (a list of items in reversed order), * the value of pair is the pattern's count. */ def run( minCount: Long, maxPatternLength: Int, - prefix: List[Int], + prefixes: List[Int], database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { - - if (database.isEmpty) return Iterator.empty - + if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val frequentItems = frequentItemAndCounts.map(_._1).toSet - val frequentPatternAndCounts = frequentItemAndCounts - .map { case (item, count) => ((item :: prefix), count) } - - - if (prefix.length + 1 < maxPatternLength) { - val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) - frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => - val nextProjected = project(filteredProjectedDatabase, item) - run(minCount, maxPatternLength, item :: prefix, nextProjected) - } - } else { - frequentPatternAndCounts.iterator + val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) + frequentItemAndCounts.iterator.flatMap { case (item, count) => + val newPrefixes = item :: prefixes + val newProjected = project(filteredDatabase, item) + Iterator.single((newPrefixes, count)) ++ + run(minCount, maxPatternLength, newPrefixes, newProjected) } } @@ -78,7 +69,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { database - .map(candidateSeq => getSuffix(prefix, candidateSeq)) + .map(getSuffix(prefix, _)) .filter(_.nonEmpty) } @@ -86,16 +77,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * Generates frequent items by filtering the input data using minimal count level. * @param minCount the minimum count for an item to be frequent * @param database database of sequences - * @return item and count pairs + * @return freq item to count map */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): Iterable[(Int, Long)] = { - database.flatMap(_.distinct) - .foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => - ctr(item) += 1 - ctr + database: Array[Array[Int]]): mutable.Map[Int, Long] = { + // TODO: use PrimitiveKeyOpenHashMap + val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + database.foreach { sequence => + sequence.distinct.foreach { item => + counts(item) += 1L } - .filter(_._2 >= minCount) + } + counts.filter(_._2 >= minCount) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 87b87569e2ec9..9f107c89f6d80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -47,8 +47,8 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toList, x._2)).toSet == - actualValue.map(x => (x._1.toList, x._2)).toSet + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan()