Skip to content

Commit

Permalink
Add Foldable and Traversable instances for Free
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron levin committed Aug 23, 2017
1 parent f4aa32d commit 360025f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
70 changes: 69 additions & 1 deletion free/src/main/scala/cats/free/Free.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ sealed abstract class Free[S[_], A] extends Product with Serializable {
}
}

/**
* A combination of step and fold.
*/
private[free] final def foldStep[B](
onPure: A => B,
onSuspend: S[A] => B,
onFlatMapped: ((S[X], X => Free[S, A]) forSome { type X }) => B
): B = this.step match {
case Pure(a) => onPure(a)
case Suspend(a) => onSuspend(a)
case FlatMapped(Suspend(fa), f) => onFlatMapped((fa, f))
case _ => sys.error("FlatMapped should be right associative after step")
}

/**
* Run to completion, using a function that extracts the resumption
* from its suspension functor.
Expand Down Expand Up @@ -161,7 +175,7 @@ sealed abstract class Free[S[_], A] extends Product with Serializable {
"Free(...)"
}

object Free {
object Free extends FreeInstances {

/**
* Return from the computation with the given value.
Expand Down Expand Up @@ -250,3 +264,57 @@ object Free {
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
}
}

private trait FreeFoldable[F[_]] extends Foldable[Free[F, ?]] {

implicit def F: Foldable[F]

override final def foldLeft[A, B](fa: Free[F, A], b: B)(f: (B, A) => B): B =
fa.foldStep(
a => f(b, a),
fa => F.foldLeft(fa, b)(f),
{ case (fx, g) => F.foldLeft(fx, b)((bb, x) => foldLeft(g(x), bb)(f)) }
)

override final def foldRight[A, B](fa: Free[F, A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
fa.foldStep(
a => f(a, lb),
fa => F.foldRight(fa, lb)(f),
{ case (fx, g) => F.foldRight(fx, lb)((a, lbb) => lbb.flatMap(bb => foldRight(g(a), Eval.now(bb))(f))) }
)
}

private trait FreeTraverse[F[_]] extends Traverse[Free[F, ?]] with FreeFoldable[F] {
implicit def TraversableF: Traverse[F]

def F: Foldable[F] = TraversableF

override final def traverse[G[_], A, B](fa: Free[F, A])(f: A => G[B])(implicit G: Applicative[G]): G[Free[F, B]] =
fa.resume match {
case Right(a) => G.map(f(a))(Free.pure(_))
case Left(ffreeA) => G.map(TraversableF.traverse(ffreeA)(traverse(_)(f)))(Free.roll(_))
}

// Override Traverse's map to use Free's map for better performance
override final def map[A, B](fa: Free[F, A])(f: A => B): Free[F, B] = fa.map(f)
}

sealed private[free] abstract class FreeInstances {

implicit def catsFreeFoldableForFree[F[_]](
implicit
foldableF: Foldable[F]
): Foldable[Free[F, ?]] =
new FreeFoldable[F] {
val F = foldableF
}

implicit def catsFreeTraverseForFree[F[_]](
implicit
traversableF: Traverse[F]
): Traverse[Free[F, ?]] =
new FreeTraverse[F] {
val TraversableF = traversableF
val FunctorF = traversableF
}
}
25 changes: 24 additions & 1 deletion free/src/test/scala/cats/free/FreeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package free

import cats.arrow.FunctionK
import cats.data.EitherK
import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests}
import cats.laws.discipline.{CartesianTests, FoldableTests, MonadTests, SerializableTests, TraverseTests}
import cats.laws.discipline.arbitrary.catsLawsArbitraryForFn0
import cats.tests.CatsSuite

Expand All @@ -18,6 +18,19 @@ class FreeTests extends CatsSuite {
checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int])
checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]]))

locally {
implicit val instance = Free.catsFreeFoldableForFree[Option]

checkAll("Free[Option, ?]", FoldableTests[Free[Option,?]].foldable[Int,Int])
checkAll("Foldable[Free[Option,?]]", SerializableTests.serializable(Foldable[Free[Option,?]]))
}

locally {
implicit val instance = Free.catsFreeTraverseForFree[Option]
checkAll("Free[Option,?]", TraverseTests[Free[Option,?]].traverse[Int, Int, Int, Int, Option, Option])
checkAll("Traverse[Free[Option,?]]", SerializableTests.serializable(Traverse[Free[Option,?]]))
}

test("toString is stack-safe") {
val r = Free.pure[List, Int](333)
val rr = (1 to 1000000).foldLeft(r)((r, _) => r.map(_ + 1))
Expand Down Expand Up @@ -82,6 +95,16 @@ class FreeTests extends CatsSuite {
assert(10000 == a(0).foldMap(runner))
}

test("foldRight is stack safe") {
val instance = Free.catsFreeFoldableForFree[Option]
val n = 50000
val freeOption: Int => Free[Option, Int] = x => Free.pure(x)
val free = (1 to n).foldLeft(freeOption(0))((r, _) => r.flatMap(n => freeOption(n + 1)))
val result = instance.foldRight(free, Eval.now(0))((a, lb) => lb.map(_ + a)).value

assert(n == result)
}

test(".runTailRec") {
val r = Free.pure[List, Int](12358)
def recurse(r: Free[List, Int], n: Int): Free[List, Int] =
Expand Down

0 comments on commit 360025f

Please sign in to comment.