diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index bbe54f14b86c..85f9b39aee3b 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -254,19 +254,36 @@ extends tpd.TreeTraverser: val tp1 = mapInferred(tp) if boxed then box(tp1) else tp1 - /** Expand some aliases of function types to the underlying functions. - * Right now, these are only $throws aliases, but this could be generalized. - */ - private def expandThrowsAlias(tp: Type)(using Context) = tp match - case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias => - // hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->` - defn.FunctionOf( - AnnotatedType( + /** Recognizer for `res $throws exc`, returning `(res, exc)` in case of success */ + object throwsAlias: + def unapply(tp: Type)(using Context): Option[(Type, Type)] = tp match + case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias => + Some((res, exc)) + case _ => + None + + /** Expand $throws aliases. This is hard-coded here since $throws aliases in stdlib + * are defined with `?=>` rather than `?->`. + * We also have to add a capture set to the last expanded throws alias. I.e. + * T $throws E1 $throws E2 + * expands to + * (erased x$0: CanThrow[E1]) ?-> (erased x$1: CanThrow[E1]) ?->{x$0} T + */ + private def expandThrowsAlias(tp: Type, encl: List[MethodType] = Nil)(using Context): Type = tp match + case throwsAlias(res, exc) => + val paramType = AnnotatedType( defn.CanThrowClass.typeRef.appliedTo(exc), - Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil, - res, - isContextual = true - ) + Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) + val isLast = throwsAlias.unapply(res).isEmpty + val paramName = nme.syntheticParamName(encl.length) + val mt = ContextualMethodType(paramName :: Nil)( + _ => paramType :: Nil, + mt => if isLast then res else expandThrowsAlias(res, mt :: encl)) + val fntpe = RefinedType(defn.ErasedFunctionClass.typeRef, nme.apply, mt) + if !encl.isEmpty && isLast then + val cs = CaptureSet(encl.map(_.paramRefs.head)*) + CapturingType(fntpe, cs, boxed = false) + else fntpe case _ => tp private def expandThrowsAliases(using Context) = new TypeMap: @@ -283,70 +300,10 @@ extends tpd.TreeTraverser: case _ => mapOver(t) - /** Fill in capture sets of curried function types from left to right, using - * a combination of the following two rules: - * - * 1. Expand `{c} (x: A) -> (y: B) -> C` - * to `{c} (x: A) -> {c} (y: B) -> C` - * 2. Expand `(x: A) -> (y: B) -> C` where `x` is tracked - * to `(x: A) -> {x} (y: B) -> C` - * - * TODO: Should we also propagate capture sets to the left? - */ - private def expandAbbreviations(using Context) = new TypeMap: - - /** Propagate `outerCs` as well as all tracked parameters as capture set to the result type - * of the dependent function type `tp`. - */ - def propagateDepFunctionResult(tp: Type, outerCs: CaptureSet): Type = tp match - case RefinedType(parent, nme.apply, rinfo: MethodType) => - val localCs = CaptureSet(rinfo.paramRefs.filter(_.isTracked)*) - val rinfo1 = rinfo.derivedLambdaType( - resType = propagateEnclosing(rinfo.resType, CaptureSet.empty, outerCs ++ localCs)) - if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) - else tp - - /** If `tp` is a function type: - * - add `outerCs` as its capture set, - * - propagate `currentCs`, `outerCs`, and all tracked parameters of `tp` to the right. - */ - def propagateEnclosing(tp: Type, currentCs: CaptureSet, outerCs: CaptureSet): Type = tp match - case tp @ AppliedType(tycon, args) if defn.isFunctionClass(tycon.typeSymbol) => - val tycon1 = this(tycon) - val args1 = args.init.mapConserve(this) - val tp1 = - if args1.exists(!_.captureSet.isAlwaysEmpty) then - val propagated = propagateDepFunctionResult( - depFun(tycon, args1, args.last), currentCs ++ outerCs) - propagated match - case RefinedType(_, _, mt: MethodType) => - if mt.isCaptureDependent then propagated - else - // No need to introduce dependent type, switch back to generic function type - tp.derivedAppliedType(tycon1, args1 :+ mt.resType) - else - val resType1 = propagateEnclosing( - args.last, CaptureSet.empty, currentCs ++ outerCs) - tp.derivedAppliedType(tycon1, args1 :+ resType1) - tp1.capturing(outerCs) - case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) => - propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs) - .capturing(outerCs) - case _ => - mapOver(tp) - - def apply(tp: Type): Type = tp match - case CapturingType(parent, cs) => - tp.derivedCapturingType(propagateEnclosing(parent, cs, CaptureSet.empty), cs) - case _ => - propagateEnclosing(tp, CaptureSet.empty, CaptureSet.empty) - end expandAbbreviations - private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type = val tp1 = expandThrowsAliases(if boxed then box(tp) else tp) if tp1 ne tp then capt.println(i"expanded: $tp --> $tp1") - if ctx.settings.YccNoAbbrev.value then tp1 - else expandAbbreviations(tp1) + tp1 /** Transform type of type tree, and remember the transformed type as the type the tree */ private def transformTT(tree: TypeTree, boxed: Boolean, exact: Boolean)(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index e685d8664037..92ff08fea395 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -387,7 +387,6 @@ private sealed trait YSettings: val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.") val YrecheckTest: Setting[Boolean] = BooleanSetting("-Yrecheck-test", "Run basic rechecking (internal test only).") val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Used in conjunction with captureChecking language import, debug info for captured references.") - val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with captureChecking language import, suppress type abbreviations.") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/tests/neg-custom-args/captures/curried-simplified.check b/tests/neg-custom-args/captures/curried-simplified.check deleted file mode 100644 index 6a792314e4e3..000000000000 --- a/tests/neg-custom-args/captures/curried-simplified.check +++ /dev/null @@ -1,42 +0,0 @@ --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:7:28 ---------------------------- -7 | def y1: () -> () -> Int = x1 // error - | ^^ - | Found: () ->? () ->{x} Int - | Required: () -> () -> Int - | - | longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:9:28 ---------------------------- -9 | def y2: () -> () => Int = x2 // error - | ^^ - | Found: () ->{x} () => Int - | Required: () -> () => Int - | - | longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:11:39 --------------------------- -11 | def y3: Cap -> Protect[Int -> Int] = x3 // error - | ^^ - | Found: (x$0: Cap) ->? Int ->{x$0} Int - | Required: Cap -> Protect[Int -> Int] - | - | longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:15:32 --------------------------- -15 | def y5: Cap -> Int ->{} Int = x5 // error - | ^^ - | Found: Cap ->? Int ->{x} Int - | Required: Cap -> Int ->{} Int - | - | longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:17:48 --------------------------- -17 | def y6: Cap -> Cap ->{} Protect[Int -> Int] = x6 // error - | ^^ - | Found: (x$0: Cap) ->? (x$0: Cap) ->{x$0} Int ->{x$0, x$0} Int - | Required: Cap -> Cap ->{} Protect[Int -> Int] - | - | longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:19:48 --------------------------- -19 | def y7: Cap -> Protect[Cap -> Int ->{} Int] = x7 // error - | ^^ - | Found: (x$0: Cap) ->? (x: Cap) ->{x$0} Int ->{x$0, x} Int - | Required: Cap -> Protect[Cap -> Int ->{} Int] - | - | longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/curried-simplified.scala b/tests/neg-custom-args/captures/curried-simplified.scala deleted file mode 100644 index 988cf7c11c45..000000000000 --- a/tests/neg-custom-args/captures/curried-simplified.scala +++ /dev/null @@ -1,21 +0,0 @@ -@annotation.capability class Cap - -type Protect[T] = T - -def test(x: Cap, y: Cap) = - def x1: () -> () ->{x} Int = ??? - def y1: () -> () -> Int = x1 // error - def x2: () ->{x} () => Int = ??? - def y2: () -> () => Int = x2 // error - def x3: Cap -> Int -> Int = ??? - def y3: Cap -> Protect[Int -> Int] = x3 // error - def x4: Cap -> Protect[Int -> Int] = ??? - def y4: Cap -> Int ->{} Int = x4 // ok - def x5: Cap -> Int ->{x} Int = ??? - def y5: Cap -> Int ->{} Int = x5 // error - def x6: Cap -> Cap -> Int -> Int = ??? - def y6: Cap -> Cap ->{} Protect[Int -> Int] = x6 // error - def x7: Cap -> (x: Cap) -> Int -> Int = ??? - def y7: Cap -> Protect[Cap -> Int ->{} Int] = x7 // error - - diff --git a/tests/pos-custom-args/captures/curried-closures.scala b/tests/pos-custom-args/captures/curried-closures.scala index 7258670c295e..baea8b15075c 100644 --- a/tests/pos-custom-args/captures/curried-closures.scala +++ b/tests/pos-custom-args/captures/curried-closures.scala @@ -1,6 +1,26 @@ -import java.io.* -import annotation.capability +object Test: + def map2(xs: List[Int])(f: Int => Int): List[Int] = xs.map(f) + val f1 = map2 + val fc1: List[Int] -> (Int => Int) -> List[Int] = f1 + + def map3(f: Int => Int)(xs: List[Int]): List[Int] = xs.map(f) + private val f2 = map3 + val fc2: (f: Int => Int) -> List[Int] ->{f} List[Int] = f2 + + val f3 = (f: Int => Int) => + println(f(3)) + (xs: List[Int]) => xs.map(_ + 1) + val f3c: (Int => Int) -> List[Int] -> List[Int] = f3 + + class LL[A]: + def drop(n: Int): LL[A]^{this} = ??? + def test(ct: CanThrow[Exception]) = + def xs: LL[Int]^{ct} = ??? + val ys = xs.drop(_) + val ysc: Int -> LL[Int]^{ct} = ys + +import java.io.* def Test4(g: OutputStream^) = val xs: List[Int] = ??? val later = (f: OutputStream^) => (y: Int) => xs.foreach(x => f.write(x + y)) @@ -8,3 +28,4 @@ def Test4(g: OutputStream^) = val later2 = () => (y: Int) => xs.foreach(x => g.write(x + y)) val _: () ->{} Int ->{g} Unit = later2 + diff --git a/tests/pos-custom-args/captures/curried-shorthands.scala b/tests/pos-custom-args/captures/curried-shorthands.scala deleted file mode 100644 index c68dc4b5cdbf..000000000000 --- a/tests/pos-custom-args/captures/curried-shorthands.scala +++ /dev/null @@ -1,24 +0,0 @@ -object Test: - def map2(xs: List[Int])(f: Int => Int): List[Int] = xs.map(f) - val f1 = map2 - val fc1: List[Int] -> (Int => Int) -> List[Int] = f1 - - def map3(f: Int => Int)(xs: List[Int]): List[Int] = xs.map(f) - private val f2 = map3 - val fc2: (Int => Int) -> List[Int] -> List[Int] = f2 - - val f3 = (f: Int => Int) => - println(f(3)) - (xs: List[Int]) => xs.map(_ + 1) - val f3c: (Int => Int) -> List[Int] ->{} List[Int] = f3 - - class LL[A]: - def drop(n: Int): LL[A]^{this} = ??? - - def test(ct: CanThrow[Exception]) = - def xs: LL[Int]^{ct} = ??? - val ys = xs.drop(_) - val ysc: Int -> LL[Int]^{ct} = ys - - - diff --git a/tests/pos-custom-args/captures/i13816.scala b/tests/pos-custom-args/captures/i13816.scala index 235afef35f1c..9d897b0f4601 100644 --- a/tests/pos-custom-args/captures/i13816.scala +++ b/tests/pos-custom-args/captures/i13816.scala @@ -2,12 +2,16 @@ import language.experimental.saferExceptions class Ex1 extends Exception("Ex1") class Ex2 extends Exception("Ex2") +class Ex3 extends Exception("Ex3") def foo0(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit = if i > 0 then throw new Ex1 else throw new Ex2 -def foo01(i: Int): CanThrow[Ex1] ?-> CanThrow[Ex2] ?-> Unit = +/* Does not work yet since annotated CFTs are not recognized properly in typer + +def foo01(i: Int): (ct: CanThrow[Ex1]) ?-> CanThrow[Ex2] ?->{ct} Unit = if i > 0 then throw new Ex1 else throw new Ex2 +*/ def foo1(i: Int): Unit throws Ex1 throws Ex2 = if i > 0 then throw new Ex1 else throw new Ex1 @@ -33,6 +37,15 @@ def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex1 | Ex2 = def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex2 | Ex1 = if i > 0 then throw new Ex1 else throw new Ex2 +/** Does not work yet since the type of the rhs is not hygienic + +def foo9(i: Int): Unit throws Ex1 | Ex2 | Ex3 = + if i > 0 then throw new Ex1 + else if i < 0 then throw new Ex2 + else throw new Ex3 + +*/ + def test(): Unit = try foo1(1)