Skip to content

Commit

Permalink
Implement individual erased parameters
Browse files Browse the repository at this point in the history
Breaking change for erasedDefinitions: this effectively makes the
current `erased` marker in parameter list apply to only the first
parameter.

    def f(erased a: int, b: int)

should now be written as

    def f(erased a: int, erased b: int)

    type Function1 = (x: Int, erased y: Int) => Int
    type Function2 = (Int, erased Int) => Int

Use refined traits for erased functions

- Function types with erased parameters are now always `ErasedFunction` refined with the correct `apply` definition,
  for example:

    scala.runtime.ErasedFunction {
        def apply(x1: Int, erased x2: Int): Int
    }
  where ErasedFunction is an @experimental empty trait.
- Polymorphic functions cannot take erased parameters.
- By-name parameters cannot be erased.
- Internally, use the @ErasedParam annotation as a marker for an erased parameter.
- Parameters that are erased classes are now marked `erased` at Typer phase (with an annotation),
  and in later phases, they are not taken into account when considering erasedness.
- Erased parameters/functions quotes API are changed:
    - `isErased` => `erasedArgs`/`erasedParams` and `hasErasedArgs`/`hasErasedParams`
    - `FunctionClass` now fails when `isErased = true`. Add `ErasedFunctionClass`.
- Added tests and test-fixes for `erasedDefinitions` feature
- Updated specs and internal syntax
- Aside, reject normal tuples with ValDefs in them. This comes up when trying to parse parameters out of tuples.
  In practice they don't show up.

Co-authored-by: Nicolas Stucki <[email protected]>
  • Loading branch information
natsukagami and nicolasstucki committed Mar 1, 2023
1 parent 120edd2 commit 0f7c3ab
Show file tree
Hide file tree
Showing 71 changed files with 879 additions and 403 deletions.
29 changes: 26 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1498,10 +1498,10 @@ object desugar {
case vd: ValDef => vd
}

def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = {
val mods = if (isErased) Given | Erased else Given
def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
val mods = Given
val params = makeImplicitParameters(formals, mods)
FunctionWithMods(params, body, Modifiers(mods))
FunctionWithMods(params, body, Modifiers(mods), erasedParams)
}

private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = {
Expand Down Expand Up @@ -1834,6 +1834,7 @@ object desugar {
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
case _ =>
annotate(tpnme.retains, parent)
case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt)
}
desugared.withSpan(tree.span)
}
Expand Down Expand Up @@ -1909,6 +1910,28 @@ object desugar {
TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait)
}

/** Ensure the given function tree use only ValDefs for parameters.
* For example,
* FunctionWithMods(List(TypeTree(A), TypeTree(B)), body, mods, erasedParams)
* gets converted to
* FunctionWithMods(List(ValDef(x$1, A), ValDef(x$2, B)), body, mods, erasedParams)
*/
def makeFunctionWithValDefs(tree: Function, pt: Type)(using Context): Function = {
val Function(args, result) = tree
args match {
case (_ : ValDef) :: _ => tree // ValDef case can be easily handled
case _ if !ctx.mode.is(Mode.Type) => tree
case _ =>
val applyVParams = args.zipWithIndex.map {
case (p, n) => makeSyntheticParameter(n + 1, p)
}
tree match
case tree: FunctionWithMods =>
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams)
case _ => untpd.Function(applyVParams, result)
}
}

/** Returns list of all pattern variables, possibly with their types,
* without duplicates
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
&& tree.isTerm
&& {
val qualType = tree.qualifier.tpe
hasRefinement(qualType) && !qualType.derivesFrom(defn.PolyFunctionClass)
hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType)
}
def loop(tree: Tree): Boolean = tree match
case TypeApply(fun, _) =>
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
// If `isParamDependent == false`, the value of `previousParamRefs` is not used.
if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN

def valueParam(name: TermName, origInfo: Type): TermSymbol =
def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol =
val maybeImplicit =
if tp.isContextualMethod then Given
else if tp.isImplicitMethod then Implicit
else EmptyFlags
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags
val maybeErased = if isErased then Erased else EmptyFlags

def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)

Expand All @@ -283,7 +283,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
case nil =>
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
(rtp, vparams :: paramss)
case _ =>
Expand Down Expand Up @@ -1140,10 +1140,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, isErased) =>
case defn.ContextFunctionType(argTypes, resType, _) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
MethodType.companion(isContextual = true)(argTypes, resType),
coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
Expand Down
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
override def isType: Boolean = body.isType
}

/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
extends Function(args, body)
/** A function type or closure with `implicit` or `given` modifiers and information on which parameters are `erased` */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val erasedParams: List[Boolean])(implicit @constructorOnly src: SourceFile)
extends Function(args, body) {
assert(args.length == erasedParams.length)

def hasErasedParams = erasedParams.contains(true)
}

/** A polymorphic function type */
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ extension (tp: Type)
defn.FunctionType(
fname.functionArity,
isContextual = fname.isContextFunction,
isErased = fname.isErasedFunction,
isImpure = true).appliedTo(args)
case _ =>
tp
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ class CheckCaptures extends Recheck, SymTransformer:
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing {
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual, isErased) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual, isErased)
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual)
}
else
super.recheckApply(tree, pt) match
Expand Down Expand Up @@ -430,7 +430,7 @@ class CheckCaptures extends Recheck, SymTransformer:
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _, _)
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `*`
Expand Down Expand Up @@ -598,18 +598,18 @@ class CheckCaptures extends Recheck, SymTransformer:
//println(i"check conforms $actual1 <<< $expected1")
super.checkConformsExpr(actual1, expected1, tree)

private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType)
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual)(args, resultType)
.toFunctionType(isJava = false, alwaysDependent = true)

/** Turn `expected` into a dependent function when `actual` is dependent. */
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
def recur(expected: Type): Type = expected.dealias match
case expected @ CapturingType(eparent, refs) =>
CapturingType(recur(eparent), refs, boxed = expected.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual, isErased)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual, isErased)
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case _ =>
expected
Expand Down Expand Up @@ -675,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
Expand Down Expand Up @@ -739,7 +739,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
Expand Down Expand Up @@ -962,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import transform.Recheck.*
import CaptureSet.IdentityCaptRefMap
import Synthetics.isExcluded
import util.Property
import dotty.tools.dotc.core.Annotations.Annotation

/** A tree traverser that prepares a compilation unit to be capture checked.
* It does the following:
Expand All @@ -38,7 +39,6 @@ extends tpd.TreeTraverser:
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
MethodType.companion(
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
)(argTypes, resType)
.toFunctionType(isJava = false, alwaysDependent = true)

Expand All @@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo) if defn.isFunctionType(tp1) =>
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
Expand Down Expand Up @@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
else tp
Expand Down Expand Up @@ -260,7 +260,13 @@ extends tpd.TreeTraverser:
private def expandThrowsAlias(tp: Type)(using Context) = tp match
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
defn.FunctionOf(
AnnotatedType(
defn.CanThrowClass.typeRef.appliedTo(exc),
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
res,
isContextual = true
)
case _ => tp

private def expandThrowsAliases(using Context) = new TypeMap:
Expand Down Expand Up @@ -323,7 +329,7 @@ extends tpd.TreeTraverser:
args.last, CaptureSet.empty, currentCs ++ outerCs)
tp.derivedAppliedType(tycon1, args1 :+ resType1)
tp1.capturing(outerCs)
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
.capturing(outerCs)
case _ =>
Expand Down
Loading

0 comments on commit 0f7c3ab

Please sign in to comment.