diff --git a/core/src/main/scala/cats/Foldable.scala b/core/src/main/scala/cats/Foldable.scala index acb89624dd..64db5d0ece 100644 --- a/core/src/main/scala/cats/Foldable.scala +++ b/core/src/main/scala/cats/Foldable.scala @@ -173,6 +173,12 @@ import simulacrum.typeclass /** * Left associative monadic folding on `F`. + * + * The default implementation of this is based on `foldLeft`, and thus will + * always fold across the entire structure. Certain structures are able to + * implement this in such a way that folds can be short-circuited (not + * traverse the entirety of the structure), depending on the `G` result + * produced at a given step. */ def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = foldLeft(fa, G.pure(z))((gb, a) => G.flatMap(gb)(f(_, a))) @@ -372,4 +378,39 @@ object Foldable { Eval.defer(if (it.hasNext) f(it.next, loop()) else lb) loop() } + + /** + * Implementation of [[Foldable.foldM]] which can short-circuit for + * structures with an `Iterator`. + * + * For example we can sum a `Stream` of integers and stop if + * the sum reaches 100 (if we reach the end of the `Stream` + * before getting to 100 we return the total sum) : + * + * {{{ + * scala> import cats.implicits._ + * scala> type LongOr[A] = Either[Long, A] + * scala> def sumStream(s: Stream[Int]): Long = + * | Foldable.iteratorFoldM[LongOr, Int, Long](s.toIterator, 0L){ (acc, n) => + * | val sum = acc + n + * | if (sum < 100L) Right(sum) else Left(sum) + * | }.merge + * + * scala> sumStream(Stream.continually(1)) + * res0: Long = 100 + * + * scala> sumStream(Stream(1,2,3,4)) + * res1: Long = 10 + * }}} + * + * Note that `Foldable[Stream].foldM` uses this method underneath, so + * you wouldn't call this method explicitly like in the example above. + */ + def iteratorFoldM[M[_], A, B](it: Iterator[A], z: B)(f: (B, A) => M[B])(implicit M: Monad[M]): M[B] = { + val go: B => M[Either[B, B]] = { b => + if (it.hasNext) M.map(f(b, it.next))(Left(_)) + else M.pure(Right(b)) + } + M.tailRecM(z)(go) + } } diff --git a/core/src/main/scala/cats/instances/list.scala b/core/src/main/scala/cats/instances/list.scala index 3ff790c146..343e534cdc 100644 --- a/core/src/main/scala/cats/instances/list.scala +++ b/core/src/main/scala/cats/instances/list.scala @@ -79,6 +79,9 @@ trait ListInstances extends cats.kernel.instances.ListInstances { override def isEmpty[A](fa: List[A]): Boolean = fa.isEmpty override def filter[A](fa: List[A])(f: A => Boolean): List[A] = fa.filter(f) + + override def foldM[G[_], A, B](fa: List[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = + Foldable.iteratorFoldM(fa.toIterator, z)(f) } implicit def catsStdShowForList[A:Show]: Show[List[A]] = diff --git a/core/src/main/scala/cats/instances/map.scala b/core/src/main/scala/cats/instances/map.scala index 7784f428e5..149432cae8 100644 --- a/core/src/main/scala/cats/instances/map.scala +++ b/core/src/main/scala/cats/instances/map.scala @@ -78,6 +78,9 @@ trait MapInstances extends cats.kernel.instances.MapInstances { override def size[A](fa: Map[K, A]): Long = fa.size.toLong override def isEmpty[A](fa: Map[K, A]): Boolean = fa.isEmpty + + override def foldM[G[_], A, B](fa: Map[K, A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = + Foldable.iteratorFoldM(fa.valuesIterator, z)(f) } // scalastyle:on method.length } diff --git a/core/src/main/scala/cats/instances/set.scala b/core/src/main/scala/cats/instances/set.scala index 0629b5bbb1..7b376fb5e3 100644 --- a/core/src/main/scala/cats/instances/set.scala +++ b/core/src/main/scala/cats/instances/set.scala @@ -27,6 +27,9 @@ trait SetInstances extends cats.kernel.instances.SetInstances { fa.forall(p) override def isEmpty[A](fa: Set[A]): Boolean = fa.isEmpty + + override def foldM[G[_], A, B](fa: Set[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = + Foldable.iteratorFoldM(fa.toIterator, z)(f) } implicit def catsStdShowForSet[A:Show]: Show[Set[A]] = new Show[Set[A]] { diff --git a/core/src/main/scala/cats/instances/stream.scala b/core/src/main/scala/cats/instances/stream.scala index 5ffcd171df..6fdd95cd8a 100644 --- a/core/src/main/scala/cats/instances/stream.scala +++ b/core/src/main/scala/cats/instances/stream.scala @@ -106,6 +106,9 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances { override def filter[A](fa: Stream[A])(f: A => Boolean): Stream[A] = fa.filter(f) override def collect[A, B](fa: Stream[A])(f: PartialFunction[A, B]): Stream[B] = fa.collect(f) + + override def foldM[G[_], A, B](fa: Stream[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = + Foldable.iteratorFoldM(fa.toIterator, z)(f) } implicit def catsStdShowForStream[A: Show]: Show[Stream[A]] = diff --git a/core/src/main/scala/cats/instances/vector.scala b/core/src/main/scala/cats/instances/vector.scala index 4b91a414c4..dfde5b10f7 100644 --- a/core/src/main/scala/cats/instances/vector.scala +++ b/core/src/main/scala/cats/instances/vector.scala @@ -86,6 +86,9 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances { override def filter[A](fa: Vector[A])(f: A => Boolean): Vector[A] = fa.filter(f) override def collect[A, B](fa: Vector[A])(f: PartialFunction[A, B]): Vector[B] = fa.collect(f) + + override def foldM[G[_], A, B](fa: Vector[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = + Foldable.iteratorFoldM(fa.toIterator, z)(f) } implicit def catsStdShowForVector[A:Show]: Show[Vector[A]] = diff --git a/tests/src/test/scala/cats/tests/FoldableTests.scala b/tests/src/test/scala/cats/tests/FoldableTests.scala index 4287d0f2fb..381eecb4b0 100644 --- a/tests/src/test/scala/cats/tests/FoldableTests.scala +++ b/tests/src/test/scala/cats/tests/FoldableTests.scala @@ -107,14 +107,35 @@ class FoldableTestsAdditional extends CatsSuite { larger.value should === (large.map(_ + 1)) } - test("Foldable[List].foldM stack safety") { - def nonzero(acc: Long, x: Long): Option[Long] = + def checkFoldMStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Foldable[F]): Unit = { + def nonzero(acc: Long, x: Int): Option[Long] = if (x == 0) None else Some(acc + x) - val n = 100000L - val expected = n*(n+1)/2 - val actual = Foldable[List].foldM((1L to n).toList, 0L)(nonzero) - assert(actual.get == expected) + val n = 100000 + val expected = n.toLong*(n.toLong+1)/2 + val foldMResult = F.foldM(fromRange(1 to n), 0L)(nonzero) + assert(foldMResult.get == expected) + () + } + + test("Foldable[List].foldM stack safety") { + checkFoldMStackSafety[List](_.toList) + } + + test("Foldable[Stream].foldM stack safety") { + checkFoldMStackSafety[Stream](_.toStream) + } + + test("Foldable[Vector].foldM stack safety") { + checkFoldMStackSafety[Vector](_.toVector) + } + + test("Foldable[Set].foldM stack safety") { + checkFoldMStackSafety[Set](_.toSet) + } + + test("Foldable[Map[String, ?]].foldM stack safety") { + checkFoldMStackSafety[Map[String, ?]](_.map(x => x.toString -> x).toMap) } test("Foldable[Stream]") { @@ -141,6 +162,9 @@ class FoldableTestsAdditional extends CatsSuite { // test trampolining val large = Stream((1 to 10000): _*) assert(contains(large, 10000).value) + + // test laziness of foldM + dangerous.foldM(0)((acc, a) => if (a < 2) Some(acc + a) else None) should === (None) } }