Skip to content

Commit

Permalink
Cofree comonad
Browse files Browse the repository at this point in the history
Removed Cofree[List, A] tests and added attribution

Remove unused imports

Add Reducible instance, separate tests, fix Traverse1 comment

Removed unfoldStrategy, added a test for unfold

Tests for Cofree.mapBranchingRoot/mapBranchingS/T

Add Cofree.tailForced and Cofree.forceAll tests, fix notrunning tests

Make instances classes package-private

Added Cofree.cata/cataM

Add docs

Add test for forceTail

Remove comments

tailEval=>tail, bracket doc references

Add unfoldM
  • Loading branch information
edmundnoble committed Dec 21, 2016
1 parent 73a6481 commit e5e7ab8
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 11 deletions.
162 changes: 162 additions & 0 deletions free/src/main/scala/cats/free/Cofree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package cats
package free

/**
* A free comonad for some branching functor `S`. Branching is done lazily using [[Eval]].
* A tree with data at the branches, as opposed to [[Free]] which is a tree with data at the leaves.
* Not an instruction set functor made into a program monad as in [[Free]], but an instruction set's outputs as a
* functor made into a tree of the possible worlds reachable using the instruction set.
*
* This Scala implementation of `Cofree` and its usages are derived from
* [[https://github.com/scalaz/scalaz/blob/series/7.3.x/core/src/main/scala/scalaz/Cofree.scala Scalaz's Cofree]],
* originally written by Rúnar Bjarnason.
*/
final case class Cofree[S[_], A](head: A, tail: Eval[S[Cofree[S, A]]]) {

/** Evaluates and returns the tail of the computation. */
def tailForced: S[Cofree[S, A]] = tail.value

/** Applies `f` to the head and `g` to the tail. */
def transform[B](f: A => B, g: Cofree[S, A] => Cofree[S, B])(implicit S: Functor[S]): Cofree[S, B] =
Cofree[S, B](f(head), tail.map(S.map(_)(g)))

/** Map over head and inner `S[_]` branches. */
def map[B](f: A => B)(implicit S: Functor[S]): Cofree[S, B] =
transform(f, _.map(f))

/** Transform the branching functor at the root of the Cofree tree. */
def mapBranchingRoot(nat: S ~> S)(implicit S: Functor[S]): Cofree[S, A] =
Cofree[S, A](head, tail.map(nat(_)))

/** Transform the branching functor, using the S functor to perform the recursion. */
def mapBranchingS[T[_]](nat: S ~> T)(implicit S: Functor[S]): Cofree[T, A] =
Cofree[T, A](head, tail.map(v => nat(S.map(v)(_.mapBranchingS(nat)))))

/** Transform the branching functor, using the T functor to perform the recursion. */
def mapBranchingT[T[_]](nat: S ~> T)(implicit T: Functor[T]): Cofree[T, A] =
Cofree[T, A](head, tail.map(v => T.map(nat(v))(_.mapBranchingT(nat))))

/** Map `f` over each subtree of the computation. */
def coflatMap[B](f: Cofree[S, A] => B)(implicit S: Functor[S]): Cofree[S, B] =
Cofree[S, B](f(this), tail.map(S.map(_)(_.coflatMap(f))))

/** Replace each node in the computation with the subtree from that node downwards */
def coflatten(implicit S: Functor[S]): Cofree[S, Cofree[S, A]] =
Cofree[S, Cofree[S, A]](this, tail.map(S.map(_)(_.coflatten)))

/** Alias for head. */
def extract: A = head

/** Evaluate just the tail. */
def forceTail: Cofree[S, A] =
Cofree[S, A](head, Eval.now(tail.value))

/** Evaluate the entire Cofree tree. */
def forceAll(implicit S: Functor[S]): Cofree[S, A] =
Cofree[S, A](head, Eval.now(tail.map(S.map(_)(_.forceAll)).value))

}

