Skip to content

Commit

Permalink
Merge pull request #2 from mengxr/SPARK-5520
Browse files Browse the repository at this point in the history
update to make generic FPGrowth Java-friendly
  • Loading branch information
jackylk committed Feb 3, 2015
2 parents 737d8bb + 63073d0 commit f5acf84
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 87 deletions.
49 changes: 29 additions & 20 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,30 @@

package org.apache.spark.mllib.fpm

import java.lang.{Iterable => JavaIterable}
import java.{util => ju}
import java.lang.{Iterable => JavaIterable}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}

class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = {
freqItemsets.toJavaRDD()
/**
* Model trained by [[FPGrowth]], which holds frequent itemsets.
* @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
* @tparam Item item type
*/
class FPGrowthModel[Item: ClassTag](
val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {

/** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
}
}

Expand Down Expand Up @@ -77,22 +86,22 @@ class FPGrowth private (
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*/
def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = {
def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val count = data.count()
val minCount = math.ceil(minSupport * count).toLong
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
val partitioner = new HashPartitioner(numParts)
val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner)
val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner)
val freqItems = genFreqItems(data, minCount, partitioner)
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
new FPGrowthModel(freqItemsets)
}

def run[Item: ClassTag, Basket <: JavaIterable[Item]](
data: JavaRDD[Basket]): FPGrowthModel[Item] = {
this.run(data.rdd.map(_.asScala))
def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
implicit val tag = fakeClassTag[Item]
run(data.rdd.map(_.asScala.toArray))
}

/**
Expand All @@ -101,8 +110,8 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]](
data: RDD[Basket],
private def genFreqItems[Item: ClassTag](
data: RDD[Array[Item]],
minCount: Long,
partitioner: Partitioner): Array[Item] = {
data.flatMap { t =>
Expand All @@ -127,8 +136,8 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]](
data: RDD[Basket],
private def genFreqItemsets[Item: ClassTag](
data: RDD[Array[Item]],
minCount: Long,
freqItems: Array[Item],
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
Expand All @@ -152,13 +161,13 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]](
transaction: Basket,
private def genCondTransactions[Item: ClassTag](
transaction: Array[Item],
itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
val filtered = transaction.flatMap(itemToRank.get).toArray
val filtered = transaction.flatMap(itemToRank.get)
ju.Arrays.sort(filtered)
val n = filtered.length
var i = n - 1
Expand Down
102 changes: 45 additions & 57 deletions mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,78 +19,66 @@

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;

import com.google.common.collect.Lists;
import static org.junit.Assert.*;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
}

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
}
@After
public void tearDown() {
sc.stop();
sc = null;
}

@After
public void tearDown() {
sc.stop();
sc = null;
}
@Test
public void runFPGrowth() {

@Test
public void runFPGrowth() {
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
Lists.newArrayList("r z h k p".split(" ")),
Lists.newArrayList("z y x w v u t s".split(" ")),
Lists.newArrayList("s x o n r".split(" ")),
Lists.newArrayList("x z y m t s q e".split(" ")),
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);
@SuppressWarnings("unchecked")
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
Lists.newArrayList("r z h k p".split(" ")),
Lists.newArrayList("z y x w v u t s".split(" ")),
Lists.newArrayList("s x o n r".split(" ")),
Lists.newArrayList("x z y m t s q e".split(" ")),
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);

FPGrowth fpg = new FPGrowth();
FPGrowth fpg = new FPGrowth();

/*
FPGrowthModel model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd);
assert(model6.javaFreqItemsets().count() == 0);
FPGrowthModel<String> model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd);
assertEquals(0, model6.javaFreqItemsets().count());

FPGrowthModel model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
val expected = Set(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
(Set("r"), 3L),
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
(Set("t", "y", "x"), 3L),
(Set("t", "y", "x", "z"), 3L))
assert(freqItemsets3.toSet === expected)
FPGrowthModel<String> model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
assertEquals(18, model3.javaFreqItemsets().count());

val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run[String](rdd)
assert(model2.freqItemsets.count() == 54)
FPGrowthModel<String> model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run(rdd);
assertEquals(54, model2.javaFreqItemsets().count());

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run[String](rdd)
assert(model1.freqItemsets.count() == 625) */
}
}
FPGrowthModel<String> model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run(rdd);
assertEquals(625, model1.javaFreqItemsets().count());
}
}
22 changes: 12 additions & 10 deletions mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
"x z y m t s q e",
"z",
"x z y r q t p")
.map(_.split(" ").toSeq)
.map(_.split(" "))
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run[String, Seq[String]](rdd)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run[String, Seq[String]](rdd)
.run(rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
Expand All @@ -62,13 +62,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run[String, Seq[String]](rdd)
.run(rdd)
assert(model2.freqItemsets.count() === 54)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run[String, Seq[String]](rdd)
.run(rdd)
assert(model1.freqItemsets.count() === 625)
}

Expand All @@ -81,21 +81,23 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
"2 4",
"1 3",
"1 7")
.map(_.split(" ").map(_.toInt).toList)
.map(_.split(" ").map(_.toInt).toArray)
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run[Int, List[Int]](rdd)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run[Int, List[Int]](rdd)
.run(rdd)
assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
Expand All @@ -108,13 +110,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run[Int, List[Int]](rdd)
.run(rdd)
assert(model2.freqItemsets.count() === 15)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run[Int, List[Int]](rdd)
.run(rdd)
assert(model1.freqItemsets.count() === 65)
}
}

0 comments on commit f5acf84

Please sign in to comment.