Skip to content

Commit

Permalink
Properly handle SAM types with wildcards
Browse files Browse the repository at this point in the history
When typing a closure with an expected type containing a wildcard, the closure
type itself should not contain wildcards, because it might be expanded to an
anonymous class extending the closure type (this happens on non-JVM backends as
well as on the JVM itself in situations where a SAM trait does not compile down
to a SAM interface).

We were already approximating wildcards in the method type returned by the
SAMType extractor, but to fix this issue we had to change the extractor to
perform the approximation on the expected type itself to generate a valid
parent type. The SAMType extractor now returns both the approximated parent
type and the type of the method itself.

The wildcard approximation analysis relies on a new `VarianceMap` opaque type
extracted from Inferencing#variances.

Fixes scala#16065.
Fixes scala#18096.
  • Loading branch information
smarter committed Jul 15, 2023
1 parent 18f90d9 commit 89735d0
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 130 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ class Definitions {
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)

@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)

Expand Down
210 changes: 128 additions & 82 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import CheckRealizable._
import Variances.{Variance, setStructuralVariances, Invariant}
import typer.Nullables
import util.Stats._
import util.SimpleIdentitySet
import util.{SimpleIdentityMap, SimpleIdentitySet}
import ast.tpd._
import ast.TreeTypeMap
import printing.Texts._
Expand Down Expand Up @@ -1751,7 +1751,7 @@ object Types {
t
case t if defn.isErasedFunctionType(t) =>
t
case t @ SAMType(_) =>
case t @ SAMType(_, _) =>
t
case _ =>
NoType
Expand Down Expand Up @@ -5520,104 +5520,119 @@ object Types {
* A type is a SAM type if it is a reference to a class or trait, which
*
* - has a single abstract method with a method type (ExprType
* and PolyType not allowed!) whose result type is not an implicit function type
* and which is not marked inline.
* and PolyType not allowed!) according to `possibleSamMethods`.
* - can be instantiated without arguments or with just () as argument.
*
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
* type of the single abstract method.
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
* type of the single abstract method and `samParent` is a subtype of the matched
* SAM type which has been stripped of wildcards to turn it into a valid parent
* type.
*/
object SAMType {
def zeroParamClass(tp: Type)(using Context): Type = tp match {
/** If possible, return a type which is both a subtype of `origTp` and a type
* application of `samClass` where none of the type arguments are
* wildcards (thus making it a valid parent type), otherwise return
* NoType.
*
* A wildcard in the original type will be replaced by its upper or lower bound in a way
* that maximizes the number of possible implementations of `samMeth`. For example,
* java.util.function defines an interface equivalent to:
*
* trait Function[T, R]:
* def apply(t: T): R
*
* and it usually appears with wildcards to compensate for the lack of
* definition-site variance in Java:
*
* (x => x.toInt): Function[? >: String, ? <: Int]
*
* When typechecking this lambda, we need to approximate the wildcards to find
* a valid parent type for our lambda to extend. We can see that in `apply`,
* `T` only appears contravariantly and `R` only appears covariantly, so by
* minimizing the first parameter and maximizing the second, we maximize the
* number of valid implementations of `apply` which lets us implement the lambda
* with a closure equivalent to:
*
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
*
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
* we arbitrarily pick the upper-bound.
*/
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
val tp = origTp.baseType(samClass)
if !(tp <:< origTp) then NoType
else tp match
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
vmap.recordLocalVariance(tp.symbol, variance)
case _ =>
foldOver(vmap, t)
val vmap = accu(VarianceMap.empty, samMeth.info)
val tparams = tycon.typeParamSymbols
val args1 = args.zipWithConserve(tparams):
case (arg @ TypeBounds(lo, hi), tparam) =>
val v = vmap.computedVariance(tparam)
if v.uncheckedNN < 0 then lo
else hi
case (arg, _) => arg
tp.derivedAppliedType(tycon, args1)
case _ =>
tp

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
def zeroParams(tp: Type): Boolean = tp.stripPoly match
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
case et: ExprType => true
case _ => false
}
// `ContextFunctionN` does not have constructors
val ctor = tp.cls.primaryConstructor
if (!ctor.exists || zeroParams(ctor.info)) tp
else NoType
val cls = tp.cls
val validCtor =
val ctor = cls.primaryConstructor
// `ContextFunctionN` does not have constructors
!ctor.exists || zeroParams(ctor.info)
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if validCtor && isInstantiable then tp.cls
else NoSymbol
case tp: AppliedType =>
zeroParamClass(tp.superType)
samClass(tp.superType)
case tp: TypeRef =>
zeroParamClass(tp.underlying)
samClass(tp.underlying)
case tp: RefinedType =>
zeroParamClass(tp.underlying)
samClass(tp.underlying)
case tp: TypeBounds =>
zeroParamClass(tp.underlying)
samClass(tp.underlying)
case tp: TypeVar =>
zeroParamClass(tp.underlying)
samClass(tp.underlying)
case tp: AnnotatedType =>
zeroParamClass(tp.underlying)
case _ =>
NoType
}
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
tp <:< selfType
samClass(tp.underlying)
case _ =>
false
}
def unapply(tp: Type)(using Context): Option[MethodType] =
if (isInstantiatable(tp)) {
val absMems = tp.possibleSamMethods
if (absMems.size == 1)
absMems.head.info match {
case mt: MethodType if !mt.isParamDependent &&
mt.resultType.isValueTypeOrWildcard =>
val cls = tp.classSymbol

// Given a SAM type such as:
//
// import java.util.function.Function
// Function[? >: String, ? <: Int]
//
// the single abstract method will have type:
//
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
//
// which is not implementable outside of the scope of Function.
//
// To avoid this kind of issue, we approximate references to
// parameters of the SAM type by their bounds, this way in the
// above example we get:
//
// (x: String): Int
val approxParams = new ApproximatingTypeMap {
def apply(tp: Type): Type = tp match {
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
tp.info match {
case info: AliasingBounds =>
mapOver(info.alias)
case TypeBounds(lo, hi) =>
range(atVariance(-variance)(apply(lo)), apply(hi))
case _ =>
range(defn.NothingType, defn.AnyType) // should happen only in error cases
}
case _ =>
mapOver(tp)
}
}
val approx =
if ctx.owner.isContainedIn(cls) then mt
else approxParams(mt).asInstanceOf[MethodType]
Some(approx)
NoSymbol

def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
val cls = samClass(tp)
if cls.exists then
val absMems =
if tp.isRef(defn.PartialFunctionClass) then
// To maintain compatibility with 2.x, we treat PartialFunction specially,
// pretending it is a SAM type. In the future it would be better to merge
// Function and PartialFunction, have Function1 contain a isDefinedAt method
// def isDefinedAt(x: T) = true
// and overwrite that method whenever the function body is a sequence of
// case clauses.
List(defn.PartialFunction_apply)
else
tp.possibleSamMethods.map(_.symbol)
if absMems.lengthCompare(1) == 0 then
val samMethSym = absMems.head
val parent = samParent(tp, cls, samMethSym)
samMethSym.asSeenFrom(parent).info match
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
Some(mt, parent)
case _ =>
None
}
else if (tp isRef defn.PartialFunctionClass)
// To maintain compatibility with 2.x, we treat PartialFunction specially,
// pretending it is a SAM type. In the future it would be better to merge
// Function and PartialFunction, have Function1 contain a isDefinedAt method
// def isDefinedAt(x: T) = true
// and overwrite that method whenever the function body is a sequence of
// case clauses.
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
else None
}
else None
}

