Skip to content

Commit

Permalink
Merge pull request #4602 from jozic/monad.flatmaporkeep
Browse files Browse the repository at this point in the history
Add `flatMapOrKeep` to `Monad`
  • Loading branch information
satorg authored May 25, 2024
2 parents cff7d75 + 8a94478 commit a1fab57
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala-2/cats/syntax/MonadOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal {
def untilM_(p: F[Boolean])(implicit M: Monad[F]): F[Unit] = M.untilM_(fa)(p)
def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p)
def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p)
def flatMapOrKeep[A1 >: A](pfa: PartialFunction[A, F[A1]])(implicit M: Monad[F]): F[A1] =
M.flatMapOrKeep[A, A1](fa)(pfa)
}
2 changes: 2 additions & 0 deletions core/src/main/scala-3/cats/syntax/MonadOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal {
def untilM_(using M: Monad[F])(p: F[Boolean]): F[Unit] = M.untilM_(fa)(p)
def iterateWhile(using M: Monad[F])(p: A => Boolean): F[A] = M.iterateWhile(fa)(p)
def iterateUntil(using M: Monad[F])(p: A => Boolean): F[A] = M.iterateUntil(fa)(p)
def flatMapOrKeep[A1 >: A](using M: Monad[F])(pfa: PartialFunction[A, F[A1]]): F[A1] =
M.flatMapOrKeep[A, A1](fa)(pfa)
}
14 changes: 14 additions & 0 deletions core/src/main/scala/cats/Monad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ trait Monad[F[_]] extends FlatMap[F] with Applicative[F] {

tailRecM(branches.toList)(step)
}

/**
* Modifies the `A` value in `F[A]` with the supplied function, if the function is defined for the value.
* Example:
* {{{
* scala> import cats.syntax.all._
*
* scala> List(1, 2, 3).flatMapOrKeep{ case 2 => List(2, 22, 222) }
* res0: List[Int] = List(1, 2, 22, 222, 3)
* }}}
*/
def flatMapOrKeep[A, A1 >: A](fa: F[A])(pfa: PartialFunction[A, F[A1]]): F[A1] =
flatMap(fa)(a => pfa.applyOrElse(a, pure[A1]))

}

object Monad {
Expand Down
6 changes: 6 additions & 0 deletions laws/src/main/scala/cats/laws/MonadLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ trait MonadLaws[F[_]] extends ApplicativeLaws[F] with FlatMapLaws[F] {
def mapFlatMapCoherence[A, B](fa: F[A], f: A => B): IsEq[F[B]] =
fa.flatMap(a => F.pure(f(a))) <-> fa.map(f)

/**
* Make sure that flatMapOrKeep and flatMap are consistent.
*/
def flatMapOrKeepToFlatMapCoherence[A, A1 >: A](fa: F[A], pfa: PartialFunction[A, F[A1]]): IsEq[F[A1]] =
F.flatMapOrKeep[A, A1](fa)(pfa) <-> F.flatMap(fa)(a => pfa.applyOrElse(a, F.pure[A1]))

lazy val tailRecMStackSafety: IsEq[F[Int]] = {
val n = 50000
val res = F.tailRecM(0)(i => F.pure(if (i < n) Either.left(i + 1) else Either.right(i)))
Expand Down
6 changes: 4 additions & 2 deletions laws/src/main/scala/cats/laws/discipline/MonadTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ trait MonadTests[F[_]] extends ApplicativeTests[F] with FlatMapTests[F] {
Seq(
"monad left identity" -> forAll(laws.monadLeftIdentity[A, B] _),
"monad right identity" -> forAll(laws.monadRightIdentity[A] _),
"map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _)
"map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _),
"flatMapOrKeep flatMap coherence" -> forAll(laws.flatMapOrKeepToFlatMapCoherence[A, A] _)
) ++ (if (Platform.isJvm) Seq[(String, Prop)]("tailRecM stack safety" -> Prop.lzy(laws.tailRecMStackSafety))
else Seq.empty)
}
Expand Down Expand Up @@ -84,7 +85,8 @@ trait MonadTests[F[_]] extends ApplicativeTests[F] with FlatMapTests[F] {
Seq(
"monad left identity" -> forAll(laws.monadLeftIdentity[A, B] _),
"monad right identity" -> forAll(laws.monadRightIdentity[A] _),
"map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _)
"map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _),
"flatMapOrKeep flatMap coherence" -> forAll(laws.flatMapOrKeepToFlatMapCoherence[A, A] _)
)
}
}
Expand Down

0 comments on commit a1fab57

Please sign in to comment.