diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index dc555001b7778..39c48b084e550 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -30,13 +30,13 @@ import org.apache.spark.annotation.Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** - * Calculate all patterns of a projected database in local. + * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length * @param prefix prefix * @param projectedDatabase the projected dabase * @return a set of sequential pattern pairs, - * the key of pair is pattern (a list of elements), + * the key of pair is sequential pattern (a list of items), * the value of pair is the pattern's count. */ def run( @@ -44,7 +44,21 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { maxPatternLength: Int, prefix: Array[Int], projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase) + val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) + val frequentPatternAndCounts = frequentPrefixAndCounts + .map(x => (prefix ++ Array(x._1), x._2)) + val prefixProjectedDatabases = getPatternAndProjectedDatabase( + prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) + + val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength + if (continueProcess) { + val nextPatterns = prefixProjectedDatabases + .map(x => run(minCount, maxPatternLength, x._1, x._2)) + .reduce(_ ++ _) + frequentPatternAndCounts ++ nextPatterns + } else { + frequentPatternAndCounts + } } /** @@ -96,34 +110,4 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { (prePrefix ++ Array(x), sub) }.filter(x => x._2.nonEmpty) } - - /** - * Calculate all patterns of a projected database in local. - * @param minCount the minimum count - * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param projectedDatabase projected database - * @return patterns - */ - private def getPatternsWithPrefix( - minCount: Long, - maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns - } else { - frequentPatternAndCounts - } - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 2239aa529695c..9d8c60ef0fc45 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -82,10 +82,15 @@ class PrefixSpan private ( logWarning("Input data is not cached.") } val minCount = getMinCount(sequences) - val (lengthOnePatternsAndCounts, prefixAndCandidates) = - findLengthOnePatterns(minCount, sequences) - val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates) - val nextPatterns = getPatternsInLocal(minCount, projectedDatabase) + val lengthOnePatternsAndCounts = + getFreqItemAndCounts(minCount, sequences).collect() + val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( + lengthOnePatternsAndCounts.map(_._1), sequences) + val groupedProjectedDatabase = prefixAndProjectedDatabase + .map(x => (x._1.toSeq, x._2)) + .groupByKey() + .map(x => (x._1.toArray, x._2.toArray)) + val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) val lengthOnePatternsAndCountsRdd = sequences.sparkContext.parallelize( lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) @@ -122,7 +127,7 @@ class PrefixSpan private ( * @param sequences sequences data * @return prefixes and projected database */ - private def getPatternAndProjectedDatabase( + private def getPrefixAndProjectedDatabase( frequentPrefixes: Array[Int], sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { val filteredSequences = sequences.map { p => @@ -136,33 +141,6 @@ class PrefixSpan private ( } } - /** - * Find the patterns that it's length is one - * @param minCount the minimum count - * @param sequences original sequences data - * @return length-one patterns and projection table - */ - private def findLengthOnePatterns( - minCount: Long, - sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = { - val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences) - val prefixAndProjectedDatabase = getPatternAndProjectedDatabase( - frequentLengthOnePatternAndCounts.keys.collect(), sequences) - (frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase) - } - - /** - * Constructs prefix-projected databases from (prefix, suffix) pairs. - * @param data patterns and projected sequences data before re-partition - * @return patterns and projected sequences data after re-partition - */ - private def makePrefixProjectedDatabases( - data: RDD[(Array[Int], Array[Int])]): RDD[(Array[Int], Array[Array[Int]])] = { - data.map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - } - /** * calculate the patterns in local. * @param minCount the absolute minimum count