diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index bbdc75532ae6f..8a15a867910a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -103,45 +103,49 @@ class PrefixSpan private ( // Convert min support to a min number of transactions for this dataset val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong - val itemCounts = sequences + // Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences .flatMap(seq => seq.distinct.map(item => (item, 1L))) .reduceByKey(_ + _) .filter(_._2 >= minCount) - var allPatternAndCounts = itemCounts.map(x => (List(x._1), x._2)) - val prefixSuffixPairs = { - val frequentItems = itemCounts.map(_._1).collect() - val candidates = sequences.map { p => - p.filter (frequentItems.contains(_) ) - } - candidates.flatMap { x => - frequentItems.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (List(y), sub) - }.filter(_._2.nonEmpty) + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.keys.collect().toSet + sequences.flatMap { seq => + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, seq.filter(freqItems.contains(_))) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } } } - var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = partitionByProjDBSize(prefixSuffixPairs) - while (largePrefixSuffixPairs.count() != 0) { + // Accumulator for the computed results to be returned + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = - getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) - largePrefixSuffixPairs.unpersist() + getPatternCountsAndPrefixSuffixPairs(minCount, pairsForDistributed) + pairsForDistributed.unpersist() val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) - largePrefixSuffixPairs = largerPairsPart - largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - smallPrefixSuffixPairs ++= smallerPairsPart - allPatternAndCounts ++= nextPatternAndCounts + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts } - if (smallPrefixSuffixPairs.count() > 0) { - val projectedDatabase = smallPrefixSuffixPairs - // TODO aggregateByKey - .groupByKey() - val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) - allPatternAndCounts ++= nextPatternAndCounts - } - allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) } + // Process the small projected databases locally + resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey()) + + resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) } } @@ -177,8 +181,8 @@ class PrefixSpan private ( */ private def getPatternCountsAndPrefixSuffixPairs( minCount: Long, - prefixSuffixPairs: RDD[(List[Int], Array[Int])]): - (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { val prefixAndFrequentItemAndCounts = prefixSuffixPairs .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _)