Skip to content

Commit

Permalink
add generic support in FPGrowth
Browse files Browse the repository at this point in the history
  • Loading branch information
jackylk committed Feb 3, 2015
1 parent bebf4c4 commit 7783351
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 23 deletions.
46 changes: 29 additions & 17 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@

package org.apache.spark.mllib.fpm

import java.lang.{Iterable => JavaIterable}
import java.{util => ju}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}

class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = {
freqItemsets.toJavaRDD()
}
}

/**
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
Expand Down Expand Up @@ -69,32 +77,36 @@ class FPGrowth private (
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*/
def run(data: RDD[Array[String]]): FPGrowthModel = {
def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val count = data.count()
val minCount = math.ceil(minSupport * count).toLong
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
val partitioner = new HashPartitioner(numParts)
val freqItems = genFreqItems(data, minCount, partitioner)
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner)
val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner)
new FPGrowthModel(freqItemsets)
}

def run[Item: ClassTag, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
this.run(data.rdd.map(_.asScala))
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
private def genFreqItems(
data: RDD[Array[String]],
private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]](
data: RDD[Basket],
minCount: Long,
partitioner: Partitioner): Array[String] = {
partitioner: Partitioner): Array[Item] = {
data.flatMap { t =>
val uniq = t.toSet
if (t.length != uniq.size) {
if (t.size != uniq.size) {
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
}
t
Expand All @@ -114,11 +126,11 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
private def genFreqItemsets(
data: RDD[Array[String]],
private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]](
data: RDD[Basket],
minCount: Long,
freqItems: Array[String],
partitioner: Partitioner): RDD[(Array[String], Long)] = {
freqItems: Array[Item],
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
Expand All @@ -139,13 +151,13 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
private def genCondTransactions(
transaction: Array[String],
itemToRank: Map[String, Int],
private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]](
transaction: Basket,
itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
val filtered = transaction.flatMap(itemToRank.get)
val filtered = transaction.flatMap(itemToRank.get).toArray
ju.Arrays.sort(filtered)
val n = filtered.length
var i = n - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,30 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {

test("FP-Growth") {

test("FP-Growth using String type") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
"s x o n r",
"x z y m t s q e",
"z",
"x z y r q t p")
.map(_.split(" "))
.map(_.split(" ").toSeq)
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd)
.run[String, Seq[String]](rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
.run[String, Seq[String]](rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
Expand All @@ -61,13 +62,59 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run(rdd)
.run[String, Seq[String]](rdd)
assert(model2.freqItemsets.count() === 54)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run(rdd)
.run[String, Seq[String]](rdd)
assert(model1.freqItemsets.count() === 625)
}

test("FP-Growth using Int type") {
val transactions = Seq(
"1 2 3",
"1 2 3 4",
"5 4 3 2 1",
"6 5 4 3 2 1",
"2 4",
"1 3",
"1 7")
.map(_.split(" ").map(_.toInt).toList)
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run[Int, List[Int]](rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run[Int, List[Int]](rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
val expected = Set(
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
(Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
(Set(2, 4), 4L), (Set(1, 2, 3), 4L))
assert(freqItemsets3.toSet === expected)

val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run[Int, List[Int]](rdd)
assert(model2.freqItemsets.count() === 15)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run[Int, List[Int]](rdd)
assert(model1.freqItemsets.count() === 65)
}
}

0 comments on commit 7783351

Please sign in to comment.