Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow wildcard types in constraints #12703

Merged
merged 14 commits into from
Jun 9, 2021
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ object Config {
*/
inline val checkConstraintsPropagated = false

/** Check that constraint bounds do not contain wildcard types */
inline val checkNoWildcardsInConstraint = false

/** If a constraint is over a type lambda `tl` and `tvar` is one of
* the type variables associated with `tl` in the constraint, check
* that the origin of `tvar` is a parameter of `tl`.
Expand Down
98 changes: 18 additions & 80 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Flags._
import config.Config
import config.Printers.typr
import reporting.trace
import typer.ProtoTypes.newTypeVar
import StdNames.tpnme

/** Methods for adding constraints and solving them.
Expand Down Expand Up @@ -78,22 +79,29 @@ trait ConstraintHandling {
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))

protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean =
/** If true, eliminate wildcards in bounds by avoidance, otherwise replace
* them by fresh variables.
*/
protected def approximateWildcards: Boolean = true

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
if !constraint.contains(param) then true
else if !isUpper && param.occursIn(bound) then
else if !isUpper && param.occursIn(rawBound) then
// We don't allow recursive lower bounds when defining a type,
// so we shouldn't allow them as constraints either.
false
else
val dropWildcards = new AvoidWildcardsMap:
if !isUpper then variance = -1
override def mapWild(t: WildcardType) =
if approximateWildcards then super.mapWild(t)
else newTypeVar(apply(t.effectiveBounds).toBounds)
val bound = dropWildcards(rawBound)
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
val equalBounds = (if isUpper then lo else hi) eq bound
if equalBounds
&& !bound.existsPart(bp => bp.isInstanceOf[WildcardType] || (bp eq param))
then
// The narrowed bounds are equal and do not contain wildcards,
if equalBounds && !bound.existsPart(_ eq param, stopAtStatic = true) then
// The narrowed bounds are equal and not recursive,
// so we can remove `param` from the constraint.
// (Handling wildcards requires choosing a bound, but we don't know which
// bound to choose here, this is handled in `ConstraintHandling#approximation`)
constraint = constraint.replace(param, bound)
true
else
Expand Down Expand Up @@ -245,81 +253,11 @@ trait ConstraintHandling {
* @pre `param` is in the constraint's domain.
*/
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =

/** Substitute wildcards with fresh TypeParamRefs, to be compared with
* other bound, so that they can be instantiated.
*/
object substWildcards extends TypeMap:
override def stopAtStatic = true

var trackedPolis: List[PolyType] = Nil
def apply(tp: Type) = tp match
case tp: WildcardType =>
val poly = PolyType(tpnme.EMPTY :: Nil)(pt => tp.bounds :: Nil, pt => defn.AnyType)
trackedPolis = poly :: trackedPolis
poly.paramRefs.head
case _ =>
mapOver(tp)
end substWildcards

/** Replace TypeParamRefs substituted for wildcards by `substWildCards`
* and any remaining wildcards by a safe approximation
*/
val replaceWildcards = new TypeMap:
override def stopAtStatic = true

/** Try to instantiate a wildcard or TypeParamRef representing a wildcard
* to a type that is known to conform to it.
* This means:
* If fromBelow is true, we minimize the type overall
* Hence, if variance < 0, pick the maximal safe type: bounds.lo
* (i.e. the whole bounds range is over the type).
* If variance > 0, pick the minimal safe type: bounds.hi
* (i.e. the whole bounds range is under the type).
* If variance == 0, pick bounds.lo anyway (this is arbitrary but in line with
* the principle that we pick the smaller type when in doubt).
* If fromBelow is false, we maximize the type overall and reverse the bounds
* If variance != 0. For variance == 0, we still minimize.
* In summary we pick the bound given by this table:
*
* variance | -1 0 1
* ------------------------
* from below | lo lo hi
* from above | hi lo lo
*/
def pickOneBound(bounds: TypeBounds) =
if variance == 0 || fromBelow == (variance < 0) then bounds.lo
else bounds.hi

def apply(tp: Type) = mapOver {
tp match
case tp: WildcardType =>
pickOneBound(tp.bounds)
case tp: TypeParamRef if substWildcards.trackedPolis.contains(tp.binder) =>
pickOneBound(fullBounds(tp))
case _ => tp
}
end replaceWildcards

constraint.entry(param) match
case entry: TypeBounds =>
val useLowerBound = fromBelow || param.occursIn(entry.hi)
val rawBound = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
val bound = substWildcards(rawBound)
val inst =
if bound eq rawBound then bound
else
// Get rid of wildcards by mapping them to fresh TypeParamRefs
// with constraints derived from comparing both bounds, and then
// instantiating. See pos/i10161.scala for a test where this matters.
val saved = constraint
try
for poly <- substWildcards.trackedPolis do addToConstraint(poly, Nil)
if useLowerBound then bound <:< fullUpperBound(param)
else fullLowerBound(param) <:< bound
replaceWildcards(bound)
finally constraint = saved
typr.println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
inst
case inst =>
assert(inst.exists, i"param = $param\nconstraint = $constraint")
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
var current = this
val todos = new mutable.ListBuffer[(OrderingConstraint, TypeParamRef) => OrderingConstraint]
var i = 0
val dropWildcards = AvoidWildcardsMap()
while (i < poly.paramNames.length) {
val param = poly.paramRefs(i)
val stripped = stripParams(nonParamBounds(param), todos, isUpper = true)
val bounds = dropWildcards(nonParamBounds(param))
val stripped = stripParams(bounds, todos, isUpper = true)
current = updateEntry(current, param, stripped)
while todos.nonEmpty do
current = todos.head(current, param)
Expand Down Expand Up @@ -376,6 +378,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
Nil

private def updateEntry(current: This, param: TypeParamRef, tp: Type)(using Context): This = {
if Config.checkNoWildcardsInConstraint then assert(!tp.containsWildcardTypes)
var current1 = boundsLens.update(this, current, param, tp)
tp match {
case TypeBounds(lo, hi) =>
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
try topLevelSubType(tp1, tp2)
finally useNecessaryEither = saved

/** Use avoidance to get rid of wildcards in constraint bounds if
* we are doing a necessary comparison, or the mode is TypeVarsMissContext.
* The idea is that under either of these conditions we are not interested
* in creating a fresh type variable to replace the wildcard. I verified
Comment on lines +144 to +145
Copy link
Member

@smarter smarter Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why useNecessaryEither is treated specially here, is there a good justification or test case that illustrates this? Also in practice we also call necessaryEither when inferring GADTs:
https://github.com/lampepfl/dotty/blob/a82af21ff48ea07dbae042239286e5d80ef0e92e/compiler/src/dotty/tools/dotc/core/TypeComparer.scala#L1564
Should that condition also appear here? /cc @abgruszecki

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it probably should. Perhaps no GADT tests failed, because the interaction of GADTs and wildcards is undertested.

As a side remark, I'm not sure when exactly we should consider useNecessaryEither in separation from the GADT mode. The name doesn't help either in reminding that we should do both. Perhaps we should hide this condition behind a definition and rename useNecessaryEither?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe 12677 is the test case that fails otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about GADT inference. Does it even come up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that all tests pass with or without the added condition ctx.mode.is(Mode.GadtConstraintInference). Someone else could go to the bottom of this in a separate PR.

* that several tests break if one or the other part of the disjunction is dropped.
* (for instance, i12677.scala demands `useNecessaryEither` in the condition)
*/
override protected def approximateWildcards: Boolean =
useNecessaryEither || ctx.mode.is(Mode.TypevarsMissContext)

def testSubType(tp1: Type, tp2: Type): CompareResult =
GADTused = false
if !topLevelSubType(tp1, tp2) then CompareResult.Fail
Expand Down
16 changes: 15 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ object Types {

/** Does this type contain wildcard types? */
final def containsWildcardTypes(using Context) =
existsPart(_.isInstanceOf[WildcardType], stopAtStatic = true)
existsPart(_.isInstanceOf[WildcardType], stopAtStatic = true, forceLazy = false)

// ----- Higher-order combinators -----------------------------------

Expand Down Expand Up @@ -5053,6 +5053,11 @@ object Types {

/** Wildcard type, possibly with bounds */
abstract case class WildcardType(optBounds: Type) extends CachedGroundType with TermType {

def effectiveBounds(using Context): TypeBounds = optBounds match
case bounds: TypeBounds => bounds
case _ => TypeBounds.empty

def derivedWildcardType(optBounds: Type)(using Context): WildcardType =
if (optBounds eq this.optBounds) this
else if (!optBounds.exists) WildcardType
Expand Down Expand Up @@ -5696,6 +5701,15 @@ object Types {
lo.toText(printer) ~ ".." ~ hi.toText(printer)
}

/** Approximate wildcards by their bounds */
class AvoidWildcardsMap(using Context) extends ApproximatingTypeMap:
protected def mapWild(t: WildcardType) =
val bounds = t.effectiveBounds
range(atVariance(-variance)(apply(bounds.lo)), apply(bounds.hi))
def apply(t: Type): Type = t match
case t: WildcardType => mapWild(t)
case _ => mapOver(t)

// ----- TypeAccumulators ----------------------------------------------------

abstract class TypeAccumulator[T](implicit protected val accCtx: Context)
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,10 @@ object ProtoTypes {
if wildcardOnly
|| ctx.mode.is(Mode.TypevarsMissContext)
|| !ref.underlying.widenExpr.isValueTypeOrWildcard
then WildcardType
else newDepTypeVar(ref)
then
WildcardType(ref.underlying.substParams(mt, mt.paramRefs.map(_ => WildcardType)).toBounds)
else
newDepTypeVar(ref)
mt.resultType.substParams(mt, mt.paramRefs.map(replacement))
else mt.resultType

Expand Down
31 changes: 31 additions & 0 deletions tests/pos/i12677.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
class F[A]
object F {
def apply[A](a: => A) = new F[A]
}

trait TC[A] { type Out }
object TC {
implicit def tc[A]: TC[A] { type Out = String } = ???
}

// ====================================================================================
object Bug {
final class CustomHook[A] {
def blah(implicit tc: TC[A]): CustomHook[tc.Out] = ???
}

def i: CustomHook[Int] = ???
val f = F(i.blah)
f: F[CustomHook[String]] // error
}

// ====================================================================================
object Workaround {
final class CustomHook[A] {
def blah[B](implicit tc: TC[A] { type Out = B }): CustomHook[B] = ??? // raise type
}

def i: CustomHook[Int] = ???
val f = F(i.blah)
f: F[CustomHook[String]] // works
}