Skip to content

Commit

Permalink
A PoC with one working case
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Sep 18, 2024
1 parent 687deb1 commit cc767f0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
70 changes: 61 additions & 9 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -877,16 +877,16 @@ class Namer { typer: Typer =>
protected def addAnnotations(sym: Symbol): Unit = original match {
case original: untpd.MemberDef =>
lazy val annotCtx = annotContext(original, sym)
original.setMods:
original.setMods:
original.mods.withAnnotations :
original.mods.annotations.mapConserve: annotTree =>
original.mods.annotations.mapConserve: annotTree =>
val cls = typedAheadAnnotationClass(annotTree)(using annotCtx)
if (cls eq sym)
report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos)
annotTree
else
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
else annotTree
val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx))
sym.addAnnotation(ann1)
Expand Down Expand Up @@ -1545,6 +1545,8 @@ class Namer { typer: Typer =>
case completer: Completer => completer.indexConstructor(constr, constrSym)
case _ =>

// constrSym.info = typeSig(constrSym)

tempInfo = denot.asClass.classInfo.integrateOpaqueMembers.asInstanceOf[TempClassInfo]
denot.info = savedInfo
}
Expand Down Expand Up @@ -1646,14 +1648,17 @@ class Namer { typer: Typer =>
* as an attachment on the ClassDef tree.
*/
def enterParentRefinementSyms(refinements: List[(Name, Type)]) =
println(s"For class $cls, entering parent refinements: $refinements")
val refinedSyms = mutable.ListBuffer[Symbol]()
for (name, tp) <- refinements do
if decls.lookupEntry(name) == null then
val flags = tp match
case tp: MethodOrPoly => Method | Synthetic | Deferred | Tracked
case _ if name.isTermName => Synthetic | Deferred | Tracked
case _ => Synthetic | Deferred
refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
refinedSyms += s
println(s" entered $s")
if refinedSyms.nonEmpty then
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
original.pushAttachment(ParentRefinements, refinedSyms.toList)
Expand Down Expand Up @@ -1695,6 +1700,7 @@ class Namer { typer: Typer =>
end addUsingTraits

completeConstructor(denot)
val constrSym = symbolOfTree(constr)
denot.info = tempInfo.nn

val parentTypes = defn.adjustForTuple(cls, cls.typeParams,
Expand Down Expand Up @@ -1928,7 +1934,7 @@ class Namer { typer: Typer =>
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then
else if Feature.enabled(modularity) then
// set every context bound evidence parameter of a given companion method
// to be tracked, provided it has a type that has an abstract type member.
// Add refinements for all tracked parameters to the result type.
Expand Down Expand Up @@ -1986,14 +1992,60 @@ class Namer { typer: Typer =>
cls.srcPos)
case _ =>

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
/** Try to infer if the parameter needs a `tracked` modifier
*/
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
!sym.is(Tracked)
&& param.hasAttachment(ContextBoundParam)
&& (
isContextBoundWitnessWithAbstractMembers(sym, param)
|| isReferencedInPublicSignatures(sym)
// || isPassedToTrackedParentParameter(sym, param)
)

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
*/
def isContextBoundWitnessWithAbstractMembers(sym: Symbol, param: ValDef)(using Context): Boolean =
param.hasAttachment(ContextBoundParam)
&& sym.info.memberNames(abstractTypeNameFilter).nonEmpty

/** Under x.modularity, we add `tracked` to term parameters whose types are referenced
* in public signatures of the defining class
*/
def isReferencedInPublicSignatures(sym: Symbol)(using Context): Boolean =
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.filter(d => !d.isConstructor).exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
case _ => false
checkOwnerMemberSignatures(owner)

def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean =
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
owner.infoOrCompleter match
// case info: ClassInfo =>
// info.parents.foreach(println)
// info.parents.exists(tpeContainsSymbolRef(_, accessorSyms))
case _ => false

private def namedTypeWithPrefixContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean = tpe match
case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms)
case _ => false

private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean =
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)

private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.lookupAll(sym.name).filter(d => d.is(ParamAccessor)).toList
case _ => List.empty

/** Under x.modularity, set every context bound evidence parameter of a class to be tracked,
* provided it has a type that has an abstract type member. Reset private and local flags
* so that the parameter becomes a `val`.
Expand Down
22 changes: 22 additions & 0 deletions tests/pos/infer-tracked.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import scala.language.experimental.modularity
import scala.language.future

abstract class C:
type T
def foo: T

class F(val x: C):
val result: x.T = x.foo

class G(override val x: C) extends F(x)

def Test =
val c = new C:
type T = Int
def foo = 42

val f = new F(c)
val i: Int = f.result

// val g = new G(c)
// val j: Int = g.result

0 comments on commit cc767f0

Please sign in to comment.