Skip to content

Commit

Permalink
Modified the code according to the review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 15, 2015
1 parent 6560c69 commit baa2885
Showing 1 changed file with 37 additions and 40 deletions.
77 changes: 37 additions & 40 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,72 +84,69 @@ class PrefixSpan private (
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)

var patternsCount = lengthOnePatternsAndCounts.length
var allPatternAndCounts = sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
var currentProjectedDatabase = prefixAndProjectedDatabase
while (patternsCount <= minPatternsBeforeShuffle &&
currentProjectedDatabase.count() != 0) {
val (nextPatternAndCounts, nextProjectedDatabase) =
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences)
val prefixSuffixPairs = getPrefixSuffixPairs(
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
var patternsCount: Long = lengthOnePatternsAndCounts.count()
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))
var currentPrefixSuffixPairs = prefixSuffixPairs
while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
patternsCount = nextPatternAndCounts.count().toInt
currentProjectedDatabase = nextProjectedDatabase
currentPrefixSuffixPairs = nextPrefixSuffixPairs
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
if (patternsCount > 0) {
val groupedProjectedDatabase = currentProjectedDatabase
val projectedDatabase = currentPrefixSuffixPairs
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
allPatternAndCounts
}

/**
* Get the pattern and counts, and projected database
* Get the pattern and counts, and prefix suffix pairs
* @param minCount minimum count
* @param prefixAndProjectedDatabase prefix and projected database,
* @return pattern and counts, and projected database
* (Array[pattern, count], RDD[prefix, projected database ])
* @param prefixSuffixPairs prefix and suffix pairs,
* @return pattern and counts, and prefix suffix pairs
* (Array[pattern, count], RDD[prefix, suffix ])
*/
private def getPatternCountsAndProjectedDatabase(
private def getPatternCountsAndPrefixSuffixPairs(
minCount: Long,
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
prefixSuffixPairs: RDD[(Array[Int], Array[Int])]):
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
val prefixAndFreqentItemAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) =>
suffix.distinct.map(y => ((prefix.toSeq, y), 1L))
}.reduceByKey(_ + _)
.filter(_._2 >= minCount)
val patternAndCounts = prefixAndFreqentItemAndCounts
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
.map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
val prefixlength = prefixSuffixPairs.first()._1.length
if (prefixlength + 1 >= maxPatternLength) {
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
(patternAndCounts, prefixSuffixPairs.filter(x => false))
} else {
val frequentItemsMap = prefixAndFreqentItemAndCounts
.keys.map(x => (x._1, x._2))
.keys
.groupByKey()
.mapValues(_.toSet)
.collect
.toMap
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
val nextPrefixSuffixPairs = prefixSuffixPairs
.filter(x => frequentItemsMap.contains(x._1))
.flatMap { x =>
val frequentItemSet = frequentItemsMap(x._1)
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
val subProjectedDabase = frequentItemSet.map{ y =>
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
.flatMap { case (prefix, suffix) =>
val frequentItemSet = frequentItemsMap(prefix)
val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
val nextSuffixes = frequentItemSet.map{ item =>
(item, LocalPrefixSpan.getSuffix(item, filteredSuffix))
}.filter(_._2.nonEmpty)
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
}
(patternAndCounts, nextPrefixAndProjectedDatabase)
(patternAndCounts, nextPrefixSuffixPairs)
}
}

Expand Down Expand Up @@ -177,12 +174,12 @@ class PrefixSpan private (
}

/**
* Get the frequent prefixes' projected database.
* Get the frequent prefixes and suffix pairs.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
* @return prefixes and suffix pairs.
*/
private def getPrefixAndProjectedDatabase(
private def getPrefixSuffixPairs(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
Expand All @@ -199,7 +196,7 @@ class PrefixSpan private (
/**
* calculate the patterns in local.
* @param minCount the absolute minimum count
* @param data patterns and projected sequences data data
* @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
Expand Down

0 comments on commit baa2885

Please sign in to comment.