object Cofree extends CofreeInstances {

/** Cofree anamorphism, lazily evaluated. */
def unfold[F[_], A](a: A)(f: A => F[A])(implicit F: Functor[F]): Cofree[F, A] =
Cofree[F, A](a, Eval.later(F.map(f(a))(unfold(_)(f))))

/** Cofree monadic anamorphism, lazily evaluated. */
def unfoldM[F[_], M[_], A](a: A)(f: A => M[F[A]])(implicit F: Traverse[F], M: Monad[M]): M[Cofree[F, A]] = {
M.flatMap(f(a)) { (fa: F[A]) =>
val looped: M[F[Cofree[F, A]]] = F.traverse(fa)(unfoldM[F, M, A](_)(f))
val rolled: M[Cofree[F, A]] = M.map(looped) { (fcf: F[Cofree[F, A]]) =>
Cofree[F, A](a, Eval.now(fcf))
}
rolled
}
}

/**
* A stack-safe algebraic recursive fold out of the cofree comonad.
*/
def cata[F[_], A, B](cof: Cofree[F, A])(folder: (A, F[B]) => Eval[B])(implicit F: Traverse[F]): Eval[B] =
F.traverse(cof.tailForced)(cata(_)(folder)).flatMap(folder(cof.head, _))

/**
* A monadic recursive fold out of the cofree comonad into a monad which can express Eval's stack-safety.
*/
def cataM[F[_], M[_], A, B](cof: Cofree[F, A])(folder: (A, F[B]) => M[B])(inclusion: Eval ~> M)(implicit F: Traverse[F], M: Monad[M]): M[B] = {
def loop(fr: Cofree[F, A]): Eval[M[B]] = {
val looped: M[F[B]] = F.traverse[M, Cofree[F, A], B](fr.tailForced)(fr => M.flatten(inclusion(Eval.defer(loop(fr)))))
val folded: M[B] = M.flatMap(looped)(fb => folder(fr.head, fb))
Eval.now(folded)
}
M.flatten(inclusion(loop(cof)))
}

}

sealed private[free] abstract class CofreeInstances2 {
implicit def catsReducibleForCofree[F[_] : Foldable]: Reducible[Cofree[F, ?]] =
new CofreeReducible[F] {
def F = implicitly
}
}

sealed private[free] abstract class CofreeInstances1 extends CofreeInstances2 {
implicit def catsTraverseForCofree[F[_] : Traverse]: Traverse[Cofree[F, ?]] =
new CofreeTraverse[F] {
def F = implicitly
}
}

sealed private[free] abstract class CofreeInstances extends CofreeInstances1 {
implicit def catsFreeComonadForCofree[S[_] : Functor]: Comonad[Cofree[S, ?]] = new CofreeComonad[S] {
def F = implicitly
}
}

private trait CofreeComonad[S[_]] extends Comonad[Cofree[S, ?]] {
implicit def F: Functor[S]

override final def extract[A](p: Cofree[S, A]): A = p.extract

override final def coflatMap[A, B](a: Cofree[S, A])(f: Cofree[S, A] => B): Cofree[S, B] = a.coflatMap(f)

override final def coflatten[A](a: Cofree[S, A]): Cofree[S, Cofree[S, A]] = a.coflatten

override final def map[A, B](a: Cofree[S, A])(f: A => B): Cofree[S, B] = a.map(f)
}

