From 7783351736bd3a6d766ab44a9c2a19e6d5411d74 Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Wed, 4 Feb 2015 00:54:41 +0800 Subject: [PATCH 1/4] add generic support in FPGrowth --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 46 +++++++++------ .../spark/mllib/fpm/FPGrowthSuite.scala | 59 +++++++++++++++++-- 2 files changed, 82 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9591c7966e06a..96baac07cb5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,15 +17,23 @@ package org.apache.spark.mllib.fpm +import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.reflect.ClassTag -import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable +class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { + def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = { + freqItemsets.toJavaRDD() + } +} /** * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. @@ -69,7 +77,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run(data: RDD[Array[String]]): FPGrowthModel = { + def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -77,24 +85,28 @@ class FPGrowth private ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems(data, minCount, partitioner) - val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) + val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner) + val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner) new FPGrowthModel(freqItemsets) } + def run[Item: ClassTag, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + this.run(data.rdd.map(_.asScala)) + } + /** * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems( - data: RDD[Array[String]], + private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]]( + data: RDD[Basket], minCount: Long, - partitioner: Partitioner): Array[String] = { + partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.length != uniq.size) { + if (t.size != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t @@ -114,11 +126,11 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets( - data: RDD[Array[String]], + private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]]( + data: RDD[Basket], minCount: Long, - freqItems: Array[String], - partitioner: Partitioner): RDD[(Array[String], Long)] = { + freqItems: Array[Item], + partitioner: Partitioner): RDD[(Array[Item], Long)] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -139,13 +151,13 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions( - transaction: Array[String], - itemToRank: Map[String, Int], + private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]]( + transaction: Basket, + itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. - val filtered = transaction.flatMap(itemToRank.get) + val filtered = transaction.flatMap(itemToRank.get).toArray ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 71ef60da6dd32..67dc246abfb24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { - test("FP-Growth") { + + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -30,7 +31,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "x z y m t s q e", "z", "x z y r q t p") - .map(_.split(" ")) + .map(_.split(" ").toSeq) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -38,13 +39,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run(rdd) + .run[String, Seq[String]](rdd) val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -61,13 +62,59 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model1.freqItemsets.count() === 625) } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toList) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run[Int, List[Int]](rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run[Int, List[Int]](rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => + (items.toSet, count) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run[Int, List[Int]](rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run[Int, List[Int]](rdd) + assert(model1.freqItemsets.count() === 65) + } } From 793f85c131eade9b58a94655f318f1a00a7f2d9f Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Wed, 4 Feb 2015 01:01:12 +0800 Subject: [PATCH 2/4] add Java test case --- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 0000000000000..c0b55691983ae --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowth fpg = new FPGrowth(); + + /* + FPGrowthModel model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd); + assert(model6.javaFreqItemsets().count() == 0); + + FPGrowthModel model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => + (items.toSet, count) + } + val expected = Set( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run[String](rdd) + assert(model2.freqItemsets.count() == 54) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run[String](rdd) + assert(model1.freqItemsets.count() == 625) */ + } +} \ No newline at end of file From 737d8bb7b3a2e1a8a4fe55d3be4dbf21d3b82c3c Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Wed, 4 Feb 2015 01:25:15 +0800 Subject: [PATCH 3/4] fix scalastyle --- mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 96baac07cb5d4..70955bffa206b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -90,7 +90,8 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } - def run[Item: ClassTag, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + def run[Item: ClassTag, Basket <: JavaIterable[Item]]( + data: JavaRDD[Basket]): FPGrowthModel[Item] = { this.run(data.rdd.map(_.asScala)) } From 63073d011fd5bf80e3ac1e57e475329a12899ca1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 3 Feb 2015 12:25:55 -0800 Subject: [PATCH 4/4] update to make generic FPGrowth Java-friendly --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 49 +++++---- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 102 ++++++++---------- .../spark/mllib/fpm/FPGrowthSuite.scala | 22 ++-- 3 files changed, 86 insertions(+), 87 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70955bffa206b..1433ee9a0dd5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,21 +17,30 @@ package org.apache.spark.mllib.fpm -import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import java.lang.{Iterable => JavaIterable} -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { - def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = { - freqItemsets.toJavaRDD() +/** + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs + * @tparam Item item type + */ +class FPGrowthModel[Item: ClassTag]( + val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { + + /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */ + def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = { + JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]] } } @@ -77,7 +86,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = { + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -85,14 +94,14 @@ class FPGrowth private ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner) - val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner) + val freqItems = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) new FPGrowthModel(freqItemsets) } - def run[Item: ClassTag, Basket <: JavaIterable[Item]]( - data: JavaRDD[Basket]): FPGrowthModel[Item] = { - this.run(data.rdd.map(_.asScala)) + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) } /** @@ -101,8 +110,8 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]]( - data: RDD[Basket], + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, partitioner: Partitioner): Array[Item] = { data.flatMap { t => @@ -127,8 +136,8 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]]( - data: RDD[Basket], + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, freqItems: Array[Item], partitioner: Partitioner): RDD[(Array[Item], Long)] = { @@ -152,13 +161,13 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]]( - transaction: Basket, + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. - val filtered = transaction.flatMap(itemToRank.get).toArray + val filtered = transaction.flatMap(itemToRank.get) ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index c0b55691983ae..851707c8a19c4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -19,78 +19,66 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; - import com.google.common.collect.Lists; +import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; public class JavaFPGrowthSuite implements Serializable { - private transient JavaSparkContext sc; + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); - } + @After + public void tearDown() { + sc.stop(); + sc = null; + } - @After - public void tearDown() { - sc.stop(); - sc = null; - } + @Test + public void runFPGrowth() { - @Test - public void runFPGrowth() { - JavaRDD> rdd = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("r z h k p".split(" ")), - Lists.newArrayList("z y x w v u t s".split(" ")), - Lists.newArrayList("s x o n r".split(" ")), - Lists.newArrayList("x z y m t s q e".split(" ")), - Lists.newArrayList("z".split(" ")), - Lists.newArrayList("x z y r q t p".split(" "))), 2); + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); - FPGrowth fpg = new FPGrowth(); + FPGrowth fpg = new FPGrowth(); - /* - FPGrowthModel model6 = fpg - .setMinSupport(0.9) - .setNumPartitions(1) - .run(rdd); - assert(model6.javaFreqItemsets().count() == 0); + FPGrowthModel model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd); + assertEquals(0, model6.javaFreqItemsets().count()); - FPGrowthModel model3 = fpg - .setMinSupport(0.5) - .setNumPartitions(2) - .run(rdd); - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) - } - val expected = Set( - (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), - (Set("r"), 3L), - (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), - (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), - (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), - (Set("t", "y", "x"), 3L), - (Set("t", "y", "x", "z"), 3L)) - assert(freqItemsets3.toSet === expected) + FPGrowthModel model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + assertEquals(18, model3.javaFreqItemsets().count()); - val model2 = fpg - .setMinSupport(0.3) - .setNumPartitions(4) - .run[String](rdd) - assert(model2.freqItemsets.count() == 54) + FPGrowthModel model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd); + assertEquals(54, model2.javaFreqItemsets().count()); - val model1 = fpg - .setMinSupport(0.1) - .setNumPartitions(8) - .run[String](rdd) - assert(model1.freqItemsets.count() == 625) */ - } -} \ No newline at end of file + FPGrowthModel model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd); + assertEquals(625, model1.javaFreqItemsets().count()); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 67dc246abfb24..68128284b8608 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -31,7 +31,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "x z y m t s q e", "z", "x z y r q t p") - .map(_.split(" ").toSeq) + .map(_.split(" ")) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -39,13 +39,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run[String, Seq[String]](rdd) + .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -62,13 +62,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model1.freqItemsets.count() === 625) } @@ -81,7 +81,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "2 4", "1 3", "1 7") - .map(_.split(" ").map(_.toInt).toList) + .map(_.split(" ").map(_.toInt).toArray) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -89,13 +89,15 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run[Int, List[Int]](rdd) + .run(rdd) + assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -108,13 +110,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model1.freqItemsets.count() === 65) } }