Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HK quoted pattern type variables #16980

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

/** Creates the tuple type tree repesentation of the type trees in `ts` */
/** Creates the tuple type tree representation of the type trees in `ts` */
def tupleTypeTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
if arity <= Definitions.MaxTupleArity then
Expand All @@ -1509,10 +1509,14 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else nestedPairsTypeTree(elems)
}

/** Creates the nested pairs type tree repesentation of the type trees in `ts` */
/** Creates the nested pairs type tree representation of the type trees in `ts` */
def nestedPairsTypeTree(ts: List[Tree])(using Context): Tree =
ts.foldRight[Tree](TypeTree(defn.EmptyTupleModule.termRef))((x, acc) => AppliedTypeTree(TypeTree(defn.PairClass.typeRef), x :: acc :: Nil))

/** Creates the nested higher-kinded pairs type tree representation of the type trees in `ts` */
def hkNestedPairsTypeTree(ts: List[Tree])(using Context): Tree =
ts.foldRight[Tree](TypeTree(defn.QuoteMatching_KNil.typeRef))((x, acc) => AppliedTypeTree(TypeTree(defn.QuoteMatching_KCons.typeRef), x :: acc :: Nil))

/** Replaces all positions in `tree` with zero-extent positions */
private def focusPositions(tree: Tree)(using Context): Tree = {
val transformer = new tpd.TreeMap {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,9 @@ class Definitions {
@tu lazy val QuoteMatching_ExprMatchModule: Symbol = QuoteMatchingClass.requiredClass("ExprMatchModule")
@tu lazy val QuoteMatching_TypeMatch: Symbol = QuoteMatchingClass.requiredMethod("TypeMatch")
@tu lazy val QuoteMatching_TypeMatchModule: Symbol = QuoteMatchingClass.requiredClass("TypeMatchModule")
@tu lazy val QuoteMatchingModule: Symbol = requiredModule("scala.quoted.runtime.QuoteMatching")
@tu lazy val QuoteMatching_KNil: Symbol = QuoteMatchingModule.requiredType("KNil")
@tu lazy val QuoteMatching_KCons: Symbol = QuoteMatchingModule.requiredType("KCons")

@tu lazy val ToExprModule: Symbol = requiredModule("scala.quoted.ToExpr")
@tu lazy val ToExprModule_BooleanToExpr: Symbol = ToExprModule.requiredMethod("BooleanToExpr")
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ trait QuotesAndSplices {
*
* ```
* case scala.internal.quoted.Expr.unapply[
* Tuple1[t @ _], // Type binging definition
* KList[t @ _, KNil], // Type binging definition
* Tuple2[Type[t], Expr[List[t]]] // Typing the result of the pattern match
* ](
* Tuple2.unapply
Expand Down Expand Up @@ -411,7 +411,7 @@ trait QuotesAndSplices {
val replaceBindings = new ReplaceBindings
val patType = defn.tupleType(splices.tpes.map(tpe => replaceBindings(tpe.widen)))

val typeBindingsTuple = tpd.tupleTypeTree(typeBindings.values.toList)
val typeBindingsTuple = tpd.hkNestedPairsTypeTree(typeBindings.values.toList)

val replaceBindingsInTree = new TreeMap {
private var bindMap = Map.empty[Symbol, Symbol]
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ object QuoteMatcher {

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)

def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Tuple] =
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] =
given Env = Map.empty
scrutineeTerm =?= patternTerm
scrutineeTree =?= patternTree

/** Check that all trees match with `mtch` and concatenate the results with &&& */
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3093,14 +3093,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
new TypeImpl(tree, SpliceScope.getCurrent).asInstanceOf[scala.quoted.Type[T]]

object ExprMatch extends ExprMatchModule:
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: scala.quoted.Expr[Any])(using pattern: scala.quoted.Expr[Any]): Option[Tup] =
def unapply[TypeBindings, Tup <: Tuple](scrutinee: scala.quoted.Expr[Any])(using pattern: scala.quoted.Expr[Any]): Option[Tup] =
val scrutineeTree = reflect.asTerm(scrutinee)
val patternTree = reflect.asTerm(pattern)
treeMatch(scrutineeTree, patternTree).asInstanceOf[Option[Tup]]
end ExprMatch

object TypeMatch extends TypeMatchModule:
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: scala.quoted.Type[?])(using pattern: scala.quoted.Type[?]): Option[Tup] =
def unapply[TypeBindings, Tup <: Tuple](scrutinee: scala.quoted.Type[?])(using pattern: scala.quoted.Type[?]): Option[Tup] =
val scrutineeTree = reflect.TypeTree.of(using scrutinee)
val patternTree = reflect.TypeTree.of(using pattern)
treeMatch(scrutineeTree, patternTree).asInstanceOf[Option[Tup]]
Expand Down
11 changes: 8 additions & 3 deletions library/src/scala/quoted/runtime/QuoteMatching.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trait QuoteMatching:
* - `ExprMatch.unapply('{ f(0, myInt) })('{ f(patternHole[Int], patternHole[Int]) }, _)`
* will return `Some(Tuple2('{0}, '{ myInt }))`
* - `ExprMatch.unapply('{ f(0, "abc") })('{ f(0, patternHole[Int]) }, _)`
* will return `None` due to the missmatch of types in the hole
* will return `None` due to the mismatch of types in the hole
*
* Holes:
* - scala.quoted.runtime.Patterns.patternHole[T]: hole that matches an expression `x` of type `Expr[U]`
Expand All @@ -27,7 +27,7 @@ trait QuoteMatching:
* @param pattern `Expr[Any]` containing the pattern tree
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
*/
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: Expr[Any])(using pattern: Expr[Any]): Option[Tup]
def unapply[TypeBindings, Tup <: Tuple](scrutinee: Expr[Any])(using pattern: Expr[Any]): Option[Tup]
}

val TypeMatch: TypeMatchModule
Expand All @@ -40,5 +40,10 @@ trait QuoteMatching:
* @param pattern `Type[?]` containing the pattern tree
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Type[Ti]``
*/
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutinee: Type[?])(using pattern: Type[?]): Option[Tup]
def unapply[TypeBindings, Tup <: Tuple](scrutinee: Type[?])(using pattern: Type[?]): Option[Tup]
}

object QuoteMatching:
type KList
type KCons[+H <: AnyKind, +T <: KList] <: KList
type KNil <: KList
17 changes: 17 additions & 0 deletions tests/pos-macros/hk-quoted-type-patterns/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import scala.quoted._

private def impl(x: Expr[Any])(using Quotes): Expr[Unit] = {
x match
case '{ foo[x] } =>
assert(Type.show[x] == "scala.Int", Type.show[x])
case '{ type f[X]; foo[`f`] } =>
assert(Type.show[f] == "[A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A]", Type.show[f])
case '{ type f <: AnyKind; foo[`f`] } =>
assert(Type.show[f] == "[K >: scala.Nothing <: scala.Any, V >: scala.Nothing <: scala.Any] => scala.collection.immutable.Map[K, V]", Type.show[f])
case x => throw MatchError(x.show)
'{}
}

inline def test(inline x: Any): Unit = ${ impl('x) }

def foo[T <: AnyKind]: Any = ???
5 changes: 5 additions & 0 deletions tests/pos-macros/hk-quoted-type-patterns/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@main
def Test =
test(foo[Int])
test(foo[List])
test(foo[Map])