Skip to content

Commit

Permalink
backported #3041, added tailrec instance for StacksafeMonad and Defer (
Browse files Browse the repository at this point in the history
  • Loading branch information
gagandeepkalra authored Mar 11, 2020
1 parent 8227961 commit 8e50b78
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 10 deletions.
19 changes: 9 additions & 10 deletions bench/src/main/scala/cats/bench/TrampolineBench.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package cats.bench

import org.openjdk.jmh.annotations.{Benchmark, Scope, State}

import cats._
import cats.implicits._
import cats.free.Trampoline

import scala.util.control.TailCalls

@State(Scope.Benchmark)
class TrampolineBench {

Expand All @@ -30,14 +31,12 @@ class TrampolineBench {
y <- Trampoline.defer(trampolineFib(n - 2))
} yield x + y

// TailRec[A] only has .flatMap in 2.11.
@Benchmark
def stdlib(): Int = stdlibFib(N).result

// @Benchmark
// def stdlib(): Int = stdlibFib(N).result
//
// def stdlibFib(n: Int): TailCalls.TailRec[Int] =
// if (n < 2) TailCalls.done(n) else for {
// x <- TailCalls.tailcall(stdlibFib(n - 1))
// y <- TailCalls.tailcall(stdlibFib(n - 2))
// } yield x + y
def stdlibFib(n: Int): TailCalls.TailRec[Int] =
if (n < 2) TailCalls.done(n) else for {
x <- TailCalls.tailcall(stdlibFib(n - 1))
y <- TailCalls.tailcall(stdlibFib(n - 2))
} yield x + y
}
1 change: 1 addition & 0 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 {
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
def extract[A](la: Eval[A]): A = la.value
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
override def unit: Eval[Unit] = Eval.Unit
}

implicit val catsDeferForEval: Defer[Eval] =
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/cats/instances/all.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ trait AllInstancesBinCompat7
with VectorInstancesBinCompat1
with EitherInstancesBinCompat0
with StreamInstancesBinCompat1
with TailRecInstances
with SortedSetInstancesBinCompat2
1 change: 1 addition & 0 deletions core/src/main/scala/cats/instances/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ package object instances {
with SortedSetInstancesBinCompat2
object stream extends StreamInstances with StreamInstancesBinCompat0 with StreamInstancesBinCompat1
object string extends StringInstances
object tailRec extends TailRecInstances
object try_ extends TryInstances
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
object unit extends UnitInstances
Expand Down
26 changes: 26 additions & 0 deletions core/src/main/scala/cats/instances/tailrec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cats.instances

import cats.{Defer, StackSafeMonad}
import scala.util.control.TailCalls.{done, tailcall, TailRec}

trait TailRecInstances {
implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
TailRecInstances.catsInstancesForTailRec
}

private object TailRecInstances {
val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
new StackSafeMonad[TailRec] with Defer[TailRec] {
def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)

def pure[A](a: A): TailRec[A] = done(a)

override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] =
fa.map(f)

def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] =
fa.flatMap(f)

override val unit: TailRec[Unit] = done(())
}
}
29 changes: 29 additions & 0 deletions tests/src/test/scala/cats/tests/TailRecSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package cats.tests

import cats.{Defer, Eq, Monad}
import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests}
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.{Arbitrary, Cogen, Gen}

import scala.util.control.TailCalls.{done, tailcall, TailRec}

class TailRecSuite extends CatsSuite {

implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] =
Arbitrary(
Gen.frequency(
(3, arbitrary[A].map(done)),
(1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })),
(1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) }))
)
)

implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] =
Eq.by[TailRec[A], A](_.result)

checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int])
checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec]))

checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int])
checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec]))
}

0 comments on commit 8e50b78

Please sign in to comment.