Skip to content

Commit

Permalink
Quotes type printing: take infix type modifier into account (#21726)
Browse files Browse the repository at this point in the history
This is similar to how the regular compiler `.show` handles `infix` but
using explicit parens everywhere to not have to reimplement the
precedence logic (maybe quote type printing should just use `.show`
eventually).
  • Loading branch information
smarter authored Oct 8, 2024
2 parents 2023c5d + 936c009 commit 35c7d74
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 2 deletions.
15 changes: 13 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1150,8 +1150,19 @@ object SourceCode {
case tp: TypeRef if tp.typeSymbol == Symbol.requiredClass("scala.<repeated>") =>
this += "_*"
case _ =>
printType(tp)
inSquare(printTypesOrBounds(args, ", "))
if !fullNames && args.lengthCompare(2) == 0 && tp.typeSymbol.flags.is(Flags.Infix) then
val lhs = args(0)
val rhs = args(1)
this += "("
printType(lhs)
this += " "
printType(tp)
this += " "
printType(rhs)
this += ")"
else
printType(tp)
inSquare(printTypesOrBounds(args, ", "))
}

case AnnotatedType(tp, annot) =>
Expand Down
12 changes: 12 additions & 0 deletions tests/run-macros/type-print.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
List[Int]
scala.collection.immutable.List[scala.Int]
scala.collection.immutable.List[scala.Int]
AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "immutable")), "List"), List(TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int")))
(3 + (a * b))
scala.compiletime.ops.int.+[3, scala.compiletime.ops.int.*[a, b]]
scala.compiletime.ops.int.+[3, scala.compiletime.ops.int.*[a, b]]
AppliedType(TypeRef(TermRef(TermRef(TermRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "compiletime"), "ops"), "int"), "+"), List(ConstantType(IntConstant(3)), AppliedType(TypeRef(TermRef(TermRef(TermRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "compiletime"), "ops"), "int"), "*"), List(TermRef(NoPrefix(), "a"), TermRef(NoPrefix(), "b")))))
((3 + a) * b)
scala.compiletime.ops.int.*[scala.compiletime.ops.int.+[3, a], b]
scala.compiletime.ops.int.*[scala.compiletime.ops.int.+[3, a], b]
AppliedType(TypeRef(TermRef(TermRef(TermRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "compiletime"), "ops"), "int"), "*"), List(AppliedType(TypeRef(TermRef(TermRef(TermRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "compiletime"), "ops"), "int"), "+"), List(ConstantType(IntConstant(3)), TermRef(NoPrefix(), "a"))), TermRef(NoPrefix(), "b")))
29 changes: 29 additions & 0 deletions tests/run-macros/type-print/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import scala.quoted.*

inline def printTypeShort[T]: String =
${ printTypeShortImpl[T] }

inline def printType[T]: String =
${ printTypeImpl[T] }

inline def printTypeAnsi[T]: String =
${ printTypeAnsiImpl[T] }

inline def printTypeStructure[T]: String =
${ printTypeStructureImpl[T] }

def printTypeShortImpl[T: Type](using Quotes): Expr[String] =
import quotes.reflect.*
Expr(Printer.TypeReprShortCode.show(TypeRepr.of[T]))

def printTypeImpl[T: Type](using Quotes): Expr[String] =
import quotes.reflect.*
Expr(Printer.TypeReprCode.show(TypeRepr.of[T]))

def printTypeAnsiImpl[T: Type](using Quotes): Expr[String] =
import quotes.reflect.*
Expr(Printer.TypeReprAnsiCode.show(TypeRepr.of[T]))

def printTypeStructureImpl[T: Type](using Quotes): Expr[String] =
import quotes.reflect.*
Expr(Printer.TypeReprStructure.show(TypeRepr.of[T]))
15 changes: 15 additions & 0 deletions tests/run-macros/type-print/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.compiletime.ops.int.*

inline def printAll[T]: Unit =
println(printTypeShort[T])
println(printType[T])
println(printTypeAnsi[T])
println(printTypeStructure[T])

@main
def Test: Unit =
printAll[List[Int]]
val a = 1
val b = 2
printAll[3 + a.type * b.type]
printAll[(3 + a.type) * b.type]

0 comments on commit 35c7d74

Please sign in to comment.