From 2e00cba1ef52eed77d6df1b2acfafb741ac427ef Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 13 Jul 2015 14:50:52 -0700 Subject: [PATCH] Depth first projections --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 79 ++++++++----------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 892bfa61403e0..53760645a53af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -20,8 +20,6 @@ package org.apache.spark.mllib.fpm import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import scala.collection.mutable.ArrayBuffer - /** * * :: Experimental :: @@ -36,7 +34,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * @param minCount minimum count * @param maxPatternLength maximum pattern length * @param prefix prefix - * @param projectedDatabase the projected dabase + * @param database the projected dabase * @return a set of sequential pattern pairs, * the key of pair is sequential pattern (a list of items), * the value of pair is the pattern's count. @@ -44,31 +42,36 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def run( minCount: Long, maxPatternLength: Int, - prefix: ArrayBuffer[Int], - projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => ((prefix :+ x._1).toArray, x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) + prefix: List[Int], + database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = { + + if (database.isEmpty) return Iterator.empty + + val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) + val frequentItems = frequentItemAndCounts.map(_._1) + val frequentPatternAndCounts = frequentItemAndCounts + .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } - if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) { - frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap { - case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB) + val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) + + if (prefix.length + 1 < maxPatternLength) { + frequentPatternAndCounts ++ frequentItems.flatMap { item => + val nextProjected = project(filteredProjectedDatabase, item) + run(minCount, maxPatternLength, item :: prefix, nextProjected) } } else { - frequentPatternAndCounts.iterator + frequentPatternAndCounts } } /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence + * Calculate suffix sequence immediately after the first occurrence of an item. + * @param item item to get suffix after + * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) + def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(item) if (index == -1) { Array() } else { @@ -76,40 +79,26 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + database + .map(candidateSeq => getSuffix(prefix, candidateSeq)) + .filter(_.nonEmpty) + } + /** * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences sequences data - * @return array of item and count pair + * @param minCount the minimum count for an item to be frequent + * @param database database of sequences + * @return item and count pairs */ private def getFreqItemAndCounts( minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) + database: Iterable[Array[Int]]): Iterator[(Int, Long)] = { + database.flatMap(_.distinct) .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => ctr + (item -> (ctr(item) + 1)) } .filter(_._2 >= minCount) - .toArray - } - - /** - * Get the frequent prefixes' projected database. - * @param prefix the frequent prefixes' prefix - * @param frequentPrefixes frequent next prefixes - * @param projDB projected database for given prefix - * @return extensions of prefix by one item and corresponding projected databases - */ - private def getPatternAndProjectedDatabase( - prefix: ArrayBuffer[Int], - frequentPrefixes: Array[Int], - projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { nextItem => - val nextProjDB = filteredProjectedDatabase - .map(candidateSeq => getSuffix(nextItem, candidateSeq)) - .filter(_.nonEmpty) - (prefix :+ nextItem, nextProjDB) - }.filter(x => x._2.nonEmpty) + .iterator } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 0bccb37fa9cd5..73ba3bb63dfcb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -153,7 +153,7 @@ class PrefixSpan private ( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB) + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) } } }