diff --git a/free/src/main/scala/cats/free/Cofree.scala b/free/src/main/scala/cats/free/Cofree.scala index 7c8fa93669..c1408a20e8 100644 --- a/free/src/main/scala/cats/free/Cofree.scala +++ b/free/src/main/scala/cats/free/Cofree.scala @@ -34,15 +34,15 @@ final case class Cofree[S[_], A](head: A, tail: Eval[S[Cofree[S, A]]]) { /** Transform the branching functor, using the T functor to perform the recursion. */ def mapBranchingT[T[_]](nat: S ~> T)(implicit T: Functor[T]): Cofree[T, A] = - Cofree[T, A](head, tail.map(v => T.map(nat(v))(_.mapBranchingT(nat)))) + Cofree.anaE(this)(_.tail.map(nat(_)), _.head) /** Map `f` over each subtree of the computation. */ def coflatMap[B](f: Cofree[S, A] => B)(implicit S: Functor[S]): Cofree[S, B] = - Cofree[S, B](f(this), tail.map(S.map(_)(_.coflatMap(f)))) + Cofree.anaE(this)(_.tail, f) /** Replace each node in the computation with the subtree from that node downwards */ def coflatten(implicit S: Functor[S]): Cofree[S, Cofree[S, A]] = - Cofree[S, Cofree[S, A]](this, tail.map(S.map(_)(_.coflatten))) + Cofree.anaE(this)(_.tail, identity) /** Alias for head. */ def extract: A = head @@ -53,7 +53,7 @@ final case class Cofree[S[_], A](head: A, tail: Eval[S[Cofree[S, A]]]) { /** Evaluate the entire Cofree tree. */ def forceAll(implicit S: Functor[S]): Cofree[S, A] = - Cofree[S, A](head, Eval.now(tail.map(S.map(_)(_.forceAll)).value)) + Cofree.anaE(this)(sa => Eval.now(sa.tail.value), _.head) } @@ -65,7 +65,16 @@ object Cofree extends CofreeInstances { /** Cofree anamorphism with a fused map, lazily evaluated. */ def ana[F[_], A, B](a: A)(coalg: A => F[A], f: A => B)(implicit F: Functor[F]): Cofree[F, B] = - Cofree[F, B](f(a), Eval.later(F.map(coalg(a))(ana(_)(coalg, f)))) + anaE(a)(a => Eval.later(coalg(a)), f) + + /** Cofree anamorphism with a fused map. */ + def anaE[F[_], A, B](a: A)(coalg: A => Eval[F[A]], f: A => B)(implicit F: Functor[F]): Cofree[F, B] = + Cofree[F, B](f(a), mapSemilazy(coalg(a))(fa => F.map(fa)(anaE(_)(coalg, f)))) + + private def mapSemilazy[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa match { + case Now(a) => Now(f(a)) + case other => other.map(f) + } /** * A stack-safe algebraic recursive fold out of the cofree comonad.