Skip to content

Commit

Permalink
Remove curried function types abbreviations
Browse files Browse the repository at this point in the history
Remove automatic insertion of captured in curried function types from left to right.
They were sometimes confusing and with deep capture sets are counter-productive now.
  • Loading branch information
odersky committed Jul 4, 2023
1 parent 8030851 commit f6389d0
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 164 deletions.
103 changes: 30 additions & 73 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,36 @@ extends tpd.TreeTraverser:
val tp1 = mapInferred(tp)
if boxed then box(tp1) else tp1

/** Expand some aliases of function types to the underlying functions.
* Right now, these are only $throws aliases, but this could be generalized.
*/
private def expandThrowsAlias(tp: Type)(using Context) = tp match
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
defn.FunctionOf(
AnnotatedType(
/** Recognizer for `res $throws exc`, returning `(res, exc)` in case of success */
object throwsAlias:
def unapply(tp: Type)(using Context): Option[(Type, Type)] = tp match
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
Some((res, exc))
case _ =>
None

/** Expand $throws aliases. This is hard-coded here since $throws aliases in stdlib
* are defined with `?=>` rather than `?->`.
* We also have to add a capture set to the last expanded throws alias. I.e.
* T $throws E1 $throws E2
* expands to
* (erased x$0: CanThrow[E1]) ?-> (erased x$1: CanThrow[E1]) ?->{x$0} T
*/
private def expandThrowsAlias(tp: Type, encl: List[MethodType] = Nil)(using Context): Type = tp match
case throwsAlias(res, exc) =>
val paramType = AnnotatedType(
defn.CanThrowClass.typeRef.appliedTo(exc),
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
res,
isContextual = true
)
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span))
val isLast = throwsAlias.unapply(res).isEmpty
val paramName = nme.syntheticParamName(encl.length)
val mt = ContextualMethodType(paramName :: Nil)(
_ => paramType :: Nil,
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
val fntpe = RefinedType(defn.ErasedFunctionClass.typeRef, nme.apply, mt)
if !encl.isEmpty && isLast then
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
CapturingType(fntpe, cs, boxed = false)
else fntpe
case _ => tp

private def expandThrowsAliases(using Context) = new TypeMap:
Expand All @@ -283,70 +300,10 @@ extends tpd.TreeTraverser:
case _ =>
mapOver(t)

/** Fill in capture sets of curried function types from left to right, using
* a combination of the following two rules:
*
* 1. Expand `{c} (x: A) -> (y: B) -> C`
* to `{c} (x: A) -> {c} (y: B) -> C`
* 2. Expand `(x: A) -> (y: B) -> C` where `x` is tracked
* to `(x: A) -> {x} (y: B) -> C`
*
* TODO: Should we also propagate capture sets to the left?
*/
private def expandAbbreviations(using Context) = new TypeMap:

/** Propagate `outerCs` as well as all tracked parameters as capture set to the result type
* of the dependent function type `tp`.
*/
def propagateDepFunctionResult(tp: Type, outerCs: CaptureSet): Type = tp match
case RefinedType(parent, nme.apply, rinfo: MethodType) =>
val localCs = CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)
val rinfo1 = rinfo.derivedLambdaType(
resType = propagateEnclosing(rinfo.resType, CaptureSet.empty, outerCs ++ localCs))
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
else tp

/** If `tp` is a function type:
* - add `outerCs` as its capture set,
* - propagate `currentCs`, `outerCs`, and all tracked parameters of `tp` to the right.
*/
def propagateEnclosing(tp: Type, currentCs: CaptureSet, outerCs: CaptureSet): Type = tp match
case tp @ AppliedType(tycon, args) if defn.isFunctionClass(tycon.typeSymbol) =>
val tycon1 = this(tycon)
val args1 = args.init.mapConserve(this)
val tp1 =
if args1.exists(!_.captureSet.isAlwaysEmpty) then
val propagated = propagateDepFunctionResult(
depFun(tycon, args1, args.last), currentCs ++ outerCs)
propagated match
case RefinedType(_, _, mt: MethodType) =>
if mt.isCaptureDependent then propagated
else
// No need to introduce dependent type, switch back to generic function type
tp.derivedAppliedType(tycon1, args1 :+ mt.resType)
else
val resType1 = propagateEnclosing(
args.last, CaptureSet.empty, currentCs ++ outerCs)
tp.derivedAppliedType(tycon1, args1 :+ resType1)
tp1.capturing(outerCs)
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
.capturing(outerCs)
case _ =>
mapOver(tp)

