diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 65150e3a2a..4f8b6f8b0c 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -11,26 +11,46 @@ import cats.syntax.either._ */ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable { - def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => + def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] = + StateT.applyF(F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => fas(a).run(s) } - }) + } + }) - def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => - F.map(faf(a))((s, _)) - } + def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] = + StateT.applyF(F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } } - ) + }) - def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = + def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } + def map2[B, Z](sb: StateT[F, S, B])(fn: (A, B) => Z)(implicit F: FlatMap[F]): StateT[F, S, Z] = + StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) => + ssa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => + F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) } + } + } + }) + + def map2Eval[B, Z](sb: Eval[StateT[F, S, B]])(fn: (A, B) => Z)(implicit F: FlatMap[F]): Eval[StateT[F, S, Z]] = + F.map2Eval(runF, sb.map(_.runF)) { (ssa, ssb) => + ssa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => + F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) } + } + } + }.map(StateT.applyF) + + def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] = + map2(sb)((_, _)) + /** * Run with the provided initial state value */ @@ -69,10 +89,13 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable /** * Like [[map]], but also allows the state (`S`) value to be modified. */ - def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] = - transformF { fsa => - F.map(fsa){ case (s, a) => f(s, a) } - } + def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] = + StateT.applyF( + F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.map(fsa) { case (s, a) => f(s, a) } + } + }) /** * Like [[transform]], but allows the context to change from `F` to `G`. @@ -98,31 +121,31 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable * res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0)) * }}} */ - def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] = - StateT { r => - F.flatMap(runF) { ff => + def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] = + StateT.applyF(F.map(runF) { sfsa => + { r: R => val s = f(r) - val nextState = ff(s) - F.map(nextState) { case (s, a) => (g(r, s), a) } + val fsa = sfsa(s) + F.map(fsa) { case (s, a) => (g(r, s), a) } } - } + }) /** * Modify the state (`S`) component. */ - def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] = + def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] = transform((s, a) => (f(s), a)) /** * Inspect a value from the input state, without modifying the state. */ - def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] = + def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] = transform((s, _) => (s, f(s))) /** * Get the input state, without modifying the state. */ - def get(implicit F: Monad[F]): StateT[F, S, S] = + def get(implicit F: Functor[F]): StateT[F, S, S] = inspect(identity) } @@ -182,11 +205,16 @@ private[data] sealed trait StateTInstances2 extends StateTInstances3 { new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 } } -private[data] sealed trait StateTInstances3 { +private[data] sealed trait StateTInstances3 extends StateTInstances4 { implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] = new StateTMonad[F, S] { implicit def F = F0 } } +private[data] sealed trait StateTInstances4 { + implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] = + new StateTFunctor[F, S] { implicit def F = F0 } +} + // To workaround SI-7139 `object State` needs to be defined inside the package object // together with the type alias. private[data] abstract class StateFunctions { @@ -220,6 +248,12 @@ private[data] abstract class StateFunctions { def set[S](s: S): State[S, Unit] = State(_ => (s, ())) } +private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] { + implicit def F: Functor[F] + + def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) +} + private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { implicit def F: Monad[F] @@ -229,8 +263,20 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] = fa.flatMap(f) + override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] = + ff.map2(fa) { case (f, a) => f(a) } + override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) + override def map2[A, B, Z](fa: StateT[F, S, A], fb: StateT[F, S, B])(fn: (A, B) => Z): StateT[F, S, Z] = + fa.map2(fb)(fn) + + override def map2Eval[A, B, Z](fa: StateT[F, S, A], fb: Eval[StateT[F, S, B]])(fn: (A, B) => Z): Eval[StateT[F, S, Z]] = + fa.map2Eval(fb)(fn) + + override def product[A, B](fa: StateT[F, S, A], fb: StateT[F, S, B]): StateT[F, S, (A, B)] = + fa.product(fb) + def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] = StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) { case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 8d1e876a92..28bb8f2403 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -30,7 +30,7 @@ class StateTTests extends CatsSuite { } test("State.get and StateT.get are consistent") { - forAll{ (s: String) => + forAll{ (s: String) => val state: State[String, String] = State.get val stateT: State[String, String] = StateT.get state.run(s) should === (stateT.run(s)) @@ -195,7 +195,25 @@ class StateTTests extends CatsSuite { } - implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataMonadForStateT(ListWrapper.monad)) + implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataFunctorForStateT(ListWrapper.monad)) + + { + // F has a Functor + implicit val F: Functor[ListWrapper] = ListWrapper.monad + // We only need a Functor on F to find a Functor on StateT + Functor[StateT[ListWrapper, Int, ?]] + } + + { + // F needs a Monad to do Eq on StateT + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Functor[StateT[ListWrapper, Int, ?]] = StateT.catsDataFunctorForStateT + + checkAll("StateT[ListWrapper, Int, Int]", FunctorTests[StateT[ListWrapper, Int, ?]].functor[Int, Int, Int]) + checkAll("Functor[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[StateT[ListWrapper, Int, ?]])) + + Functor[StateT[ListWrapper, Int, ?]] + } { // F has a Monad @@ -265,7 +283,7 @@ class StateTTests extends CatsSuite { // F has a MonadError implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int] - + checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int , ?], Unit].monadError[Int, Int, Int]) checkAll("MonadError[StateT[Option, Int , ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit])) }