Skip to content

Commit

Permalink
Help implement Metals' infer expected type feature (scala#21390)
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Aug 22, 2024
2 parents afcb0ad + 43fc10c commit d490d13
Show file tree
Hide file tree
Showing 15 changed files with 666 additions and 61 deletions.
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3268,9 +3268,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

/** The trace of comparison operations when performing `op` */
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean)(using Context): String =
val cmp = explainingTypeComparer(short)
inSubComparer(cmp)(op)
cmp.lastTrace(header)
explaining(cmp => { op(cmp); cmp.lastTrace(header) }, short)

def explaining[T](op: ExplainingTypeComparer => T, short: Boolean)(using Context): T =
inSubComparer(explainingTypeComparer(short))(op)

def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
inSubComparer(matchReducer)(op)
Expand Down Expand Up @@ -3440,6 +3441,9 @@ object TypeComparer {
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
comparing(_.explained(op, header, short))

def explaining[T](op: ExplainingTypeComparer => T, short: Boolean = false)(using Context): T =
comparing(_.explaining(op, short))

def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
comparing(_.reduceMatchWith(op))

Expand Down Expand Up @@ -3871,7 +3875,7 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa
override def recur(tp1: Type, tp2: Type): Boolean =
def moreInfo =
if Config.verboseExplainSubtype || ctx.settings.verbose.value
then s" ${tp1.getClass} ${tp2.getClass}"
then s" ${tp1.className} ${tp2.className}"
else ""
val approx = approxState
def approxStr = if short then "" else approx.show
Expand Down
19 changes: 15 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,22 @@ object TypeOps:
val hiBound = instantiate(bounds.hi, skolemizedArgTypes)
val loBound = instantiate(bounds.lo, skolemizedArgTypes)

def check(using Context) = {
if (!(lo <:< hiBound)) violations += ((arg, "upper", hiBound))
if (!(loBound <:< hi)) violations += ((arg, "lower", loBound))
def check(tp1: Type, tp2: Type, which: String, bound: Type)(using Context) = {
val isSub = TypeComparer.explaining { cmp =>
val isSub = cmp.isSubType(tp1, tp2)
if !isSub then
if !ctx.typerState.constraint.domainLambdas.isEmpty then
typr.println(i"${ctx.typerState.constraint}")
if !ctx.gadt.symbols.isEmpty then
typr.println(i"${ctx.gadt}")
typr.println(cmp.lastTrace(i"checkOverlapsBounds($lo, $hi, $arg, $bounds)($which)"))
//trace.dumpStack()
isSub
}//(using ctx.fresh.setSetting(ctx.settings.verbose, true)) // uncomment to enable moreInfo in ExplainingTypeComparer recur
if !isSub then violations += ((arg, which, bound))
}
check(using checkCtx)
check(lo, hiBound, "upper", hiBound)(using checkCtx)
check(loBound, hi, "lower", loBound)(using checkCtx)
}

def loop(args: List[Tree], boundss: List[TypeBounds]): Unit = args match
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/trace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ object trace extends TraceSyntax:
object log extends TraceSyntax:
inline def isEnabled: true = true
protected val isForced = false

def dumpStack(limit: Int = -1): Unit = {
val out = Console.out
val exc = new Exception("Dump Stack")
var stack = exc.getStackTrace
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.TraceSyntax"))
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.trace"))
if limit >= 0 then
stack = stack.take(limit)
exc.setStackTrace(stack)
exc.printStackTrace(out)
}
end trace

/** This module is carefully optimized to give zero overhead if Config.tracingEnabled
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ trait Applications extends Compatibility {
fail(TypeMismatch(methType.resultType, resultType, None))

// match all arguments with corresponding formal parameters
matchArgs(orderedArgs, methType.paramInfos, 0)
if success then matchArgs(orderedArgs, methType.paramInfos, 0)
case _ =>
if (methType.isError) ok = false
else fail(em"$methString does not take parameters")
Expand Down Expand Up @@ -666,7 +666,7 @@ trait Applications extends Compatibility {
* @param n The position of the first parameter in formals in `methType`.
*/
def matchArgs(args: List[Arg], formals: List[Type], n: Int): Unit =
if (success) formals match {
formals match {
case formal :: formals1 =>

def checkNoVarArg(arg: Arg) =
Expand Down Expand Up @@ -878,7 +878,9 @@ trait Applications extends Compatibility {
init()

def addArg(arg: Tree, formal: Type): Unit =
typedArgBuf += adapt(arg, formal.widenExpr)
val typedArg = adapt(arg, formal.widenExpr)
typedArgBuf += typedArg
ok = ok & !typedArg.tpe.isError

def makeVarArg(n: Int, elemFormal: Type): Unit = {
val args = typedArgBuf.takeRight(n).toList
Expand Down Expand Up @@ -943,7 +945,7 @@ trait Applications extends Compatibility {
var typedArgs = typedArgBuf.toList
def app0 = cpy.Apply(app)(normalizedFun, typedArgs) // needs to be a `def` because typedArgs can change later
val app1 =
if (!success || typedArgs.exists(_.tpe.isError)) app0.withType(UnspecifiedErrorType)
if !success then app0.withType(UnspecifiedErrorType)
else {
if isJavaAnnotConstr(methRef.symbol) then
// #19951 Make sure all arguments are NamedArgs for Java annotations
Expand Down
55 changes: 35 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,12 @@ object Inferencing {
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
instDecision(tvar, variance, minimizeSelected, force.ifBottom) match
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
case Decision.Fail => fail = true
case Decision.ToMax => toMaximize ::= tvar
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
Expand Down Expand Up @@ -452,9 +439,32 @@ object Inferencing {
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
val approxAbove =
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
//println(i"instDirection($param) = $approxAbove - $approxBelow original=[$original] constrained=[$constrained]")
approxAbove - approxBelow
}

/** The instantiation decision for given poly param computed from the constraint. */
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
private def instDecision(tvar: TypeVar, v: Int, minimizeSelected: Boolean, ifBottom: IfBottom)(using Context): Decision =
import Decision.*
val direction = instDirection(tvar.origin)
val dec = if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then Min
else if direction >= 0 && tvar.hasUpperBound then Max
else Skip
else if direction != 0 then if direction < 0 then Min else Max
else if tvar.hasLowerBound then if v >= 0 then Min else ToMax
else ifBottom match
// What's left are unconstrained tvars with at most a non-Any param upperbound:
// * IfBottom.flip will always maximise to the param upperbound, for all variances
// * IfBottom.fail will fail the IFD check, for covariant or invariant tvars, maximise contravariant tvars
// * IfBottom.ok will minimise to Nothing covariant and unbounded invariant tvars, and max to Any the others
case IfBottom.ok => if v > 0 || v == 0 && !tvar.hasUpperBound then Min else ToMax // prefer upper bound if one is given
case IfBottom.fail => if v >= 0 then Fail else ToMax
case ifBottom_flip => ToMax
//println(i"instDecision($tvar, v=v, minimizedSelected=$minimizeSelected, $ifBottom) dir=$direction = $dec")
dec

/** Following type aliases and stripping refinements and annotations, if one arrives at a
* class type reference where the class has a companion module, a reference to
* that companion module. Otherwise NoType
Expand Down Expand Up @@ -651,7 +661,7 @@ trait Inferencing { this: Typer =>

val ownedVars = state.ownedVars
if (ownedVars ne locked) && !ownedVars.isEmpty then
val qualifying = ownedVars -- locked
val qualifying = (ownedVars -- locked).toList
if (!qualifying.isEmpty) {
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
val resultAlreadyConstrained =
Expand Down Expand Up @@ -687,6 +697,10 @@ trait Inferencing { this: Typer =>

def constraint = state.constraint

trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying\n$constraint\n${ctx.gadt}") {
//println(i"$constraint")
//println(i"${ctx.gadt}")

/** Values of this type report type variables to instantiate with variance indication:
* +1 variable appears covariantly, can be instantiated from lower bound
* -1 variable appears contravariantly, can be instantiated from upper bound
Expand Down Expand Up @@ -804,6 +818,7 @@ trait Inferencing { this: Typer =>
end doInstantiate

doInstantiate(filterByDeps(toInstantiate))
}
}
end if
tree
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import config.Printers.typr
import Inferencing.*
import ErrorReporting.*
import util.SourceFile
import util.Spans.{NoSpan, Span}
import TypeComparer.necessarySubType
import reporting.*

import scala.annotation.internal.sharable
import dotty.tools.dotc.util.Spans.{NoSpan, Span}

object ProtoTypes {

Expand Down Expand Up @@ -83,6 +84,7 @@ object ProtoTypes {
* fits the given expected result type.
*/
def constrainResult(mt: Type, pt: Type)(using Context): Boolean =
trace(i"constrainResult($mt, $pt)", typr):
val savedConstraint = ctx.typerState.constraint
val res = pt.widenExpr match {
case pt: FunProto =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/util/Signatures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ object Signatures {
*
* @param err The error message to inspect.
* @param params The parameters that were given at the call site.
* @param alreadyCurried Index of paramss we are currently in.
* @param paramssIndex Index of paramss we are currently in.
*
* @return A pair composed of the index of the best alternative (0 if no alternatives
* were found), and the list of alternatives.
Expand Down
49 changes: 23 additions & 26 deletions compiler/test/dotty/tools/dotc/typer/InstantiateModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,16 @@ package typer

// Modelling the decision in IsFullyDefined
object InstantiateModel:
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
enum Var { case V; case NotV }; import Var.*
enum MSe { case M; case NotM }; import MSe.*
enum Bot { case Fail; case Ok; case Flip }; import Bot.*
enum Act { case Min; case Max; case ToMax; case Skip; case False }; import Act.*
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }; import Decision.*

// NN/AA = Nothing/Any
// LL/UU = the original bounds, on the type parameter
// L1/U1 = the constrained bounds, on the type variable
// V = variance >= 0 ("non-contravariant")
// MSe = minimisedSelected
// Bot = IfBottom
// ToMax = delayed maximisation, via addition to toMaximize
// Skip = minimisedSelected "hold off instantiating"
// False = return false
// Fail = IfBottom.fail's bail option

// there are 9 combinations:
// # | LB | UB | d | // d = direction
Expand All @@ -34,24 +28,27 @@ object InstantiateModel:
// 8 | NN | UU | 0 | T <: UU
// 9 | NN | AA | 0 | T

def decide(lb: LB, ub: UB, v: Var, bot: Bot, m: MSe): Act = (lb, ub) match
def instDecision(lb: LB, ub: UB, v: Int, ifBottom: IfBottom, min: Boolean) = (lb, ub) match
case (L1, AA) => Min
case (L1, UU) => Min
case (LL, U1) => Max
case (NN, U1) => Max

case (L1, U1) => if m==M || v==V then Min else ToMax
case (LL, UU) => if m==M || v==V then Min else ToMax
case (LL, AA) => if m==M || v==V then Min else ToMax

case (NN, UU) => bot match
case _ if m==M => Max
//case Ok if v==V => Min // removed, i14218 fix
case Fail if v==V => False
case _ => ToMax

case (NN, AA) => bot match
case _ if m==M => Skip
case Ok if v==V => Min
case Fail if v==V => False
case _ => ToMax
case (L1, U1) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, UU) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, AA) => if min then Min else pickVar(v, Min, Min, ToMax)

case (NN, UU) => ifBottom match
case _ if min => Max
case IfBottom.ok => pickVar(v, Min, ToMax, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => ToMax

case (NN, AA) => ifBottom match
case _ if min => Skip
case IfBottom.ok => pickVar(v, Min, Min, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => ToMax

def pickVar[A](v: Int, cov: A, inv: A, con: A) =
if v > 0 then cov else if v == 0 then inv else con
Loading

0 comments on commit d490d13

Please sign in to comment.