Skip to content

Commit

Permalink
First version of @use checking
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Oct 9, 2024
1 parent 5215945 commit 1746091
Show file tree
Hide file tree
Showing 25 changed files with 262 additions and 80 deletions.
62 changes: 47 additions & 15 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ class CCState:
*/
val approxWarnings: mutable.ListBuffer[Message] = mutable.ListBuffer()

/** The operation to perform `recordUse` is called. This is typically the case
* when a subtype check is performed between a part of a function argument
* and a corresponding part of a formal parameter that was labelled @use.
* Such annotations are mapped to `<use>[T]` applications, which are handled
* as compiletime ops.
* The parameter to the handler is typically the deep capture set of the argument.
* The result of the handler is the result to be returned from the subtype check.
*/
private var useHandler: CaptureSet => Boolean = Function.const(true)

private var curLevel: Level = outermostLevel
private val symLevel: mutable.Map[Symbol, Int] = mutable.Map()

Expand All @@ -108,6 +118,17 @@ object CCState:
*/
def currentLevel(using Context): Level = ccState.curLevel

/** Perform operation `op` with a given useHandler */
inline def withUseHandler[T](handler: CaptureSet => Boolean)(inline op: T)(using Context): T =
val ccs = ccState
val saved = ccs.useHandler
ccs.useHandler = handler
try op finally ccs.useHandler = saved

/** Record a deep capture set in the current `useSet` */
def recordUse(cs: CaptureSet)(using Context): Boolean =
ccState.useHandler(cs)

inline def inNestedLevel[T](inline op: T)(using Context): T =
val ccs = ccState
val saved = ccs.curLevel
Expand Down Expand Up @@ -146,7 +167,7 @@ extension (tree: Tree)
*/
def toCaptureRefs(using Context): List[CaptureRef] = tree match
case ReachCapabilityApply(arg) =>
arg.toCaptureRefs.map(_.reach)
arg.toCaptureRefs.map(_.reach())
case CapsOfApply(arg) =>
arg.toCaptureRefs
case _ => tree.tpe.dealiasKeepAnnots match
Expand Down Expand Up @@ -204,6 +225,7 @@ extension (tp: Type)
tp.derivesFrom(defn.Caps_CapSet)
case AnnotatedType(parent, annot) =>
(annot.symbol == defn.ReachCapabilityAnnot
|| annot.symbol == defn.ReachUnderUseCapabilityAnnot
|| annot.symbol == defn.MaybeCapabilityAnnot
) && parent.isTrackableRef
case _ =>
Expand Down Expand Up @@ -233,7 +255,8 @@ extension (tp: Type)
if dcs.isAlwaysEmpty then dcs
else tp match
case tp @ ReachCapability(_) => tp.singletonCaptureSet
case tp: SingletonCaptureRef => tp.reach.singletonCaptureSet
case tp @ ReachUnderUseCapability(_) => tp.singletonCaptureSet
case tp: SingletonCaptureRef => tp.reach().singletonCaptureSet
case _ => dcs

/** A type capturing `ref` */
Expand Down Expand Up @@ -396,9 +419,11 @@ extension (tp: Type)
* type of `x`. If `x` and `y` are different variables then `{x*}` and `{y*}`
* are unrelated.
*/
def reach(using Context): CaptureRef = tp match
def reach(underUse: Boolean = false)(using Context): CaptureRef = tp match
case tp: CaptureRef if tp.isTrackableRef =>
if tp.isReach then tp else ReachCapability(tp)
if tp.isReach then tp
else if underUse then ReachUnderUseCapability(tp)
else ReachCapability(tp)

/** If `x` is a capture ref, its maybe capability `x?`, represented internally
* as `x @maybeCapability`. `x?` stands for a capability `x` that might or might
Expand Down Expand Up @@ -471,26 +496,32 @@ extension (tp: Type)
object narrowCaps extends TypeMap:
/** Has the variance been flipped at this point? */
private var isFlipped: Boolean = false
private var underUse = false

def apply(t: Type) =
val saved = isFlipped
try
if variance <= 0 then isFlipped = true
t.dealias match
case t1 @ CapturingType(p, cs) if cs.isUniversal && !isFlipped =>
t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet)
case t1 @ FunctionOrMethod(args, res @ Existential(_, _))
t.dealiasKeepAnnots match
case t @ CapturingType(p, cs) if cs.isUniversal && !isFlipped =>
t.derivedCapturingType(apply(p), ref.reach(underUse).singletonCaptureSet)
case t @ AnnotatedType(parent, ann) =>
if ann.symbol == defn.UseAnnot then
val saved = underUse
underUse = true
try mapOver(t)
finally underUse = saved
else
t.derivedAnnotatedType(this(parent), ann)
case t @ FunctionOrMethod(args, res @ Existential(_, _))
if args.forall(_.isAlwaysPure) =>
// Also map existentials in results to reach capabilities if all
// preceding arguments are known to be always pure
apply(t1.derivedFunctionOrMethod(args, Existential.toCap(res)))
this(t.derivedFunctionOrMethod(args, Existential.toCap(res)))
case Existential(_, _) =>
t
case _ => t match
case t @ CapturingType(p, cs) =>
t.derivedCapturingType(apply(p), cs) // don't map capture set variables
case t =>
mapOver(t)
case _ =>
mapOver(t)
finally isFlipped = saved
end narrowCaps

