Skip to content

Commit

Permalink
Review: require final, remove special treatment of abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Oct 3, 2024
1 parent c04c727 commit 6b3c8f1
Show file tree
Hide file tree
Showing 57 changed files with 118 additions and 675 deletions.
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,6 @@ class Definitions {
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
@tu lazy val UnusedAnnot: ClassSymbol = requiredClass("scala.annotation.unused")
@tu lazy val UnrollAnnot: ClassSymbol = requiredClass("scala.annotation.unroll")
@tu lazy val AbstractUnrollAnnot: ClassSymbol = requiredClass("scala.annotation.internal.AbstractUnroll")
@tu lazy val UnrollForwarderAnnot: ClassSymbol = requiredClass("scala.annotation.internal.UnrollForwarder")
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")
@tu lazy val NativeAnnot: ClassSymbol = requiredClass("scala.native")
Expand Down
24 changes: 23 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,30 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>

private var inJavaAnnot: Boolean = false

private var seenUnrolledMethods: util.EqHashMap[Symbol, Boolean] | Null = null

private var noCheckNews: Set[New] = Set()

def isValidUnrolledMethod(method: Symbol)(using Context): Boolean =
val seenMethods =
val local = seenUnrolledMethods
if local == null then
val map = new util.EqHashMap[Symbol, Boolean]
seenUnrolledMethods = map
map
else
local
seenMethods.getOrElseUpdate(method, {
var res = true
if method.is(Deferred) then
report.error("Unrolled method must be final and concrete", method.srcPos)
res = false
if !method.isConstructor && !method.is(Final) then
report.error("Unrolled method must be final", method.srcPos)
res = false
res
})

def withNoCheckNews[T](ts: List[New])(op: => T): T = {
val saved = noCheckNews
noCheckNews ++= ts
Expand Down Expand Up @@ -174,7 +196,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
}

private def registerIfUnrolledParam(sym: Symbol)(using Context): Unit =
if sym.getAnnotation(defn.UnrollAnnot).isDefined then
if sym.hasAnnotation(defn.UnrollAnnot) && isValidUnrolledMethod(sym.owner) then
val cls = sym.enclosingClass
val classes = ctx.compilationUnit.unrolledClasses
val additions = Array(cls, cls.linkedClass).filter(_ != NoSymbol)
Expand Down
145 changes: 48 additions & 97 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import dotty.tools.dotc.printing.Formatting.hl
import scala.collection.mutable
import scala.util.boundary, boundary.break
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.unreachable

/**Implementation of SIP-61.
* Runs when `@unroll` annotations are found in a compilation unit, installing new definitions
Expand All @@ -30,11 +31,11 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

import tpd.*

private var _unrolledDefs: util.HashMap[Symbol, ComputedIndicies] | Null = null
private def initializeUnrolledDefs(): util.HashMap[Symbol, ComputedIndicies] =
private var _unrolledDefs: util.EqHashMap[Symbol, ComputedIndicies] | Null = null
private def initializeUnrolledDefs(): util.EqHashMap[Symbol, ComputedIndicies] =
val local = _unrolledDefs
if local == null then
val map = new util.HashMap[Symbol, ComputedIndicies]
val map = new util.EqHashMap[Symbol, ComputedIndicies]
_unrolledDefs = map
map
else
Expand All @@ -54,7 +55,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
def newTransformer(using Context): Transformer =
UnrollingTransformer(ctx.compilationUnit.unrolledClasses.nn)

type ComputedIndicies = Seq[(Int, List[Int])]
type ComputedIndicies = List[(Int, List[Int])]
type ComputeIndicies = Context ?=> Symbol => ComputedIndicies

private class UnrollingTransformer(classes: Set[Symbol]) extends Transformer {
Expand All @@ -68,7 +69,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
.flatMap { (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
else
require(annotated.is(Final, butNot = Deferred) || annotated.isConstructor,
i"${annotated} is not final&concrete, or a constructor")
Some((paramClauseIndex, annotationIndices))
}
})
end computeIndices
Expand Down Expand Up @@ -109,8 +113,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
nextParamIndex: Int,
nextSymbol: Symbol,
annotatedParamListIndex: Int,
isCaseApply: Boolean,
inferOverride: Boolean)(using Context) = {
isCaseApply: Boolean)(using Context) = {

def initNewForwarder()(using Context): (TermSymbol, List[List[Symbol]]) = {
val forwarderDefSymbol0 = Symbols.newSymbol(
Expand All @@ -119,8 +122,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
(defdef.symbol.flags &~
HasDefaultParams &~
(if nextParamIndex == -1 then EmptyFlags else Deferred)) |
Invisible | Synthetic |
(if inferOverride then Override else EmptyFlags),
Invisible | Synthetic,
NoType, // fill in later
coord = nextSymbol.span.shift(1) // shift by 1 to avoid "secondary constructor must call preceding" error
).entered
Expand Down Expand Up @@ -151,20 +153,8 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
}

val paramCount = defdef.symbol.paramSymss(annotatedParamListIndex).size
val isDeferredInitial = paramCount == paramIndex && defdef.symbol.is(Deferred)

val (forwarderDefSymbol, newParamSymLists) =
if isDeferredInitial then
val existing = defdef.symbol.asTerm
existing.addAnnotation(defn.AbstractUnrollAnnot) // mark as previously abstract
existing.flags = (existing.flags &~ Deferred) // going to implement its rhs
existing -> extractParamSymss(identity)
else
initNewForwarder()

if inferOverride then
// in this case we will not replace the source method, but we will add the override flag
defdef.symbol.flags_=(defdef.symbol.flags | Override)
val (forwarderDefSymbol, newParamSymLists) = initNewForwarder()

def forwarderRhs(): tpd.Tree = {
val defaultOffset = defdef.paramss
Expand Down Expand Up @@ -233,7 +223,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
tpd.DefDef(forwarderDefSymbol,
rhs = if nextParamIndex == -1 then EmptyTree else forwarderRhs())

forwarderDef.withSpan(if isDeferredInitial then defdef.span else nextSymbol.span.shift(1))
forwarderDef.withSpan(nextSymbol.span.shift(1))
}

def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
Expand Down Expand Up @@ -268,25 +258,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
).setDefTree
}

def generateSyntheticDefs(tree: Tree, compute: ComputeIndicies)(using Context): (Option[Symbol], Seq[(Symbol, Tree)]) = tree match {
def generateSyntheticDefs(tree: Tree, compute: ComputeIndicies)(using Context): Option[(Symbol, Option[Symbol], Seq[DefDef])] = tree match {
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

// infer an override when we are implementing a method that matches the signature and has unroll annotations
// in the same positions
lazy val inferOverride = {
def unrollIndices(sym: Symbol): List[Int] =
sym.paramSymss.flatten.zipWithIndex.collect({
case (p, i) if p.hasAnnotation(defn.UnrollAnnot) => i
})

val candidate = defdef.symbol.nextOverriddenSymbol
candidate.exists && !candidate.is(Deferred) && candidate.hasAnnotation(defn.AbstractUnrollAnnot) && {
// check unroll indices match
unrollIndices(candidate) == unrollIndices(defdef.symbol)
}
}

val isCaseCopy =
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)

Expand All @@ -302,81 +277,54 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
else defdef.symbol

compute(annotated) match {
case Nil => (None, Nil)
case Nil => None
case Seq((paramClauseIndex, annotationIndices)) =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if (isCaseFromProduct) {
val newDef = generateFromProduct(annotationIndices, paramCount, defdef)
(Some(defdef.symbol), Seq(defdef.symbol -> newDef))
} else {
if (defdef.symbol.is(Deferred)){
(
Some(defdef.symbol),
(-1 +: annotationIndices :+ paramCount).sliding(2).toList.foldLeft((Seq.empty[(Symbol, DefDef)], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
nextParamIndex,
paramIndex,
nextSymbol,
paramClauseIndex,
isCaseApply,
inferOverride
)
// replacements += forwarder.symbol
((defdef.symbol -> forwarder) +: defdefs, forwarder.symbol)
})._1
)

}else{
(
None,
(annotationIndices :+ paramCount).sliding(2).toList.reverse.foldLeft((Seq.empty[(Symbol, DefDef)], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
paramIndex,
nextParamIndex,
nextSymbol,
paramClauseIndex,
isCaseApply,
inferOverride
)
((defdef.symbol -> forwarder) +: defdefs, forwarder.symbol)
})._1
)
}
}
if isCaseFromProduct then
Some((defdef.symbol, Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef))))
else
val (generatedDefs, _) =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft((Seq.empty[DefDef], defdef.symbol)):
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
paramIndex,
nextParamIndex,
nextSymbol,
paramClauseIndex,
isCaseApply
)
(forwarder +: defdefs, forwarder.symbol)
case _ => unreachable("sliding with at least 2 elements")
Some((defdef.symbol, None, generatedDefs))

case multiple => sys.error("Cannot have multiple parameter lists containing `@unroll` annotation")
}

case _ => (None, Nil)
case _ => None
}

def unrollTemplate(tmpl: tpd.Template, compute: ComputeIndicies)(using Context): tpd.Tree = {

val (removed0, generatedDefs0) = tmpl.body.map(generateSyntheticDefs(_, compute)).unzip
val (removedCtor, generatedConstr0) = generateSyntheticDefs(tmpl.constr, compute)
val removedSymsBody = removed0.flatten
val allRemoved = removedSymsBody ++ removedCtor

val generatedDefOrigins = generatedDefs0.flatten
val generatedDefs = generatedDefOrigins.map(_(1))
val generatedConstr = generatedConstr0.map(_(1))

val otherDecls = tmpl.body.filter(t => !removedSymsBody.contains(t.symbol))
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
val allGenerated = generatedBody ++ generatedConstr0
val bodySubs = generatedBody.flatMap((_, maybeSub, _) => maybeSub).toSet
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))

/** inlined from compiler/src/dotty/tools/dotc/typer/Checking.scala */
def checkClash(decl: Symbol, other: Symbol) =
def staticNonStaticPair = decl.isScalaStatic != other.isScalaStatic
decl.matches(other) && !staticNonStaticPair

if generatedDefOrigins.nonEmpty then
if allGenerated.nonEmpty then
val byName = otherDecls.groupMap(_.symbol.name.toString)(_.symbol)
for case (src, dcl: NamedDefTree) <- generatedDefOrigins do
for
(src, _, dcls) <- allGenerated
dcl <- dcls
do
val replaced = dcl.symbol
byName.get(dcl.name.toString).foreach { syms =>
val clashes = syms.filter(checkClash(replaced, _))
Expand All @@ -387,6 +335,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
}
end if

val generatedDefs = generatedBody.flatMap((_, _, gens) => gens)
val generatedConstr = generatedConstr0.toList.flatMap((_, _, gens) => gens)

cpy.Template(tmpl)(
tmpl.constr,
tmpl.parents,
Expand Down
10 changes: 0 additions & 10 deletions library/src/scala/annotation/internal/AbstractUnroll.scala

This file was deleted.

69 changes: 0 additions & 69 deletions sbt-test/unroll-annot/abstractClassMethod/build.sbt

This file was deleted.

This file was deleted.

Loading

0 comments on commit 6b3c8f1

Please sign in to comment.