Skip to content

Commit

Permalink
Remove iteratorFoldM fixes #1716 (#1740)
Browse files Browse the repository at this point in the history
* remove iteratorFoldM fixes #1716

* fix foldM to not evaluate one extra uneccessary element

* update comment

* more efficient implementations

* keep the original sequence

* updated NonEmptyVector as well

* use default foldM implemenation to test short-circuit

* style correction / remove redundant overrides
  • Loading branch information
kailuowang authored and peterneyens committed Jun 26, 2017
1 parent 05b9118 commit 4274102
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 57 deletions.
41 changes: 3 additions & 38 deletions core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ import simulacrum.typeclass
def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = {
val src = Foldable.Source.fromFoldable(fa)(self)
G.tailRecM((z, src)) { case (b, src) => src.uncons match {
case Some((a, src)) => G.map(f(b, a))(b => Left((b, src)))
case Some((a, src)) => G.map(f(b, a))(b => Left((b, src.value)))
case None => G.pure(Right(b))
}}
}
Expand Down Expand Up @@ -414,41 +414,6 @@ object Foldable {
loop()
}

/**
* Implementation of [[Foldable.foldM]] which can short-circuit for
* structures with an `Iterator`.
*
* For example we can sum a `Stream` of integers and stop if
* the sum reaches 100 (if we reach the end of the `Stream`
* before getting to 100 we return the total sum) :
*
* {{{
* scala> import cats.implicits._
* scala> type LongOr[A] = Either[Long, A]
* scala> def sumStream(s: Stream[Int]): Long =
* | Foldable.iteratorFoldM[LongOr, Int, Long](s.toIterator, 0L){ (acc, n) =>
* | val sum = acc + n
* | if (sum < 100L) Right(sum) else Left(sum)
* | }.merge
*
* scala> sumStream(Stream.continually(1))
* res0: Long = 100
*
* scala> sumStream(Stream(1,2,3,4))
* res1: Long = 10
* }}}
*
* Note that `Foldable[Stream].foldM` uses this method underneath, so
* you wouldn't call this method explicitly like in the example above.
*/
def iteratorFoldM[M[_], A, B](it: Iterator[A], z: B)(f: (B, A) => M[B])(implicit M: Monad[M]): M[B] = {
val go: B => M[Either[B, B]] = { b =>
if (it.hasNext) M.map(f(b, it.next))(Left(_))
else M.pure(Right(b))
}
M.tailRecM(z)(go)
}


/**
* Isomorphic to
Expand All @@ -461,7 +426,7 @@ object Foldable {
* https://github.com/scala/bug/issues/9600 is resolved.
*/
private sealed abstract class Source[+A] {
def uncons: Option[(A, Source[A])]
def uncons: Option[(A, Eval[Source[A]])]
}

