Skip to content

Commit

Permalink
[SPARK-8997] [MLLIB] Performance improvements in LocalPrefixSpan
Browse files Browse the repository at this point in the history
Improves the performance of LocalPrefixSpan by implementing optimizations proposed in [SPARK-8997](https://issues.apache.org/jira/browse/SPARK-8997)

Author: Feynman Liang <[email protected]>
Author: Feynman Liang <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes apache#7360 from feynmanliang/SPARK-8997-improve-prefixspan and squashes the following commits:

59db2f5 [Feynman Liang] Merge pull request #1 from mengxr/SPARK-8997
91e4357 [Xiangrui Meng] update LocalPrefixSpan impl
9212256 [Feynman Liang] MengXR code review comments
f055d82 [Feynman Liang] Fix failing scalatest
2e00cba [Feynman Liang] Depth first projections
70b93e3 [Feynman Liang] Performance improvements in LocalPrefixSpan, fix tests
  • Loading branch information
Feynman Liang authored and mengxr committed Jul 15, 2015
1 parent f0e1297 commit 1bb8acc
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,78 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental

/**
*
* :: Experimental ::
*
* Calculate all patterns of a projected database in local.
*/
@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {

/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase the projected dabase
* @param prefixes prefixes in reversed order
* @param database the projected database
* @return a set of sequential pattern pairs,
* the key of pair is sequential pattern (a list of items),
* the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => run(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
val newPrefixes = item :: prefixes
val newProjected = project(filteredDatabase, item)
Iterator.single((newPrefixes, count)) ++
run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence sequence
* Calculate suffix sequence immediately after the first occurrence of an item.
* @param item item to get suffix after
* @param sequence sequence to extract suffix from
* @return suffix sequence
*/
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(item)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences sequences data
* @return array of item and count pair
* @param minCount the minimum count for an item to be frequent
* @param database database of sequences
* @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
sequence.distinct.foreach { item =>
counts(item) += 1L
}
}
counts.filter(_._2 >= minCount)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD

class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

test("PrefixSpan using Integer type") {

Expand Down Expand Up @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
val sortedActualValue = actualValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
sortedExpectedValue.zip(sortedActualValue)
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
.reduce(_&&_)
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
Expand Down

0 comments on commit 1bb8acc

Please sign in to comment.