def apply(tp: Type): Type = tp match
case CapturingType(parent, cs) =>
tp.derivedCapturingType(propagateEnclosing(parent, cs, CaptureSet.empty), cs)
case _ =>
propagateEnclosing(tp, CaptureSet.empty, CaptureSet.empty)
end expandAbbreviations

private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type =
val tp1 = expandThrowsAliases(if boxed then box(tp) else tp)
if tp1 ne tp then capt.println(i"expanded: $tp --> $tp1")
if ctx.settings.YccNoAbbrev.value then tp1
else expandAbbreviations(tp1)
tp1

/** Transform type of type tree, and remember the transformed type as the type the tree */
private def transformTT(tree: TypeTree, boxed: Boolean, exact: Boolean)(using Context): Unit =
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ private sealed trait YSettings:
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.")
val YrecheckTest: Setting[Boolean] = BooleanSetting("-Yrecheck-test", "Run basic rechecking (internal test only).")
val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Used in conjunction with captureChecking language import, debug info for captured references.")
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with captureChecking language import, suppress type abbreviations.")

/** Area-specific debug output */
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")
Expand Down
42 changes: 0 additions & 42 deletions tests/neg-custom-args/captures/curried-simplified.check

This file was deleted.

21 changes: 0 additions & 21 deletions tests/neg-custom-args/captures/curried-simplified.scala

This file was deleted.

25 changes: 23 additions & 2 deletions tests/pos-custom-args/captures/curried-closures.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
import java.io.*
import annotation.capability
object Test:
def map2(xs: List[Int])(f: Int => Int): List[Int] = xs.map(f)
val f1 = map2
val fc1: List[Int] -> (Int => Int) -> List[Int] = f1

def map3(f: Int => Int)(xs: List[Int]): List[Int] = xs.map(f)
private val f2 = map3
val fc2: (f: Int => Int) -> List[Int] ->{f} List[Int] = f2

val f3 = (f: Int => Int) =>
println(f(3))
(xs: List[Int]) => xs.map(_ + 1)
val f3c: (Int => Int) -> List[Int] -> List[Int] = f3

class LL[A]:
def drop(n: Int): LL[A]^{this} = ???

def test(ct: CanThrow[Exception]) =
def xs: LL[Int]^{ct} = ???
val ys = xs.drop(_)
val ysc: Int -> LL[Int]^{ct} = ys

import java.io.*
def Test4(g: OutputStream^) =
val xs: List[Int] = ???
val later = (f: OutputStream^) => (y: Int) => xs.foreach(x => f.write(x + y))
val _: (f: OutputStream^) ->{} Int ->{f} Unit = later

val later2 = () => (y: Int) => xs.foreach(x => g.write(x + y))
val _: () ->{} Int ->{g} Unit = later2

24 changes: 0 additions & 24 deletions tests/pos-custom-args/captures/curried-shorthands.scala

This file was deleted.

15 changes: 14 additions & 1 deletion tests/pos-custom-args/captures/i13816.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ import language.experimental.saferExceptions

class Ex1 extends Exception("Ex1")
class Ex2 extends Exception("Ex2")
class Ex3 extends Exception("Ex3")

def foo0(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit =
if i > 0 then throw new Ex1 else throw new Ex2

def foo01(i: Int): CanThrow[Ex1] ?-> CanThrow[Ex2] ?-> Unit =
/* Does not work yet since annotated CFTs are not recognized properly in typer
def foo01(i: Int): (ct: CanThrow[Ex1]) ?-> CanThrow[Ex2] ?->{ct} Unit =
if i > 0 then throw new Ex1 else throw new Ex2
*/

def foo1(i: Int): Unit throws Ex1 throws Ex2 =
if i > 0 then throw new Ex1 else throw new Ex1
Expand All @@ -33,6 +37,15 @@ def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex1 | Ex2 =
def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex2 | Ex1 =
if i > 0 then throw new Ex1 else throw new Ex2

/** Does not work yet since the type of the rhs is not hygienic
def foo9(i: Int): Unit throws Ex1 | Ex2 | Ex3 =
if i > 0 then throw new Ex1
else if i < 0 then throw new Ex2
else throw new Ex3
*/

def test(): Unit =
try
foo1(1)
Expand Down

0 comments on commit f6389d0

Please sign in to comment.