private trait CofreeReducible[F[_]] extends Reducible[Cofree[F, ?]] {
implicit def F: Foldable[F]

override final def foldMap[A, B](fa: Cofree[F, A])(f: A => B)(implicit M: Monoid[B]): B =
M.combine(f(fa.head), F.foldMap(fa.tailForced)(foldMap(_)(f)))

override final def foldRight[A, B](fa: Cofree[F, A], z: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
f(fa.head, fa.tail.flatMap(F.foldRight(_, z)(foldRight(_, _)(f))))

override final def foldLeft[A, B](fa: Cofree[F, A], z: B)(f: (B, A) => B): B =
F.foldLeft(fa.tailForced, f(z, fa.head))((b, cof) => foldLeft(cof, b)(f))

override final def reduceLeftTo[A, B](fa: Cofree[F, A])(z: A => B)(f: (B, A) => B): B =
F.foldLeft(fa.tailForced, z(fa.head))((b, cof) => foldLeft(cof, b)(f))

override def reduceRightTo[A, B](fa: Cofree[F, A])(z: A => B)(f: (A, Eval[B]) => Eval[B]): Eval[B] = {
foldRight(fa, Eval.now((None: Option[B]))) {
case (l, e) => e.flatMap {
case None => Eval.now(Some(z(l)))
case Some(r) => f(l, Eval.now(r)).map(Some(_))
}
}.map(_.getOrElse(sys.error("reduceRightTo")))
}

}

private trait CofreeTraverse[F[_]] extends Traverse[Cofree[F, ?]] with CofreeReducible[F] with CofreeComonad[F] {
implicit def F: Traverse[F]

override final def traverse[G[_], A, B](fa: Cofree[F, A])(f: A => G[B])(implicit G: Applicative[G]): G[Cofree[F, B]] =
G.map2(f(fa.head), F.traverse(fa.tailForced)(traverse(_)(f)))((h, t) => Cofree[F, B](h, Eval.now(t)))

}

183 changes: 183 additions & 0 deletions free/src/test/scala/cats/free/CofreeTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package cats
package free

import cats.data.{NonEmptyList, OptionT}
import cats.laws.discipline.{CartesianTests, ComonadTests, ReducibleTests, SerializableTests, TraverseTests}
import cats.syntax.list._
import cats.tests.{CatsSuite, Spooky}
import org.scalacheck.{Arbitrary, Cogen, Gen}

class CofreeTests extends CatsSuite {

import CofreeTests._

implicit val iso = CartesianTests.Isomorphisms.invariant[Cofree[Option, ?]]

checkAll("Cofree[Option, ?]", ComonadTests[Cofree[Option, ?]].comonad[Int, Int, Int])
locally {
implicit val instance = Cofree.catsTraverseForCofree[Option]
checkAll("Cofree[Option, ?]", TraverseTests[Cofree[Option, ?]].traverse[Int, Int, Int, Int, Option, Option])
checkAll("Traverse[Cofree[Option, ?]]", SerializableTests.serializable(Traverse[Cofree[Option, ?]]))
}
locally {
implicit val instance = Cofree.catsReducibleForCofree[Option]
checkAll("Cofree[Option, ?]", ReducibleTests[Cofree[Option, ?]].reducible[Option, Int, Int])
checkAll("Reducible[Cofree[Option, ?]]", SerializableTests.serializable(Reducible[Cofree[Option, ?]]))
}
checkAll("Comonad[Cofree[Option, ?]]", SerializableTests.serializable(Comonad[Cofree[Option, ?]]))

test("Cofree.unfold") {
val unfoldedHundred: CofreeNel[Int] = Cofree.unfold[Option, Int](0)(i => if (i == 100) None else Some(i + 1))
val nelUnfoldedHundred: NonEmptyList[Int] = NonEmptyList.fromListUnsafe(List.tabulate(101)(identity))
cofNelToNel(unfoldedHundred) should ===(nelUnfoldedHundred)
}

test("Cofree.unfoldM") {
val unfoldedHundred: Option[CofreeNel[Int]] =
Cofree.unfoldM[Option, Option, Int](0)(i => if (i == 100) Some(None) else Some(Some(i + 1)))
val unfoldedNone =
Cofree.unfoldM[Option, Option, Int](0)(i => if (i == 100) None else Some(Some(i + 1)))
val nelUnfoldedHundred: NonEmptyList[Int] =
NonEmptyList.fromListUnsafe(List.tabulate(101)(identity))
unfoldedHundred.map(cofNelToNel(_)) should ===(Some(nelUnfoldedHundred))
unfoldedNone should ===(None)
}

test("Cofree.tailForced") {
val spooky = new Spooky
val incrementor =
Cofree.unfold[Id, Int](spooky.counter) { _ => spooky.increment(); spooky.counter }
spooky.counter should ===(0)
incrementor.tailForced
spooky.counter should ===(1)
}

test("Cofree.forceTail") {
val spooky = new Spooky
val incrementor =
Cofree.unfold[Id, Int](spooky.counter) { _ => spooky.increment(); spooky.counter }
spooky.counter should ===(0)
incrementor.forceTail
spooky.counter should ===(1)
}

test("Cofree.forceAll") {
val spooky = new Spooky
val incrementor =
Cofree.unfold[Option, Int](spooky.counter)(i =>
if (i == 5) {
None
} else {
spooky.increment()
Some(spooky.counter)
})
spooky.counter should ===(0)
incrementor.forceAll
spooky.counter should ===(5)
}

test("Cofree.mapBranchingRoot") {
val unfoldedHundred: CofreeNel[Int] = Cofree.unfold[Option, Int](0)(i => if (i == 100) None else Some(i + 1))
val withNoneRoot = unfoldedHundred.mapBranchingRoot(new (Option ~> Option) {
override def apply[A](opt: Option[A]): Option[A] = None
})
val nelUnfoldedOne: NonEmptyList[Int] = NonEmptyList.of(0)
cofNelToNel(withNoneRoot) should ===(nelUnfoldedOne)
}

val unfoldedHundred: Cofree[Option, Int] = Cofree.unfold[Option, Int](0)(i => if (i == 100) None else Some(i + 1))
test("Cofree.mapBranchingS/T") {
val toList = new (Option ~> List) {
override def apply[A](lst: Option[A]): List[A] = lst.fold[List[A]](Nil)(_ :: Nil)
}
val toNelS = unfoldedHundred.mapBranchingS(toList)
val toNelT = unfoldedHundred.mapBranchingT(toList)
val nelUnfoldedOne: NonEmptyList[Int] = NonEmptyList.fromListUnsafe(List.tabulate(101)(identity))
cofRoseTreeToNel(toNelS) should ===(nelUnfoldedOne)
cofRoseTreeToNel(toNelT) should ===(nelUnfoldedOne)
}

val nelUnfoldedHundred: NonEmptyList[Int] = NonEmptyList.fromListUnsafe(List.tabulate(101)(identity))

test("Cofree.cata") {
val cata =
Cofree.cata[Option, Int, NonEmptyList[Int]](unfoldedHundred)(
(i, lb) => Eval.now(NonEmptyList(i, lb.fold[List[Int]](Nil)(_.toList)))
).value
cata should ===(nelUnfoldedHundred)
}

test("Cofree.cataM") {

type EvalOption[A] = OptionT[Eval, A]

val folder: (Int, Option[NonEmptyList[Int]]) => EvalOption[NonEmptyList[Int]] =
(i, lb) => if (i > 100) OptionT.none else OptionT.some(NonEmptyList(i, lb.fold[List[Int]](Nil)(_.toList)))
val inclusion = new (Eval ~> EvalOption) {
override def apply[A](fa: Eval[A]): EvalOption[A] = OptionT.liftF(fa)
}

val cataHundred =
Cofree.cataM[Option, EvalOption, Int, NonEmptyList[Int]](unfoldedHundred)(folder)(inclusion).value.value
val cataHundredOne =
Cofree.cataM[Option, EvalOption, Int, NonEmptyList[Int]](
Cofree[Option, Int](101, Eval.now(Some(unfoldedHundred)))
)(folder)(inclusion).value.value
cataHundred should ===(Some(nelUnfoldedHundred))
cataHundredOne should ===(None)
}

}

object CofreeTests extends CofreeTestsInstances

sealed trait CofreeTestsInstances {

type CofreeNel[A] = Cofree[Option, A]
type CofreeRoseTree[A] = Cofree[List, A]

implicit def cofNelEq[A](implicit e: Eq[A]): Eq[CofreeNel[A]] = new Eq[CofreeNel[A]] {
override def eqv(a: CofreeNel[A], b: CofreeNel[A]): Boolean = {
def tr(a: CofreeNel[A], b: CofreeNel[A]): Boolean =
(a.tailForced, b.tailForced) match {
case (Some(at), Some(bt)) if e.eqv(a.head, b.head) => tr(at, bt)
case (None, None) if e.eqv(a.head, b.head) => true
case _ => false
}
tr(a, b)
}
}


implicit def CofreeOptionCogen[A: Cogen]: Cogen[CofreeNel[A]] =
implicitly[Cogen[List[A]]].contramap[CofreeNel[A]](cofNelToNel(_).toList)

implicit def CofreeOptionArb[A: Arbitrary]: Arbitrary[CofreeNel[A]] = {
val arb = Arbitrary {
Gen.resize(20, Gen.nonEmptyListOf(implicitly[Arbitrary[A]].arbitrary))
}
Arbitrary {
arb.arbitrary.map(l => (l.head, l.tail) match {
case (h, Nil) => nelToCofNel(NonEmptyList(h, Nil))
case (h, t) => nelToCofNel(NonEmptyList(h, t))
})
}
}

val nelToCofNel = new (NonEmptyList ~> CofreeNel) {
override def apply[A](fa: NonEmptyList[A]): CofreeNel[A] =
Cofree[Option, A](fa.head, Eval.later(fa.tail.toNel.map(apply)))
}

val cofNelToNel = new (CofreeNel ~> NonEmptyList) {
override def apply[A](fa: CofreeNel[A]): NonEmptyList[A] =
NonEmptyList[A](fa.head, fa.tailForced.fold[List[A]](Nil)(apply(_).toList))
}

val cofRoseTreeToNel = new (CofreeRoseTree ~> NonEmptyList) {
override def apply[A](fa: CofreeRoseTree[A]): NonEmptyList[A] =
NonEmptyList[A](fa.head, fa.tailForced.flatMap(apply(_).toList))
}


}
11 changes: 0 additions & 11 deletions tests/src/test/scala/cats/tests/EvalTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,6 @@ import cats.laws.discipline.arbitrary._
import cats.kernel.laws.{GroupLaws, OrderLaws}

class EvalTests extends CatsSuite {

/**
* Class for spooky side-effects and action-at-a-distance.
*
* It is basically a mutable counter that can be used to measure how
* many times an otherwise pure function is being evaluted.
*/
class Spooky(var counter: Int = 0) {
def increment(): Unit = counter += 1
}

/**
* This method creates a Eval[A] instance (along with a
* corresponding Spooky instance) from an initial `value` using the
Expand Down
13 changes: 13 additions & 0 deletions tests/src/test/scala/cats/tests/Spooky.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package cats
package tests

/**
* Class for spooky side-effects and action-at-a-distance.
*
* It is basically a mutable counter that can be used to measure how
* many times an otherwise pure function is being evaluted.
*/
class Spooky(var counter: Int = 0) {
def increment(): Unit = counter += 1
}

0 comments on commit e5e7ab8

Please sign in to comment.