Skip to content

Commit

Permalink
Set untpdTree in repl compilation unit for completions
Browse files Browse the repository at this point in the history
  • Loading branch information
rochala committed Aug 3, 2023
1 parent ae1b409 commit d64894b
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 20 deletions.
8 changes: 3 additions & 5 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,9 @@ object Completion {
*/
def pathBeforeDesugaring(path: List[Tree], pos: SourcePosition)(using Context): List[Tree] =
val hasUntypedTree = path.headOption.forall(NavigateAST.untypedPath(_, exactMatch = true).nonEmpty)
if hasUntypedTree then
path
else
NavigateAST.untypedPath(pos.span).collect:
case tree: untpd.Tree => tree
if hasUntypedTree then path
else NavigateAST.untypedPath(pos.span).collect:
case tree: untpd.Tree => tree

private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
val path0 = pathBeforeDesugaring(path, pos)
Expand Down
35 changes: 24 additions & 11 deletions compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
end compile

final def typeOf(expr: String)(using state: State): Result[String] =
typeCheck(expr).map { tree =>
typeCheck(expr).map { (_, tpdTree) =>
given Context = state.context
tree.rhs match {
tpdTree.rhs match {
case Block(xs, _) => xs.last.tpe.widen.show
case _ =>
"""Couldn't compute the type of your expression, so sorry :(
Expand Down Expand Up @@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
Iterator(sym) ++ sym.allOverriddenSymbols
}

typeCheck(expr).map {
typeCheck(expr).map { (_, tpdTree) => tpdTree match
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
val stat = stats.last.asInstanceOf[tpd.Tree]
if (stat.tpe.isError) stat.tpe.show
Expand All @@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
}
}

final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {

def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
Expand Down Expand Up @@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
}
}

def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
def error: Result[tpd.ValDef] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors

def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
import tpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error)
.getOrElse(error[tpd.ValDef](sourceFile0))
case _ =>
error
error[tpd.ValDef](sourceFile0)
}
}

def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
import untpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error[untpd.ValDef](sourceFile0))
case _ =>
error[untpd.ValDef](sourceFile0)
}

val src = SourceFile.virtual("<typecheck>", expr)
inContext(state.context.fresh
Expand All @@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
ctx.run.nn.compileUnits(unit :: Nil, ctx)

if (errorsAllowed || !ctx.reporter.hasErrors)
unwrapped(unit.tpdTree, src)
for
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
yield untpdTree -> tpdTree
else
ctx.reporter.removeBufferedMessages.errors
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
given state: State = newRun(state0)
compiler
.typeCheck(expr, errorsAllowed = true)
.map { tree =>
.map { (untpdTree, tpdTree) =>
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
val unit = CompilationUnit(file)(using state.context)
unit.tpdTree = tree
unit.untpdTree = untpdTree
unit.tpdTree = tpdTree
given Context = state.context.fresh.setCompilationUnit(unit)
val srcPos = SourcePosition(file, Span(cursor))
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil
Expand Down
5 changes: 5 additions & 0 deletions compiler/test/dotty/tools/repl/TabcompleteTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class TabcompleteTests extends ReplTest {
assertEquals(List("apply"), comp)
}

@Test def tabCompleteInExtensionDefinition = initially {
val comp = tabComplete("extension (x: Lis")
assertEquals(List("List"), comp)
}

@Test def tabCompleteTwiceIn = {
val src1 = "class Foo { def bar(xs: List[Int]) = xs.map"
val src2 = "class Foo { def bar(xs: List[Int]) = xs.mapC"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,16 @@ class CompletionSuite extends BaseCompletionSuite:
)

@Test def `extension-definition-scope` =
check(
"""|trait Foo
|object T:
| extension (x: Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-symbol-search` =
check(
"""|object T:
| extension (x: ListBuffe@@)
Expand All @@ -1262,11 +1272,102 @@ class CompletionSuite extends BaseCompletionSuite:
|""".stripMargin,
)

@Test def `extension-definition-symbol-search` =
@Test def `extension-definition-type-parameter` =
check(
"""|trait Foo
|object T:
| extension (x: Fo@@)
| extension [A <: Fo@@]
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-type-parameter-symbol-search` =
check(
"""|object T:
| extension [A <: ListBuffe@@]
|""".stripMargin,
"""|ListBuffer[T] - scala.collection.mutable
|ListBuffer - scala.collection.mutable
|""".stripMargin
)

@Test def `extension-definition-using-param-clause` =
check(
"""|trait Foo
|object T:
| extension (using Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)


@Test def `extension-definition-mix-1` =
check(
"""|trait Foo
|object T:
| extension (x: Int)(using Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-2` =
check(
"""|trait Foo
|object T:
| extension (using Fo@@)(x: Int)(using Foo)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-3` =
check(
"""|trait Foo
|object T:
| extension (using Foo)(x: Int)(using Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-4` =
check(
"""|trait Foo
|object T:
| extension [A](x: Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-5` =
check(
"""|trait Foo
|object T:
| extension [A](using Fo@@)(x: Int)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-6` =
check(
"""|trait Foo
|object T:
| extension [A](using Foo)(x: Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
)

@Test def `extension-definition-mix-7` =
check(
"""|trait Foo
|object T:
| extension [A](using Foo)(x: Fo@@)(using Fo@@)
|""".stripMargin,
"""|Foo test
|""".stripMargin
Expand Down

0 comments on commit d64894b

Please sign in to comment.