diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 5b8da9665366b..899078a759f31 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,6 +49,7 @@ class PrefixSpan private ( * The maximum number of items allowed in a projected database before local processing. If a * projected database exceeds this size, another iteration of distributed PrefixSpan is run. */ + // TODO: make configurable with a better default value, 10000 may be too small private val maxLocalProjDBSize: Long = 10000 /** @@ -61,7 +62,7 @@ class PrefixSpan private ( * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered * frequent). */ - def getMinSupport(): Double = this.minSupport + def getMinSupport: Double = this.minSupport /** * Sets the minimal support level (default: `0.1`). @@ -75,7 +76,7 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ - def getMaxPatternLength(): Double = this.maxPatternLength + def getMaxPatternLength: Double = this.maxPatternLength /** * Sets maximal pattern length (default: `10`). @@ -96,6 +97,8 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -108,10 +111,11 @@ class PrefixSpan private ( .flatMap(seq => seq.distinct.map(item => (item, 1L))) .reduceByKey(_ + _) .filter(_._2 >= minCount) + .collect() // Pairs of (length 1 prefix, suffix consisting of frequent items) val itemSuffixPairs = { - val freqItems = freqItemCounts.keys.collect().toSet + val freqItems = freqItemCounts.map(_._1).toSet sequences.flatMap { seq => val filteredSeq = seq.filter(freqItems.contains(_)) freqItems.flatMap { item => @@ -141,13 +145,14 @@ class PrefixSpan private ( pairsForDistributed = largerPairsPart pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) pairsForLocal ++= smallerPairsPart - resultsAccumulator ++= nextPatternAndCounts + resultsAccumulator ++= nextPatternAndCounts.collect() } // Process the small projected databases locally - resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey()) + val remainingResults = getPatternsInLocal(minCount, pairsForLocal.groupByKey()) - resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) } + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } }