Skip to content

Commit

Permalink
Fix HK quoted pattern type variables (scala#16980)
Browse files Browse the repository at this point in the history
The issue was in the encoding into
`{ExprMatchModule,TypeMatchModule}.unapply`. Specifically with the
`TypeBindings` argument. This arguments holds the list of type variable
definitions (`tpd.Bind` trees). We used a `Tuple` to list all the types
inside. The problem is that higher-kinded type variables do not conform
with the upper bounds of the tuple elements. The solution is to use an
HList with any-kinded elements.

Backport of scala#16907
  • Loading branch information
Kordyjan authored Feb 21, 2023
2 parents b3c1c98 + 014be6f commit aecbfa7
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 11 deletions.
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

/** Creates the tuple type tree repesentation of the type trees in `ts` */
/** Creates the tuple type tree representation of the type trees in `ts` */
def tupleTypeTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
if arity <= Definitions.MaxTupleArity then
Expand All @@ -1509,10 +1509,14 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else nestedPairsTypeTree(elems)
}

/** Creates the nested pairs type tree repesentation of the type trees in `ts` */
/** Creates the nested pairs type tree representation of the type trees in `ts` */
def nestedPairsTypeTree(ts: List[Tree])(using Context): Tree =
ts.foldRight[Tree](TypeTree(defn.EmptyTupleModule.termRef))((x, acc) => AppliedTypeTree(TypeTree(defn.PairClass.typeRef), x :: acc :: Nil))

/** Creates the nested higher-kinded pairs type tree representation of the type trees in `ts` */
def hkNestedPairsTypeTree(ts: List[Tree])(using Context): Tree =
ts.foldRight[Tree](TypeTree(defn.QuoteMatching_KNil.typeRef))((x, acc) => AppliedTypeTree(TypeTree(defn.QuoteMatching_KCons.typeRef), x :: acc :: Nil))

/** Replaces all positions in `tree` with zero-extent positions */
private def focusPositions(tree: Tree)(using Context): Tree = {
val transformer = new tpd.TreeMap {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,9 @@ class Definitions {
@tu lazy val QuoteMatching_ExprMatchModule: Symbol = QuoteMatchingClass.requiredClass("ExprMatchModule")
@tu lazy val QuoteMatching_TypeMatch: Symbol = QuoteMatchingClass.requiredMethod("TypeMatch")
@tu lazy val QuoteMatching_TypeMatchModule: Symbol = QuoteMatchingClass.requiredClass("TypeMatchModule")
@tu lazy val QuoteMatchingModule: Symbol = requiredModule("scala.quoted.runtime.QuoteMatching")
@tu lazy val QuoteMatching_KNil: Symbol = QuoteMatchingModule.requiredType("KNil")
@tu lazy val QuoteMatching_KCons: Symbol = QuoteMatchingModule.requiredType("KCons")

@tu lazy val ToExprModule: Symbol = requiredModule("scala.quoted.ToExpr")
@tu lazy val ToExprModule_BooleanToExpr: Symbol = ToExprModule.requiredMethod("BooleanToExpr")
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ trait QuotesAndSplices {
*
* ```
* case scala.internal.quoted.Expr.unapply[
* Tuple1[t @ _], // Type binging definition
* KList[t @ _, KNil], // Type binging definition
* Tuple2[Type[t], Expr[List[t]]] // Typing the result of the pattern match
* ](
* Tuple2.unapply
Expand Down Expand Up @@ -411,7 +411,7 @@ trait QuotesAndSplices {
val replaceBindings = new ReplaceBindings
val patType = defn.tupleType(splices.tpes.map(tpe => replaceBindings(tpe.widen)))

val typeBindingsTuple = tpd.tupleTypeTree(typeBindings.values.toList)
val typeBindingsTuple = tpd.hkNestedPairsTypeTree(typeBindings.values.toList)

val replaceBindingsInTree = new TreeMap {
private var bindMap = Map.empty[Symbol, Symbol]
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ object QuoteMatcher {

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)

def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Tuple] =
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] =
given Env = Map.empty
scrutineeTerm =?= patternTerm
scrutineeTree =?= patternTree

/** Check that all trees match with `mtch` and concatenate the results with &&& */
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3093,14 +3093,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
new TypeImpl(tree, SpliceScope.getCurrent).asInstanceOf[scala.quoted.Type[T]]

object ExprMatch extends ExprMatchModule:
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: scala.quoted.Expr[Any])(using pattern: scala.quoted.Expr[Any]): Option[Tup] =
def unapply[TypeBindings, Tup <: Tuple](scrutinee: scala.quoted.Expr[Any])(using pattern: scala.quoted.Expr[Any]): Option[Tup] =
val scrutineeTree = reflect.asTerm(scrutinee)
val patternTree = reflect.asTerm(pattern)
treeMatch(scrutineeTree, patternTree).asInstanceOf[Option[Tup]]
end ExprMatch

object TypeMatch extends TypeMatchModule:
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: scala.quoted.Type[?])(using pattern: scala.quoted.Type[?]): Option[Tup] =
def unapply[TypeBindings, Tup <: Tuple](scrutinee: scala.quoted.Type[?])(using pattern: scala.quoted.Type[?]): Option[Tup] =
val scrutineeTree = reflect.TypeTree.of(using scrutinee)
val patternTree = reflect.TypeTree.of(using pattern)
treeMatch(scrutineeTree, patternTree).asInstanceOf[Option[Tup]]
Expand Down
11 changes: 8 additions & 3 deletions library/src/scala/quoted/runtime/QuoteMatching.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trait QuoteMatching:
* - `ExprMatch.unapply('{ f(0, myInt) })('{ f(patternHole[Int], patternHole[Int]) }, _)`
* will return `Some(Tuple2('{0}, '{ myInt }))`
* - `ExprMatch.unapply('{ f(0, "abc") })('{ f(0, patternHole[Int]) }, _)`
* will return `None` due to the missmatch of types in the hole
* will return `None` due to the mismatch of types in the hole
*
* Holes:
* - scala.quoted.runtime.Patterns.patternHole[T]: hole that matches an expression `x` of type `Expr[U]`
Expand All @@ -27,7 +27,7 @@ trait QuoteMatching:
* @param pattern `Expr[Any]` containing the pattern tree
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
*/
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: Expr[Any])(using pattern: Expr[Any]): Option[Tup]
def unapply[TypeBindings, Tup <: Tuple](scrutinee: Expr[Any])(using pattern: Expr[Any]): Option[Tup]
}

val TypeMatch: TypeMatchModule
Expand All @@ -40,5 +40,10 @@ trait QuoteMatching:
* @param pattern `Type[?]` containing the pattern tree
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Type[Ti]``
*/
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: Type[?])(using pattern: Type[?]): Option[Tup]
def unapply[TypeBindings, Tup <: Tuple](scrutinee: Type[?])(using pattern: Type[?]): Option[Tup]
}

object QuoteMatching:
type KList
type KCons[+H <: AnyKind, +T <: KList] <: KList
type KNil <: KList
17 changes: 17 additions & 0 deletions tests/pos-macros/hk-quoted-type-patterns/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import scala.quoted._

private def impl(x: Expr[Any])(using Quotes): Expr[Unit] = {
x match
case '{ foo[x] } =>
assert(Type.show[x] == "scala.Int", Type.show[x])
case '{ type f[X]; foo[`f`] } =>
assert(Type.show[f] == "[A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A]", Type.show[f])
case '{ type f <: AnyKind; foo[`f`] } =>
assert(Type.show[f] == "[K >: scala.Nothing <: scala.Any, V >: scala.Nothing <: scala.Any] => scala.collection.immutable.Map[K, V]", Type.show[f])
case x => throw MatchError(x.show)
'{}
}

inline def test(inline x: Any): Unit = ${ impl('x) }

def foo[T <: AnyKind]: Any = ???
5 changes: 5 additions & 0 deletions tests/pos-macros/hk-quoted-type-patterns/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@main
def Test =
test(foo[Int])
test(foo[List])
test(foo[Map])

0 comments on commit aecbfa7

Please sign in to comment.