diff --git a/compiler/src/dotty/tools/dotc/interactive/Completion.scala b/compiler/src/dotty/tools/dotc/interactive/Completion.scala index 7ee3e48f0b68..0b9265ae2635 100644 --- a/compiler/src/dotty/tools/dotc/interactive/Completion.scala +++ b/compiler/src/dotty/tools/dotc/interactive/Completion.scala @@ -136,11 +136,9 @@ object Completion { */ def pathBeforeDesugaring(path: List[Tree], pos: SourcePosition)(using Context): List[Tree] = val hasUntypedTree = path.headOption.forall(NavigateAST.untypedPath(_, exactMatch = true).nonEmpty) - if hasUntypedTree then - path - else - NavigateAST.untypedPath(pos.span).collect: - case tree: untpd.Tree => tree + if hasUntypedTree then path + else NavigateAST.untypedPath(pos.span).collect: + case tree: untpd.Tree => tree private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = { val path0 = pathBeforeDesugaring(path, pos) diff --git a/compiler/src/dotty/tools/repl/ReplCompiler.scala b/compiler/src/dotty/tools/repl/ReplCompiler.scala index 764695e8479b..d3a5561b6080 100644 --- a/compiler/src/dotty/tools/repl/ReplCompiler.scala +++ b/compiler/src/dotty/tools/repl/ReplCompiler.scala @@ -93,9 +93,9 @@ class ReplCompiler extends Compiler: end compile final def typeOf(expr: String)(using state: State): Result[String] = - typeCheck(expr).map { tree => + typeCheck(expr).map { (_, tpdTree) => given Context = state.context - tree.rhs match { + tpdTree.rhs match { case Block(xs, _) => xs.last.tpe.widen.show case _ => """Couldn't compute the type of your expression, so sorry :( @@ -129,7 +129,7 @@ class ReplCompiler extends Compiler: Iterator(sym) ++ sym.allOverriddenSymbols } - typeCheck(expr).map { + typeCheck(expr).map { (_, tpdTree) => tpdTree match case ValDef(_, _, Block(stats, _)) if stats.nonEmpty => val stat = stats.last.asInstanceOf[tpd.Tree] if (stat.tpe.isError) stat.tpe.show @@ -152,7 +152,7 @@ class ReplCompiler extends Compiler: } } - final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = { + final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = { def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = { def wrap(trees: List[untpd.Tree]): untpd.PackageDef = { @@ -181,22 +181,32 @@ class ReplCompiler extends Compiler: } } - def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = { - def error: Result[tpd.ValDef] = - List(new Diagnostic.Error(s"Invalid scala expression", - sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors + def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] = + List(new Diagnostic.Error(s"Invalid scala expression", + sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors + def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = { import tpd._ tree match { case PackageDef(_, List(TypeDef(_, tmpl: Template))) => tmpl.body .collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result } - .getOrElse(error) + .getOrElse(error[tpd.ValDef](sourceFile0)) case _ => - error + error[tpd.ValDef](sourceFile0) } } + def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] = + import untpd._ + tree match { + case PackageDef(_, List(TypeDef(_, tmpl: Template))) => + tmpl.body + .collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result } + .getOrElse(error[untpd.ValDef](sourceFile0)) + case _ => + error[untpd.ValDef](sourceFile0) + } val src = SourceFile.virtual("", expr) inContext(state.context.fresh @@ -209,7 +219,10 @@ class ReplCompiler extends Compiler: ctx.run.nn.compileUnits(unit :: Nil, ctx) if (errorsAllowed || !ctx.reporter.hasErrors) - unwrapped(unit.tpdTree, src) + for + tpdTree <- unwrappedTypeTree(unit.tpdTree, src) + untpdTree <- unwrappedUntypedTree(unit.untpdTree, src) + yield untpdTree -> tpdTree else ctx.reporter.removeBufferedMessages.errors } diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index 905f4f06de08..2471f6bece42 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String], given state: State = newRun(state0) compiler .typeCheck(expr, errorsAllowed = true) - .map { tree => + .map { (untpdTree, tpdTree) => val file = SourceFile.virtual("", expr, maybeIncomplete = true) val unit = CompilationUnit(file)(using state.context) - unit.tpdTree = tree + unit.untpdTree = untpdTree + unit.tpdTree = tpdTree given Context = state.context.fresh.setCompilationUnit(unit) val srcPos = SourcePosition(file, Span(cursor)) val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil diff --git a/compiler/test/dotty/tools/repl/TabcompleteTests.scala b/compiler/test/dotty/tools/repl/TabcompleteTests.scala index 910584a9b5e7..0bce525e1469 100644 --- a/compiler/test/dotty/tools/repl/TabcompleteTests.scala +++ b/compiler/test/dotty/tools/repl/TabcompleteTests.scala @@ -32,6 +32,11 @@ class TabcompleteTests extends ReplTest { assertEquals(List("apply"), comp) } + @Test def tabCompleteInExtensionDefinition = initially { + val comp = tabComplete("extension (x: Lis") + assertEquals(List("List"), comp) + } + @Test def tabCompleteTwiceIn = { val src1 = "class Foo { def bar(xs: List[Int]) = xs.map" val src2 = "class Foo { def bar(xs: List[Int]) = xs.mapC" diff --git a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala index 61bb60ff40fb..440044b959a8 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala @@ -1253,6 +1253,16 @@ class CompletionSuite extends BaseCompletionSuite: ) @Test def `extension-definition-scope` = + check( + """|trait Foo + |object T: + | extension (x: Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-symbol-search` = check( """|object T: | extension (x: ListBuffe@@) @@ -1262,11 +1272,102 @@ class CompletionSuite extends BaseCompletionSuite: |""".stripMargin, ) - @Test def `extension-definition-symbol-search` = + @Test def `extension-definition-type-parameter` = check( """|trait Foo |object T: - | extension (x: Fo@@) + | extension [A <: Fo@@] + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-type-parameter-symbol-search` = + check( + """|object T: + | extension [A <: ListBuffe@@] + |""".stripMargin, + """|ListBuffer[T] - scala.collection.mutable + |ListBuffer - scala.collection.mutable + |""".stripMargin + ) + + @Test def `extension-definition-using-param-clause` = + check( + """|trait Foo + |object T: + | extension (using Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + + @Test def `extension-definition-mix-1` = + check( + """|trait Foo + |object T: + | extension (x: Int)(using Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-2` = + check( + """|trait Foo + |object T: + | extension (using Fo@@)(x: Int)(using Foo) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-3` = + check( + """|trait Foo + |object T: + | extension (using Foo)(x: Int)(using Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-4` = + check( + """|trait Foo + |object T: + | extension [A](x: Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-5` = + check( + """|trait Foo + |object T: + | extension [A](using Fo@@)(x: Int) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-6` = + check( + """|trait Foo + |object T: + | extension [A](using Foo)(x: Fo@@) + |""".stripMargin, + """|Foo test + |""".stripMargin + ) + + @Test def `extension-definition-mix-7` = + check( + """|trait Foo + |object T: + | extension [A](using Foo)(x: Fo@@)(using Fo@@) |""".stripMargin, """|Foo test |""".stripMargin