Expand Down Expand Up @@ -640,14 +671,15 @@ object CapsOfApply:
class AnnotatedCapability(annot: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context) =
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))
def unapply(tree: AnnotatedType)(using Context): Option[CaptureRef] = tree match
def unapply(tp: AnnotatedType)(using Context): Option[CaptureRef] = tp match
case AnnotatedType(parent: CaptureRef, ann) if ann.symbol == annot => Some(parent)
case _ => None

/** An extractor for `ref @annotation.internal.reachCapability`, which is used to express
* the reach capability `ref*` as a type.
*/
object ReachCapability extends AnnotatedCapability(defn.ReachCapabilityAnnot)
object ReachUnderUseCapability extends AnnotatedCapability(defn.ReachUnderUseCapabilityAnnot)

/** An extractor for `ref @maybeCapability`, which is used to express
* the maybe capability `ref?` as a type.
Expand Down
15 changes: 14 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,21 @@ trait CaptureRef extends TypeProxy, ValueType:

/** Is this a reach reference of the form `x*`? */
final def isReach(using Context): Boolean = this match
case AnnotatedType(_, annot) => annot.symbol == defn.ReachCapabilityAnnot
case AnnotatedType(_, annot) =>
annot.symbol == defn.ReachCapabilityAnnot || annot.symbol == defn.ReachUnderUseCapabilityAnnot
case _ => false

final def isUnderUse(using Context): Boolean = this match
case AnnotatedType(_, annot) => annot.symbol == defn.ReachUnderUseCapabilityAnnot
case _ => false

def toUnderUse(using Context): CaptureRef =
if isUnderUse then
this match
case _: AnnotatedType => stripReach.reach(underUse = true)
// TODO: Handle capture set variables here
else this

/** Is this a maybe reference of the form `x?`? */
final def isMaybe(using Context): Boolean = this match
case AnnotatedType(_, annot) => annot.symbol == defn.MaybeCapabilityAnnot
Expand Down Expand Up @@ -132,6 +144,7 @@ trait CaptureRef extends TypeProxy, ValueType:
case _ => false
|| this.match
case ReachCapability(x1) => x1.subsumes(y.stripReach)
case ReachUnderUseCapability(x1) => x1.subsumes(y.stripReach)
case x: TermRef => viaInfo(x.info)(subsumingRefs(_, y))
case x: TermParamRef => subsumesExistentially(x, y)
case x: TypeRef => assumedContainsOf(x).contains(y)
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ object CaptureSet:
elem.cls.ccLevel.nextInner <= level
case ReachCapability(elem1) =>
levelOK(elem1)
case ReachUnderUseCapability(elem1) =>
levelOK(elem1)
case MaybeCapability(elem1) =>
levelOK(elem1)
case _ =>
Expand Down Expand Up @@ -1066,6 +1068,8 @@ object CaptureSet:
else CaptureSet.universal
case ReachCapability(ref1) => deepCaptureSet(ref1.widen)
.showing(i"Deep capture set of $ref: ${ref1.widen} = $result", capt)
case ReachUnderUseCapability(ref1) => deepCaptureSet(ref1.widen)
.showing(i"Deep capture set of $ref: ${ref1.widen} = $result", capt)
case _ => ofType(ref.underlying, followResult = true)