private object Source {
Expand All @@ -470,7 +435,7 @@ object Foldable {
}

def cons[A](a: A, src: Eval[Source[A]]): Source[A] = new Source[A] {
def uncons = Some((a, src.value))
def uncons = Some((a, src))
}

def fromFoldable[F[_], A](fa: F[A])(implicit F: Foldable[F]): Source[A] =
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/cats/data/NonEmptyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,6 @@ private[data] sealed trait NonEmptyListInstances extends NonEmptyListInstances0
override def fold[A](fa: NonEmptyList[A])(implicit A: Monoid[A]): A =
fa.reduce

override def foldM[G[_], A, B](fa: NonEmptyList[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toList.toIterator, z)(f)

override def find[A](fa: NonEmptyList[A])(f: A => Boolean): Option[A] =
fa find f

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/cats/data/NonEmptyVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ private[data] sealed trait NonEmptyVectorInstances {
override def fold[A](fa: NonEmptyVector[A])(implicit A: Monoid[A]): A =
fa.reduce

override def foldM[G[_], A, B](fa: NonEmptyVector[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toVector.toIterator, z)(f)

override def find[A](fa: NonEmptyVector[A])(f: A => Boolean): Option[A] =
fa.find(f)

Expand Down
10 changes: 8 additions & 2 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ trait ListInstances extends cats.kernel.instances.ListInstances {

override def filter[A](fa: List[A])(f: A => Boolean): List[A] = fa.filter(f)

override def foldM[G[_], A, B](fa: List[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
override def foldM[G[_], A, B](fa: List[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = {
def step(in: (List[A], B)): G[Either[(List[A], B), B]] = in match {
case (Nil, b) => G.pure(Right(b))
case (a :: tail, b) => G.map(f(b, a)) { bnext => Left((tail, bnext)) }
}

G.tailRecM((fa, z))(step)
}

override def fold[A](fa: List[A])(implicit A: Monoid[A]): A = A.combineAll(fa)

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/cats/instances/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ trait MapInstances extends cats.kernel.instances.MapInstances {

override def isEmpty[A](fa: Map[K, A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Map[K, A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.valuesIterator, z)(f)

override def fold[A](fa: Map[K, A])(implicit A: Monoid[A]): A =
A.combineAll(fa.values)

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/cats/instances/set.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ trait SetInstances extends cats.kernel.instances.SetInstances {

override def isEmpty[A](fa: Set[A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Set[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)

override def fold[A](fa: Set[A])(implicit A: Monoid[A]): A = A.combineAll(fa)

override def toList[A](fa: Set[A]): List[A] = fa.toList
Expand Down
14 changes: 12 additions & 2 deletions core/src/main/scala/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package instances

import cats.syntax.show._

import scala.annotation.tailrec

trait StreamInstances extends cats.kernel.instances.StreamInstances {
Expand Down Expand Up @@ -118,8 +119,17 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances {

override def collect[A, B](fa: Stream[A])(f: PartialFunction[A, B]): Stream[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Stream[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
override def foldM[G[_], A, B](fa: Stream[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = {
def step(in: (Stream[A], B)): G[Either[(Stream[A], B), B]] = {
val (s, b) = in
if (s.isEmpty)
G.pure(Right(b))
else
G.map(f(b, s.head)) { bnext => Left((s.tail, bnext)) }
}

G.tailRecM((fa, z))(step)
}

override def fold[A](fa: Stream[A])(implicit A: Monoid[A]): A = A.combineAll(fa)

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import cats.syntax.show._
import scala.annotation.tailrec
import scala.collection.+:
import scala.collection.immutable.VectorBuilder
import list._

trait VectorInstances extends cats.kernel.instances.VectorInstances {
implicit val catsStdInstancesForVector: TraverseFilter[Vector] with MonadCombine[Vector] with CoflatMap[Vector] =
Expand Down Expand Up @@ -91,7 +92,7 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
override def collect[A, B](fa: Vector[A])(f: PartialFunction[A, B]): Vector[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Vector[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
Foldable[List].foldM(fa.toList, z)(f)

override def fold[A](fa: Vector[A])(implicit A: Monoid[A]): A = A.combineAll(fa)

Expand Down
22 changes: 20 additions & 2 deletions tests/src/test/scala/cats/tests/FoldableTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,33 @@ class FoldableTestsAdditional extends CatsSuite {

// test laziness of foldM
dangerous.foldM(0)((acc, a) => if (a < 2) Some(acc + a) else None) should === (None)

}

def foldableStreamWithDefaultImpl = new Foldable[Stream] {
def foldLeft[A, B](fa: Stream[A], b: B)(f: (B, A) => B): B =
instances.stream.catsStdInstancesForStream.foldLeft(fa, b)(f)

def foldRight[A, B](fa: Stream[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
instances.stream.catsStdInstancesForStream.foldRight(fa, lb)(f)
}

test(".foldLeftM short-circuiting") {
implicit val F = foldableStreamWithDefaultImpl
val ns = Stream.continually(1)
val res = Foldable[Stream].foldLeftM[Either[Int, ?], Int, Int](ns, 0) { (sum, n) =>
val res = F.foldLeftM[Either[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 100000) Left(sum) else Right(sum + n)
}
assert(res == Left(100000))
}

test(".foldLeftM short-circuiting optimality") {
implicit val F = foldableStreamWithDefaultImpl

// test that no more elements are evaluated than absolutely necessary

def concatUntil(ss: Stream[String], stop: String): Either[String, String] =
Foldable[Stream].foldLeftM[Either[String, ?], String, String](ss, "") { (acc, s) =>
F.foldLeftM[Either[String, ?], String, String](ss, "") { (acc, s) =>
if (s == stop) Left(acc) else Right(acc + s)
}

Expand All @@ -226,6 +238,12 @@ class FoldableTestsAdditional extends CatsSuite {
assert(concatUntil("Zero" #:: "STOP" #:: boom, "STOP") == Left("Zero"))
assert(concatUntil("Zero" #:: "One" #:: "STOP" #:: boom, "STOP") == Left("ZeroOne"))
}

test("Foldable[List] doesn't break substitution") {
val result = List.range(0,10).foldM(List.empty[Int])((accum, elt) => Eval.always(elt :: accum))

assert(result.value == result.value)
}
}

class FoldableListCheck extends FoldableCheck[List]("list") {
Expand Down

0 comments on commit 4274102

Please sign in to comment.