Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy foldM for "Iterables" #1414

Merged
merged 4 commits into from
Oct 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ import simulacrum.typeclass

/**
* Left associative monadic folding on `F`.
*
* The default implementation of this is based on `foldLeft`, and thus will
* always fold across the entire structure. Certain structures are able to
* implement this in such a way that folds can be short-circuited (not
* traverse the entirety of the structure), depending on the `G` result
* produced at a given step.
*/
def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
foldLeft(fa, G.pure(z))((gb, a) => G.flatMap(gb)(f(_, a)))
Expand Down Expand Up @@ -372,4 +378,39 @@ object Foldable {
Eval.defer(if (it.hasNext) f(it.next, loop()) else lb)
loop()
}

/**
* Implementation of [[Foldable.foldM]] which can short-circuit for
* structures with an `Iterator`.
*
* For example we can sum a `Stream` of integers and stop if
* the sum reaches 100 (if we reach the end of the `Stream`
* before getting to 100 we return the total sum) :
*
* {{{
* scala> import cats.implicits._
* scala> type LongOr[A] = Either[Long, A]
* scala> def sumStream(s: Stream[Int]): Long =
* | Foldable.iteratorFoldM[LongOr, Int, Long](s.toIterator, 0L){ (acc, n) =>
* | val sum = acc + n
* | if (sum < 100L) Right(sum) else Left(sum)
* | }.merge
*
* scala> sumStream(Stream.continually(1))
* res0: Long = 100
*
* scala> sumStream(Stream(1,2,3,4))
* res1: Long = 10
* }}}
*
* Note that `Foldable[Stream].foldM` uses this method underneath, so
* you wouldn't call this method explicitly like in the example above.
*/
def iteratorFoldM[M[_], A, B](it: Iterator[A], z: B)(f: (B, A) => M[B])(implicit M: Monad[M]): M[B] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we also mention here in the doc the incentive of this special implementation for Iteratables?

val go: B => M[Either[B, B]] = { b =>
if (it.hasNext) M.map(f(b, it.next))(Left(_))
else M.pure(Right(b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Iterator is mutable, I would keep it private to the method (otherwise we lose referential transparency). I would keep the signature as you had before, and only use Iterator inside.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if this was not clear from my previous comment.

}
M.tailRecM(z)(go)
}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ trait ListInstances extends cats.kernel.instances.ListInstances {
override def isEmpty[A](fa: List[A]): Boolean = fa.isEmpty

override def filter[A](fa: List[A])(f: A => Boolean): List[A] = fa.filter(f)

override def foldM[G[_], A, B](fa: List[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForList[A:Show]: Show[List[A]] =
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ trait MapInstances extends cats.kernel.instances.MapInstances {
override def size[A](fa: Map[K, A]): Long = fa.size.toLong

override def isEmpty[A](fa: Map[K, A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Map[K, A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.valuesIterator, z)(f)
}
// scalastyle:on method.length
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/set.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ trait SetInstances extends cats.kernel.instances.SetInstances {
fa.forall(p)

override def isEmpty[A](fa: Set[A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Set[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForSet[A:Show]: Show[Set[A]] = new Show[Set[A]] {
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances {
override def filter[A](fa: Stream[A])(f: A => Boolean): Stream[A] = fa.filter(f)

override def collect[A, B](fa: Stream[A])(f: PartialFunction[A, B]): Stream[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Stream[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForStream[A: Show]: Show[Stream[A]] =
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
override def filter[A](fa: Vector[A])(f: A => Boolean): Vector[A] = fa.filter(f)

override def collect[A, B](fa: Vector[A])(f: PartialFunction[A, B]): Vector[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Vector[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForVector[A:Show]: Show[Vector[A]] =
Expand Down
36 changes: 30 additions & 6 deletions tests/src/test/scala/cats/tests/FoldableTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,35 @@ class FoldableTestsAdditional extends CatsSuite {
larger.value should === (large.map(_ + 1))
}

test("Foldable[List].foldM stack safety") {
def nonzero(acc: Long, x: Long): Option[Long] =
def checkFoldMStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Foldable[F]): Unit = {
def nonzero(acc: Long, x: Int): Option[Long] =
if (x == 0) None else Some(acc + x)

val n = 100000L
val expected = n*(n+1)/2
val actual = Foldable[List].foldM((1L to n).toList, 0L)(nonzero)
assert(actual.get == expected)
val n = 100000
val expected = n.toLong*(n.toLong+1)/2
val foldMResult = F.foldM(fromRange(1 to n), 0L)(nonzero)
assert(foldMResult.get == expected)
()
}

test("Foldable[List].foldM stack safety") {
checkFoldMStackSafety[List](_.toList)
}

test("Foldable[Stream].foldM stack safety") {
checkFoldMStackSafety[Stream](_.toStream)
}

test("Foldable[Vector].foldM stack safety") {
checkFoldMStackSafety[Vector](_.toVector)
}

test("Foldable[Set].foldM stack safety") {
checkFoldMStackSafety[Set](_.toSet)
}

test("Foldable[Map[String, ?]].foldM stack safety") {
checkFoldMStackSafety[Map[String, ?]](_.map(x => x.toString -> x).toMap)
}

test("Foldable[Stream]") {
Expand All @@ -141,6 +162,9 @@ class FoldableTestsAdditional extends CatsSuite {
// test trampolining
val large = Stream((1 to 10000): _*)
assert(contains(large, 10000).value)

// test laziness of foldM
dangerous.foldM(0)((acc, a) => if (a < 2) Some(acc + a) else None) should === (None)
}
}

Expand Down