/** Capture set of a type */
Expand All @@ -1082,7 +1086,9 @@ object CaptureSet:
empty
case CapturingType(parent, refs) =>
recur(parent) ++ refs
case tp @ AnnotatedType(parent, ann) if ann.hasSymbol(defn.ReachCapabilityAnnot) =>
case tp @ AnnotatedType(parent, ann)
if ann.hasSymbol(defn.ReachCapabilityAnnot)
|| ann.hasSymbol(defn.ReachUnderUseCapabilityAnnot) =>
parent match
case parent: SingletonCaptureRef if parent.isTrackableRef =>
tp.singletonCaptureSet
Expand Down
45 changes: 40 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,19 @@ class CheckCaptures extends Recheck, SymTransformer:
&& (!ccConfig.useSealed || refSym.is(Param))
&& refOwner == env.owner
then
if refSym.hasAnnotation(defn.UnboxAnnot) then
if refSym.hasAnnotation(defn.UnboxAnnot)
|| ref.info.hasAnnotation(defn.UseAnnot)
|| c.isUnderUse
then
capt.println(i"exempt: $ref in $refOwner")
else
// Reach capabilities that go out of scope have to be approximated
// by their underlying capture set, which cannot be universal.
// Reach capabilities of @unboxed parameters are exempted.
val cs = CaptureSet.ofInfo(c)
cs.disallowRootCapability: () =>
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
def kind = if c.isReach then "reach capability" else "capture set variable"
report.error(em"Local $kind $c leaks into capture scope of ${env.ownerString}", pos)
checkSubset(cs, env.captured, pos, provenance(env))
isVisible
case ref: ThisType => isVisibleFromEnv(ref.cls, env)
Expand Down Expand Up @@ -576,11 +580,38 @@ class CheckCaptures extends Recheck, SymTransformer:
protected override
def recheckArg(arg: Tree, formal: Type)(using Context): Type =
val argType = recheck(arg, formal)
accountForUses(arg, argType, formal)
if unboxedArgs.contains(arg) then
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
markFree(argType.deepCaptureSet, arg.srcPos)
argType

class MapUses(deep: Boolean)(using Context) extends TypeMap:
var usesFound = false
def apply(t: Type) = t match
case t @ AnnotatedType(parent, ann) =>
if ann.symbol == defn.UseAnnot then
usesFound = true
defn.UseType.typeRef.appliedTo(apply(parent))
else
t.derivedAnnotatedType(this(parent), ann)
case Existential(_, _) if !deep =>
t
case _ =>
mapOver(t)

def accountForUses(arg: Tree, argType: Type, formal: Type)(using Context): Unit =
val mapper = MapUses(deep = false)
val formal1 = mapper(formal)
if mapper.usesFound then
def markUsesAsFree(cs: CaptureSet): Boolean =
capt.println(i"actual uses for $arg: $argType vs $formal = $cs")
markFree(cs, arg.srcPos)
true
CCState.withUseHandler(markUsesAsFree):
checkConformsExpr(argType, formal1, arg, NothingToAdd)
end accountForUses

/** A specialized implementation of the apply rule.
*
* E |- q: Tq^Cq
Expand Down Expand Up @@ -1366,13 +1397,15 @@ class CheckCaptures extends Recheck, SymTransformer:
* @param sym symbol of the field definition that is being checked
*/
override def checkSubType(actual: Type, expected: Type)(using Context): Boolean =
val expected1 = alignDependentFunction(addOuterRefs(expected, actual, srcPos), actual.stripCapturing)
val mapUses = MapUses(deep = true)
val expected1 = alignDependentFunction(
addOuterRefs(mapUses(expected), actual, srcPos), actual.stripCapturing)
val actual1 =
val saved = curEnv
try
curEnv = Env(clazz, EnvKind.NestedInOwner, capturedVars(clazz), outer0 = curEnv)
val adapted =
adaptBoxed(actual, expected1, srcPos, covariant = true, alwaysConst = true, null)
adaptBoxed(mapUses(actual), expected1, srcPos, covariant = true, alwaysConst = true, null)
actual match
case _: MethodType =>
// We remove the capture set resulted from box adaptation for method types,
Expand All @@ -1382,7 +1415,9 @@ class CheckCaptures extends Recheck, SymTransformer:
adapted.stripCapturing
case _ => adapted
finally curEnv = saved
actual1 frozen_<:< expected1
CCState.withUseHandler(Function.const(false)):
TypeComparer.usingContravarianceForMethods:
actual1 frozen_<:< expected1

override def needsCheck(overriding: Symbol, overridden: Symbol)(using Context): Boolean =
!setup.isPreCC(overriding) && !setup.isPreCC(overridden)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