Expand Down Expand Up @@ -6450,6 +6465,37 @@ object Types {
}
}

object VarianceMap:
/** An immutable map representing the variance of keys of type `K` */
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
extension [K <: AnyRef](vmap: VarianceMap[K])
/** The backing map used to implement this VarianceMap. */
inline def underlying: SimpleIdentityMap[K, Integer] = vmap

/** Return a new map taking into account that K appears in a
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
*/
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
val previousVariance = vmap(k)
if previousVariance == null then
vmap.updated(k, localVariance)
else if previousVariance == localVariance || previousVariance == 0 then
vmap
else
vmap.updated(k, 0)

/** Return the variance of `k`:
* - A positive value means that `k` appears only covariantly.
* - A negative value means that `k` appears only contravariantly.
* - A zero value means that `k` appears both covariantly and
* contravariantly, or appears invariantly.
* - A null value means that `k` does not appear at all.
*/
def computedVariance(k: K): Integer | Null =
vmap(k)
export VarianceMap.VarianceMap

// ----- Name Filters --------------------------------------------------

/** A name filter selects or discards a member name of a type `pre`.
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
tree // it's a plain function
case tpe if defn.isContextFunctionType(tpe) =>
tree
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
val tpe1 = checkRefinements(tpe, fn)
toPartialFunction(tree, tpe1)
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
checkRefinements(tpe, fn)
tree
case tpe =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ trait Applications extends Compatibility {

def SAMargOK =
defn.isFunctionNType(argtpe1) && formal.match
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
case _ => false

isCompatible(argtpe, formal)
Expand Down Expand Up @@ -2074,7 +2074,7 @@ trait Applications extends Compatibility {
* new java.io.ObjectOutputStream(f)
*/
pt match {
case SAMType(mtp) =>
case SAMType(mtp, _) =>
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
case _ =>
// pick any alternatives that are not methods since these might be convertible
Expand Down
25 changes: 9 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ object Inferencing {
val vs = variances(tp)
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
vs foreachBinding { (tvar, v) =>
vs.underlying foreachBinding { (tvar, v) =>
if !tvar.isInstantiated then
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
Expand Down Expand Up @@ -440,8 +440,6 @@ object Inferencing {
res
}

type VarianceMap = SimpleIdentityMap[TypeVar, Integer]

/** All occurrences of type vars in `tp` that satisfy predicate
* `include` mapped to their variances (-1/0/1) in both `tp` and
* `pt.finalResultType`, where
Expand All @@ -465,23 +463,18 @@ object Inferencing {
*
* we want to instantiate U to x.type right away. No need to wait further.
*/
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
Stats.record("variances")
val constraint = ctx.typerState.constraint

object accu extends TypeAccumulator[VarianceMap] {
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
def setVariance(v: Int) = variance = v
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
case t: TypeVar
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
val v = vmap(t)
if (v == null) vmap.updated(t, variance)
else if (v == variance || v == 0) vmap
else vmap.updated(t, 0)
vmap.recordLocalVariance(t, variance)
case _ =>
foldOver(vmap, t)
}
}

/** Include in `vmap` type variables occurring in the constraints of type variables
* already in `vmap`. Specifically:
Expand All @@ -493,10 +486,10 @@ object Inferencing {
* bounds as non-variant.
* Do this in a fixpoint iteration until `vmap` stabilizes.
*/
def propagate(vmap: VarianceMap): VarianceMap = {
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
var vmap1 = vmap
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
vmap.foreachBinding { (tvar, v) =>
vmap.underlying.foreachBinding { (tvar, v) =>
val param = tvar.origin
constraint.entry(param) match
case TypeBounds(lo, hi) =>
Expand All @@ -512,7 +505,7 @@ object Inferencing {
if (vmap1 eq vmap) vmap else propagate(vmap1)
}

propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
}

/** Run the transformation after dealiasing but return the original type if it was a no-op. */
Expand Down Expand Up @@ -638,7 +631,7 @@ trait Inferencing { this: Typer =>
if !tvar.isInstantiated then
// isInstantiated needs to be checked again, since previous interpolations could already have
// instantiated `tvar` through unification.
val v = vs(tvar)
val v = vs.computedVariance(tvar)
if v == null then buf += ((tvar, 0))
else if v.intValue != 0 then buf += ((tvar, v.intValue))
else comparing(cmp =>
Expand Down
Loading

0 comments on commit 89735d0

Please sign in to comment.