diff --git a/core/src/main/scala-2/cats/syntax/MonadOps.scala b/core/src/main/scala-2/cats/syntax/MonadOps.scala index 04897be36c..cc2b782378 100644 --- a/core/src/main/scala-2/cats/syntax/MonadOps.scala +++ b/core/src/main/scala-2/cats/syntax/MonadOps.scala @@ -30,4 +30,6 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal { def untilM_(p: F[Boolean])(implicit M: Monad[F]): F[Unit] = M.untilM_(fa)(p) def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p) def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) + def flatMapOrKeep[A1 >: A](pfa: PartialFunction[A, F[A1]])(implicit M: Monad[F]): F[A1] = + M.flatMapOrKeep[A, A1](fa)(pfa) } diff --git a/core/src/main/scala-3/cats/syntax/MonadOps.scala b/core/src/main/scala-3/cats/syntax/MonadOps.scala index ef3e285bf4..7924b01b34 100644 --- a/core/src/main/scala-3/cats/syntax/MonadOps.scala +++ b/core/src/main/scala-3/cats/syntax/MonadOps.scala @@ -30,4 +30,6 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal { def untilM_(using M: Monad[F])(p: F[Boolean]): F[Unit] = M.untilM_(fa)(p) def iterateWhile(using M: Monad[F])(p: A => Boolean): F[A] = M.iterateWhile(fa)(p) def iterateUntil(using M: Monad[F])(p: A => Boolean): F[A] = M.iterateUntil(fa)(p) + def flatMapOrKeep[A1 >: A](using M: Monad[F])(pfa: PartialFunction[A, F[A1]]): F[A1] = + M.flatMapOrKeep[A, A1](fa)(pfa) } diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index c212fc7038..c694f854e3 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -161,6 +161,20 @@ trait Monad[F[_]] extends FlatMap[F] with Applicative[F] { tailRecM(branches.toList)(step) } + + /** + * Modifies the `A` value in `F[A]` with the supplied function, if the function is defined for the value. + * Example: + * {{{ + * scala> import cats.syntax.all._ + * + * scala> List(1, 2, 3).flatMapOrKeep{ case 2 => List(2, 22, 222) } + * res0: List[Int] = List(1, 2, 22, 222, 3) + * }}} + */ + def flatMapOrKeep[A, A1 >: A](fa: F[A])(pfa: PartialFunction[A, F[A1]]): F[A1] = + flatMap(fa)(a => pfa.applyOrElse(a, pure[A1])) + } object Monad { diff --git a/laws/src/main/scala/cats/laws/MonadLaws.scala b/laws/src/main/scala/cats/laws/MonadLaws.scala index 3ed6ad15ed..38ca928247 100644 --- a/laws/src/main/scala/cats/laws/MonadLaws.scala +++ b/laws/src/main/scala/cats/laws/MonadLaws.scala @@ -57,6 +57,12 @@ trait MonadLaws[F[_]] extends ApplicativeLaws[F] with FlatMapLaws[F] { def mapFlatMapCoherence[A, B](fa: F[A], f: A => B): IsEq[F[B]] = fa.flatMap(a => F.pure(f(a))) <-> fa.map(f) + /** + * Make sure that flatMapOrKeep and flatMap are consistent. + */ + def flatMapOrKeepToFlatMapCoherence[A, A1 >: A](fa: F[A], pfa: PartialFunction[A, F[A1]]): IsEq[F[A1]] = + F.flatMapOrKeep[A, A1](fa)(pfa) <-> F.flatMap(fa)(a => pfa.applyOrElse(a, F.pure[A1])) + lazy val tailRecMStackSafety: IsEq[F[Int]] = { val n = 50000 val res = F.tailRecM(0)(i => F.pure(if (i < n) Either.left(i + 1) else Either.right(i))) diff --git a/laws/src/main/scala/cats/laws/discipline/MonadTests.scala b/laws/src/main/scala/cats/laws/discipline/MonadTests.scala index bd7a2451e5..0bb727bd0d 100644 --- a/laws/src/main/scala/cats/laws/discipline/MonadTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/MonadTests.scala @@ -55,7 +55,8 @@ trait MonadTests[F[_]] extends ApplicativeTests[F] with FlatMapTests[F] { Seq( "monad left identity" -> forAll(laws.monadLeftIdentity[A, B] _), "monad right identity" -> forAll(laws.monadRightIdentity[A] _), - "map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _) + "map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _), + "flatMapOrKeep flatMap coherence" -> forAll(laws.flatMapOrKeepToFlatMapCoherence[A, A] _) ) ++ (if (Platform.isJvm) Seq[(String, Prop)]("tailRecM stack safety" -> Prop.lzy(laws.tailRecMStackSafety)) else Seq.empty) } @@ -84,7 +85,8 @@ trait MonadTests[F[_]] extends ApplicativeTests[F] with FlatMapTests[F] { Seq( "monad left identity" -> forAll(laws.monadLeftIdentity[A, B] _), "monad right identity" -> forAll(laws.monadRightIdentity[A] _), - "map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _) + "map flatMap coherence" -> forAll(laws.mapFlatMapCoherence[A, B] _), + "flatMapOrKeep flatMap coherence" -> forAll(laws.flatMapOrKeepToFlatMapCoherence[A, A] _) ) } }