def apply(tp: Type) =
val tp1 = tp match
case AnnotatedType(parent, annot) if annot.symbol.isRetains =>
// Drop explicit retains annotations
case AnnotatedType(parent, annot) if annot.symbol.isRetains || annot.symbol == defn.UseAnnot =>
// Drop inferred retains and @use annotations
apply(parent)
case tp @ AppliedType(tycon, args) =>
val tycon1 = this(tycon)
Expand Down
24 changes: 21 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,17 @@ class Definitions {
completeClass(enterCompleteClassSymbol(
ScalaPackageClass, tpnme.maybeCapability, Final, List(StaticAnnotationClass.typeRef)))

/** A type `type <use>[+T] <: T` used locally in capture checking. At certain points
* `T @use` types are converted to `<use>[T]` types. These types are handled as
* compile-time applied types by TypeComparer.
*/
@tu lazy val UseType: TypeSymbol =
enterPermanentSymbol(
tpnme.USE,
TypeBounds.upper(
HKTypeLambda(HKTypeLambda.syntheticParamNames(1), Covariant :: Nil)
(_ => TypeBounds.empty :: Nil, _.paramRefs.head))).asType

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
Expand Down Expand Up @@ -1057,12 +1068,13 @@ class Definitions {
@tu lazy val ExperimentalAnnot: ClassSymbol = requiredClass("scala.annotation.experimental")
@tu lazy val ThrowsAnnot: ClassSymbol = requiredClass("scala.throws")
@tu lazy val TransientAnnot: ClassSymbol = requiredClass("scala.transient")
@tu lazy val UnboxAnnot: ClassSymbol = requiredClass("scala.caps.unbox")
@tu lazy val UnboxAnnot: ClassSymbol = requiredClass("scala.caps.unbox")
@tu lazy val UncheckedAnnot: ClassSymbol = requiredClass("scala.unchecked")
@tu lazy val UncheckedStableAnnot: ClassSymbol = requiredClass("scala.annotation.unchecked.uncheckedStable")
@tu lazy val UncheckedVarianceAnnot: ClassSymbol = requiredClass("scala.annotation.unchecked.uncheckedVariance")
@tu lazy val UncheckedCapturesAnnot: ClassSymbol = requiredClass("scala.annotation.unchecked.uncheckedCaptures")
@tu lazy val UntrackedCapturesAnnot: ClassSymbol = requiredClass("scala.caps.untrackedCaptures")
@tu lazy val UseAnnot: ClassSymbol = requiredClass("scala.caps.use")
@tu lazy val VolatileAnnot: ClassSymbol = requiredClass("scala.volatile")
@tu lazy val BeanGetterMetaAnnot: ClassSymbol = requiredClass("scala.annotation.meta.beanGetter")
@tu lazy val BeanSetterMetaAnnot: ClassSymbol = requiredClass("scala.annotation.meta.beanSetter")
Expand All @@ -1077,6 +1089,7 @@ class Definitions {
@tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName")
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
@tu lazy val ReachCapabilityAnnot = requiredClass("scala.annotation.internal.reachCapability")
@tu lazy val ReachUnderUseCapabilityAnnot = requiredClass("scala.annotation.internal.reachUnderUseCapability")
@tu lazy val RequiresCapabilityAnnot: ClassSymbol = requiredClass("scala.annotation.internal.requiresCapability")
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.annotation.retains")
@tu lazy val RetainsCapAnnot: ClassSymbol = requiredClass("scala.annotation.retainsCap")
Expand Down Expand Up @@ -1352,6 +1365,9 @@ class Definitions {
final def isNamedTuple_From(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.From && sym.owner == NamedTupleModule.moduleClass

final def isUse(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.USE && sym.owner == ScalaPackageClass

private val compiletimePackageAnyTypes: Set[Name] = Set(
tpnme.Equals, tpnme.NotEquals, tpnme.IsConst, tpnme.ToString
)
Expand Down Expand Up @@ -1380,7 +1396,7 @@ class Definitions {
tpnme.Plus, tpnme.Length, tpnme.Substring, tpnme.Matches, tpnme.CharAt
)
private val compiletimePackageOpTypes: Set[Name] =
Set(tpnme.S, tpnme.From)
Set(tpnme.S, tpnme.From, tpnme.USE)
++ compiletimePackageAnyTypes
++ compiletimePackageIntTypes
++ compiletimePackageLongTypes
Expand All @@ -1394,6 +1410,7 @@ class Definitions {
&& (
isCompiletime_S(sym)
|| isNamedTuple_From(sym)
|| isUse(sym)
|| sym.owner == CompiletimeOpsAnyModuleClass && compiletimePackageAnyTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsIntModuleClass && compiletimePackageIntTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsLongModuleClass && compiletimePackageLongTypes.contains(sym.name)
Expand Down Expand Up @@ -2195,7 +2212,8 @@ class Definitions {
NothingClass,
SingletonClass,
CBCompanion,
MaybeCapabilityAnnot)
MaybeCapabilityAnnot,
UseType)

@tu lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
EmptyPackageVal,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ object StdNames {
final val WILDCARD_STAR: N = "_*"
final val REIFY_TREECREATOR_PREFIX: N = "$treecreator"
final val REIFY_TYPECREATOR_PREFIX: N = "$typecreator"
final val USE: N = "<use>"

final val Any: N = "Any"
final val AnyKind: N = "AnyKind"
Expand Down
Loading

0 comments on commit 1746091

Please sign in to comment.