diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 117a28842d..79a62c73c8 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -106,6 +106,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) private def onLeftParen(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Replacement = { val rt = ft.right @@ -137,7 +138,15 @@ class RedundantBraces(implicit val ftoks: FormatTokens) case _ => None } - lpFunction.orElse(lpPartialFunction).orNull + val repl = lpFunction.orElse(lpPartialFunction).orNull + (rtOwner match { + case ac: Term.ArgClause + if repl != null && repl.ft.right.is[Token.LeftBrace] => + session.rule[RemoveScala3OptionalBraces].flatMap { r => + Option(r.onLeftForArgClause(ac, None)) + } + case _ => None + }).getOrElse(repl) } private def onRightParen( @@ -168,7 +177,8 @@ class RedundantBraces(implicit val ftoks: FormatTokens) owner match { case t: Term.FunctionTerm if t.tokens.last.is[Token.RightBrace] => if (!okToRemoveFunctionInApplyOrInit(t)) null - else if (okToReplaceFunctionInSingleArgApply(t)) replaceWithLeftParen + else if (okToReplaceFunctionInSingleArgApply(t)) + handleFuncInSingleArgApply(t) else removeToken case t: Term.PartialFunction if t.parent.exists { p => SingleArgInBraces.orBlock(p).contains(t) && @@ -179,7 +189,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) t.parent match { case Some(f: Term.FunctionTerm) if okToReplaceFunctionInSingleArgApply(f) => - replaceWithLeftParen + handleFuncInSingleArgApply(f) case Some(_: Term.Interpolate) => handleInterpolation case _ => if (processBlock(t)) removeToken else null @@ -553,4 +563,20 @@ class RedundantBraces(implicit val ftoks: FormatTokens) ftoks(t.name.tokens.head, -1).left.is[Token.Dot] } + private def handleFuncInSingleArgApply( + f: Term.FunctionTerm + )(implicit ft: FormatToken, session: Session): Replacement = + f.parent match { + case Some(ac: Term.ArgClause) if { + val acFt = ftoks.tokenJustBefore(ac) + acFt.right.is[Token.LeftParen] && + session.claimedRule(acFt).exists { x => + x.ft.right.is[Token.Colon] && + x.how == ReplacementType.Replace + } + } => + removeToken + case _ => replaceWithLeftParen + } + } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala index 2cafb8c028..57f1a5b7cc 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala @@ -1,5 +1,9 @@ package org.scalafmt.rewrite +import scala.annotation.tailrec +import scala.collection.mutable +import scala.reflect.ClassTag + import scala.meta._ import scala.meta.tokens.Token @@ -47,6 +51,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) if (t.parent.exists(_.is[Defn.Given])) removeToken else replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start)) + case t: Term.ArgClause => onLeftForArgClause(t, Some(false)) + case t: Term.PartialFunction => + t.parent match { + case Some(p: Term.ArgClause) if (p.tokens.head match { + case px: Token.LeftBrace => px eq x + case px: Token.LeftParen => + shouldRewriteArgClauseWithLeftParen[RedundantBraces](px) + case _ => false + }) => + onLeftForArgClause(p, sharesBracesWithArg = Some(true)) + case _ => null + } case _: Term.For if allowOldSyntax || { val rbFt = ftoks(ftoks.matching(ft.right)) ftoks.nextNonComment(rbFt).right.is[Token.KwDo] @@ -83,6 +99,8 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) }) ft.right match { case _ if notOkToRewrite => None + case _: Token.RightParen => + Some((left, removeToken)) case x: Token.RightBrace => val replacement = ft.meta.rightOwner match { case _: Term.For if allowOldSyntax && !nextFt.right.is[Token.KwDo] => @@ -94,9 +112,11 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) } } - private def onLeftForBlock( - tree: Term.Block - )(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = + private def onLeftForBlock(tree: Term.Block)(implicit + ft: FormatToken, + session: Session, + style: ScalafmtConfig + ): Replacement = tree.parent.fold(null: Replacement) { case t: Term.If => val ok = ftoks.prevNonComment(ft).left match { @@ -134,7 +154,73 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) else if (ftoks.prevNonComment(ft).left.is[Token.Equals]) removeToken else null case p: Tree.WithBody => if (p.body eq tree) removeToken else null + case p: Term.ArgClause => + p.tokens.head match { + case px: Token.LeftBrace => + onLeftForArgClause(p, Some(px eq ft.right)) + case px: Token.LeftParen + if shouldRewriteArgClauseWithLeftParen[RedundantParens](px) => + onLeftForArgClause(p, Some(true)) + case _ => null + } case _ => null } + private def shouldRewriteArgClauseWithLeftParen[A <: Rule]( + lp: Token + )(implicit ft: FormatToken, session: Session, tag: ClassTag[A]) = { + val prevFt = ftoks.prevNonComment(ft) + prevFt.left.eq(lp) && session + .claimedRule(prevFt.meta.idx - 1) + .exists(x => tag.runtimeClass.isInstance(x.rule)) + } + + def onLeftForArgClause( + tree: Term.ArgClause, + sharesBracesWithArg: Option[Boolean] + )(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = { + val rob = style.rewrite.scala3.removeOptionalBraces + val maxStats = rob.fewerBracesMaxStats + if (maxStats == 0) return null + @tailrec + def checkCountInRange(queue: mutable.Queue[Tree], cnt: Int): Boolean = + if (cnt > maxStats) false + else if (queue.isEmpty) cnt >= rob.fewerBracesMinStats + else { + val next = queue.dequeue() + def enqueue(trees: Iterable[Tree]*): Int = { + val len = queue.length + trees.foreach(queue ++= _) + queue.length - len + } + val delta = next match { + case x: Term.Block => enqueue(x.stats) + case x: Tree.WithBody => enqueue(x.body :: Nil) - 1 + case x: Tree.WithCases => enqueue(x.cases) + case x: Tree.WithEnums => enqueue(x.enums) + case x: Stat.WithTemplate => enqueue(x.templ.early, x.templ.stats) + case _ => 0 + } + checkCountInRange(queue, cnt + delta) + } + def ok = tree.values match { + case arg :: Nil => + val queue = mutable.Queue.empty[Tree] + queue += arg + val ignoreArg = sharesBracesWithArg + .getOrElse(arg.tokens.headOption.contains(tree.tokens.head)) + checkCountInRange(queue, if (ignoreArg) 0 else 1) + case _ => false + } + tree.parent match { + case Some(t: Term.Apply) if (t.parent match { + case Some(pp: Term.Apply) if pp.fun eq t => false + case _ => style.dialect.allowFewerBraces && ok + }) => + val x = ft.right // `{` or `(` + replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start)) + case _ => null + } + } + } diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat index 333a0b8f69..257f0a987b 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat @@ -1827,10 +1827,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1861,10 +1860,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1902,10 +1900,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1943,10 +1940,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1977,10 +1973,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2018,10 +2013,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2059,11 +2053,10 @@ foo bar match case x => x + 1 } - .mtd2 { - bar match - case x => x + 1 - case y => y + 1 - } + .mtd2: + bar match + case x => x + 1 + case y => y + 1 .mtd3 { bar match case x => x + 1 @@ -2106,12 +2099,11 @@ foo def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x = diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat index fe036468f5..a868064bff 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat @@ -1598,10 +1598,10 @@ foo x + 3 } >>> -foo.mtd1 { x + 1 }.mtd2 { - x + 1 - x + 2 -}.mtd3 { +foo.mtd1 { x + 1 }.mtd2: + x + 1 + x + 2 +.mtd3 { x + 1 x + 2 x + 3 @@ -1627,10 +1627,10 @@ foo x + 3 } >>> -foo.mtd1 { x => x + 1 }.mtd2 { x => +foo.mtd1 { x => x + 1 }.mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1663,10 +1663,10 @@ foo } ) >>> -foo.mtd1(x => x + 1).mtd2 { x => +foo.mtd1(x => x + 1).mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1699,10 +1699,10 @@ foo } ) >>> -foo.mtd1(x => x + 1).mtd2 { x => +foo.mtd1(x => x + 1).mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1728,10 +1728,10 @@ foo case z => z + 1 } >>> -foo.mtd1 { case x => x + 1 }.mtd2 { - case x => x + 1 - case y => y + 1 -}.mtd3 { +foo.mtd1 { case x => x + 1 }.mtd2: + case x => x + 1 + case y => y + 1 +.mtd3 { case x => x + 1 case y => y + 1 case z => z + 1 @@ -1764,10 +1764,10 @@ foo } ) >>> -foo.mtd1 { case x => x + 1 }.mtd2 { - case x => x + 1 - case y => y + 1 -}.mtd3 { +foo.mtd1 { case x => x + 1 }.mtd2: + case x => x + 1 + case y => y + 1 +.mtd3 { case x => x + 1 case y => y + 1 case z => z + 1 @@ -1802,11 +1802,11 @@ foo foo.mtd1 { bar match case x => x + 1 -}.mtd2 { - bar match - case x => x + 1 - case y => y + 1 -}.mtd3 { +}.mtd2: + bar match + case x => x + 1 + case y => y + 1 +.mtd3 { bar match case x => x + 1 case y => y + 1 @@ -1845,12 +1845,12 @@ foo foo.mtd1 { x + 1 def x = x + 3 -}.mtd2 { - x + 1 - def x = - x + 3 - x + 4 -}.mtd3 { +}.mtd2: + x + 1 + def x = + x + 3 + x + 4 +.mtd3 { x + 1 def x = x + 3 diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat index 48c3ac3d5f..24d4bb7490 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat @@ -1795,10 +1795,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1829,10 +1828,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1871,11 +1869,9 @@ foo x => x + 1 } - .mtd2 { - x => - x + 1 - x + 2 - } + .mtd2: x => + x + 1 + x + 2 .mtd3 { x => x + 1 @@ -1914,10 +1910,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1948,10 +1943,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -1989,10 +1983,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2030,11 +2023,10 @@ foo bar match case x => x + 1 } - .mtd2 { - bar match - case x => x + 1 - case y => y + 1 - } + .mtd2: + bar match + case x => x + 1 + case y => y + 1 .mtd3 { bar match case x => x + 1 @@ -2077,12 +2069,11 @@ foo def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x = diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat index 44d6a53b4b..2d8f8ced51 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat @@ -1818,10 +1818,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1852,10 +1851,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1893,10 +1891,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1934,10 +1931,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1968,12 +1964,11 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + case x => + x + 1 + case y => + y + 1 .mtd3 { case x => x + 1 @@ -2014,12 +2009,11 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + case x => + x + 1 + case y => + y + 1 .mtd3 { case x => x + 1 @@ -2061,13 +2055,12 @@ foo case x => x + 1 } - .mtd2 { - bar match - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + bar match + case x => + x + 1 + case y => + y + 1 .mtd3 { bar match case x => @@ -2112,12 +2105,11 @@ foo x + 1 def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x =