diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 53760645a53af..6a418dcc6fe82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -48,19 +48,19 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { if (database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val frequentItems = frequentItemAndCounts.map(_._1) + val frequentItems = frequentItemAndCounts.map(_._1).toSet val frequentPatternAndCounts = frequentItemAndCounts .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) if (prefix.length + 1 < maxPatternLength) { - frequentPatternAndCounts ++ frequentItems.flatMap { item => + frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => val nextProjected = project(filteredProjectedDatabase, item) run(minCount, maxPatternLength, item :: prefix, nextProjected) } } else { - frequentPatternAndCounts + frequentPatternAndCounts.iterator } } @@ -93,12 +93,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Iterable[Array[Int]]): Iterator[(Int, Long)] = { + database: Iterable[Array[Int]]): Iterable[(Int, Long)] = { database.flatMap(_.distinct) .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => ctr + (item -> (ctr(item) + 1)) } .filter(_._2 >= minCount) - .iterator } }