Skip to content

Commit

Permalink
Fail when a poly function value has a different number of type params…
Browse files Browse the repository at this point in the history
… than the expected poly function (#21248)
  • Loading branch information
smarter authored Jul 23, 2024
2 parents 0e93a38 + c973d9b commit fb983db
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
52 changes: 27 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1912,32 +1912,34 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
val dpt = pt.dealias

// If the expected type is a polymorphic function with the same number of
// type and value parameters, then infer the types of value parameters from the expected type.
val inferredVParams = dpt match
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
// it must be a valid method parameter type.
if vparam.tpt.isEmpty && isFullyDefined(formal, ForceDegree.failBottom) then
cpy.ValDef(vparam)(tpt = new untpd.InLambdaTypeTree(isResult = false, (tsyms, vsyms) =>
// We don't need to substitute `mt` by `vsyms` because we currently disallow
// dependencies between value parameters of a closure.
formal.substParams(poly, tsyms.map(_.typeRef)))
)
else vparam
case _ =>
vparams

val resultTpt = dpt match
dpt match
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
case _ => untpd.TypeTree()

val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
typed(desugared, pt)
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 then
// If the expected type is a polymorphic function with the same number of
// type and value parameters, then infer the types of value parameters from the expected type.
val inferredVParams = vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
// it must be a valid method parameter type.
if vparam.tpt.isEmpty && isFullyDefined(formal, ForceDegree.failBottom) then
cpy.ValDef(vparam)(tpt = new untpd.InLambdaTypeTree(isResult = false, (tsyms, vsyms) =>
// We don't need to substitute `mt` by `vsyms` because we currently disallow
// dependencies between value parameters of a closure.
formal.substParams(poly, tsyms.map(_.typeRef)))
)
else vparam
val resultTpt =
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
typed(desugared, pt)
else
val msg =
em"""|Provided polymorphic function value doesn't match the expected type $dpt.
|Expected type should be a polymorphic function with the same number of type and value parameters."""
errorTree(EmptyTree, msg, tree.srcPos)
case _ =>
val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
typed(desugared, pt)
end typedPolyFunctionValue

def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
Expand Down
5 changes: 5 additions & 0 deletions tests/neg/i20533.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Error: tests/neg/i20533.scala:5:8 -----------------------------------------------------------------------------------
5 | [X] => (x, y) => Map(x -> y) // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Provided polymorphic function value doesn't match the expected type [X, Y] => (x$1: X, x$2: Y) => Map[X, Y].
| Expected type should be a polymorphic function with the same number of type and value parameters.
6 changes: 6 additions & 0 deletions tests/neg/i20533.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def mapF(h: [X, Y] => (X, Y) => Map[X, Y]): Unit = ???

def test =
mapF(
[X] => (x, y) => Map(x -> y) // error
)

0 comments on commit fb983db

Please sign in to comment.