Skip to content

Commit

Permalink
initialize file before rebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 15, 2015
1 parent 078d410 commit 4dd1c8a
Showing 1 changed file with 10 additions and 65 deletions.
75 changes: 10 additions & 65 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

private val minPatternsBeforeShuffle: Int = 20

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
Expand Down Expand Up @@ -88,69 +86,16 @@ class PrefixSpan private (
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)

var patternsCount = lengthOnePatternsAndCounts.length
var allPatternAndCounts = sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
var currentProjectedDatabase = prefixAndProjectedDatabase
while (patternsCount <= minPatternsBeforeShuffle &&
currentProjectedDatabase.count() != 0) {
val (nextPatternAndCounts, nextProjectedDatabase) =
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
patternsCount = nextPatternAndCounts.count().toInt
currentProjectedDatabase = nextProjectedDatabase
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
if (patternsCount > 0) {
val groupedProjectedDatabase = currentProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
allPatternAndCounts
}

/**
* Get the pattern and counts, and projected database
* @param minCount minimum count
* @param prefixAndProjectedDatabase prefix and projected database,
* @return pattern and counts, and projected database
* (Array[pattern, count], RDD[prefix, projected database ])
*/
private def getPatternCountsAndProjectedDatabase(
minCount: Long,
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
}.reduceByKey(_ + _)
.filter(_._2 >= minCount)
val patternAndCounts = prefixAndFreqentItemAndCounts
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
if (prefixlength + 1 >= maxPatternLength) {
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
} else {
val frequentItemsMap = prefixAndFreqentItemAndCounts
.keys.map(x => (x._1, x._2))
.groupByKey()
.mapValues(_.toSet)
.collect
.toMap
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
.filter(x => frequentItemsMap.contains(x._1))
.flatMap { x =>
val frequentItemSet = frequentItemsMap(x._1)
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
val subProjectedDabase = frequentItemSet.map{ y =>
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
}.filter(_._2.nonEmpty)
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
}
(patternAndCounts, nextPrefixAndProjectedDatabase)
}
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns
}

/**
Expand Down

0 comments on commit 4dd1c8a

Please sign in to comment.