Skip to content

Commit

Permalink
Backport "Support completions for extension definition parameter" to …
Browse files Browse the repository at this point in the history
…LTS (#20688)

Backports #18331 to the LTS branch.

PR submitted by the release tooling.
[skip ci]
  • Loading branch information
WojciechMazur authored Jun 20, 2024
2 parents e2a9516 + e8a8428 commit c10275d
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 209 deletions.
128 changes: 83 additions & 45 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package dotty.tools.dotc.interactive

import scala.language.unsafeNulls

import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.ast.NavigateAST
import dotty.tools.dotc.config.Printers.interactiv
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
Expand All @@ -25,6 +24,10 @@ import dotty.tools.dotc.util.SourcePosition

import scala.collection.mutable
import scala.util.control.NonFatal
import dotty.tools.dotc.core.ContextOps.localContext
import dotty.tools.dotc.core.Names
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Symbols

/**
* One of the results of a completion query.
Expand All @@ -37,18 +40,17 @@ import scala.util.control.NonFatal
*/
case class Completion(label: String, description: String, symbols: List[Symbol])

object Completion {
object Completion:

import dotty.tools.dotc.ast.tpd._

/** Get possible completions from tree at `pos`
*
* @return offset and list of symbols for possible completions
*/
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) = {
val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) =
val path: List[Tree] = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
computeCompletions(pos, path)(using Interactive.contextOfPath(path).withPhase(Phases.typerPhase))
}

/**
* Inspect `path` to determine what kinds of symbols should be considered.
Expand All @@ -60,10 +62,11 @@ object Completion {
*
* Otherwise, provide no completion suggestion.
*/
def completionMode(path: List[Tree], pos: SourcePosition): Mode =
path match {
case Ident(_) :: Import(_, _) :: _ => Mode.ImportOrExport
case (ref: RefTree) :: _ =>
def completionMode(path: List[untpd.Tree], pos: SourcePosition): Mode =
path match
case untpd.Ident(_) :: untpd.Import(_, _) :: _ => Mode.ImportOrExport
case untpd.Ident(_) :: (_: untpd.ImportSelector) :: _ => Mode.ImportOrExport
case (ref: untpd.RefTree) :: _ =>
if (ref.name.isTermName) Mode.Term
else if (ref.name.isTypeName) Mode.Type
else Mode.None
Expand All @@ -72,9 +75,8 @@ object Completion {
if sel.imported.span.contains(pos.span) then Mode.ImportOrExport
else Mode.None // Can't help completing the renaming

case (_: ImportOrExport) :: _ => Mode.ImportOrExport
case (_: untpd.ImportOrExport) :: _ => Mode.ImportOrExport
case _ => Mode.None
}

/** When dealing with <errors> in varios palces we check to see if they are
* due to incomplete backticks. If so, we ensure we get the full prefix
Expand All @@ -101,10 +103,13 @@ object Completion {
case (sel: untpd.ImportSelector) :: _ =>
completionPrefix(sel.imported :: Nil, pos)

case untpd.Ident(_) :: (sel: untpd.ImportSelector) :: _ if !sel.isGiven =>
completionPrefix(sel.imported :: Nil, pos)

case (tree: untpd.ImportOrExport) :: _ =>
tree.selectors.find(_.span.contains(pos.span)).map { selector =>
tree.selectors.find(_.span.contains(pos.span)).map: selector =>
completionPrefix(selector :: Nil, pos)
}.getOrElse("")
.getOrElse("")

// Foo.`se<TAB> will result in Select(Ident(Foo), <error>)
case (select: untpd.Select) :: _ if select.name == nme.ERROR =>
Expand All @@ -118,27 +123,65 @@ object Completion {
if (ref.name == nme.ERROR) ""
else ref.name.toString.take(pos.span.point - ref.span.point)

case _ =>
""
case _ => ""

end completionPrefix

/** Inspect `path` to determine the offset where the completion result should be inserted. */
def completionOffset(path: List[Tree]): Int =
path match {
case (ref: RefTree) :: _ => ref.span.point
def completionOffset(untpdPath: List[untpd.Tree]): Int =
untpdPath match {
case (ref: untpd.RefTree) :: _ => ref.span.point
case _ => 0
}

private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
val mode = completionMode(path, pos)
val rawPrefix = completionPrefix(path, pos)
/** Some information about the trees is lost after Typer such as Extension method construct
* is expanded into methods. In order to support completions in those cases
* we have to rely on untyped trees and only when types are necessary use typed trees.
*/
def resolveTypedOrUntypedPath(tpdPath: List[Tree], pos: SourcePosition)(using Context): List[untpd.Tree] =
lazy val untpdPath: List[untpd.Tree] = NavigateAST
.pathTo(pos.span, List(ctx.compilationUnit.untpdTree), true).collect:
case untpdTree: untpd.Tree => untpdTree

tpdPath match
case (_: Bind) :: _ => tpdPath
case (_: untpd.TypTree) :: _ => tpdPath
case _ => untpdPath

/** Handle case when cursor position is inside extension method construct.
* The extension method construct is then desugared into methods, and consturct parameters
* are no longer a part of a typed tree, but instead are prepended to method parameters.
*
* @param untpdPath The typed or untyped path to the tree that is being completed
* @param tpdPath The typed path that will be returned if no extension method construct is found
* @param pos The cursor position
*
* @return Typed path to the parameter of the extension construct if found or tpdPath
*/
private def typeCheckExtensionConstructPath(
untpdPath: List[untpd.Tree], tpdPath: List[Tree], pos: SourcePosition
)(using Context): List[Tree] =
untpdPath.collectFirst:
case untpd.ExtMethods(paramss, _) =>
val enclosingParam = paramss.flatten.find(_.span.contains(pos.span))
enclosingParam.map: param =>
ctx.typer.index(paramss.flatten)
val typedEnclosingParam = ctx.typer.typed(param)
Interactive.pathTo(typedEnclosingParam, pos.span)
.flatten.getOrElse(tpdPath)

private def computeCompletions(pos: SourcePosition, tpdPath: List[Tree])(using Context): (Int, List[Completion]) =
val path0 = resolveTypedOrUntypedPath(tpdPath, pos)
val mode = completionMode(path0, pos)
val rawPrefix = completionPrefix(path0, pos)

val hasBackTick = rawPrefix.headOption.contains('`')
val prefix = if hasBackTick then rawPrefix.drop(1) else rawPrefix

val completer = new Completer(mode, prefix, pos)

val completions = path match {
val adjustedPath = typeCheckExtensionConstructPath(path0, tpdPath, pos)
val completions = adjustedPath match
// Ignore synthetic select from `This` because in code it was `Ident`
// See example in dotty.tools.languageserver.CompletionTest.syntheticThis
case Select(qual @ This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions
Expand All @@ -147,21 +190,19 @@ object Completion {
case (tree: ImportOrExport) :: _ => completer.directMemberCompletions(tree.expr)
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => completer.directMemberCompletions(expr)
case _ => completer.scopeCompletions
}

val describedCompletions = describeCompletions(completions)
val backtickedCompletions =
describedCompletions.map(completion => backtickCompletions(completion, hasBackTick))

val offset = completionOffset(path)
val offset = completionOffset(path0)

interactiv.println(i"""completion with pos = $pos,
| prefix = ${completer.prefix},
| term = ${completer.mode.is(Mode.Term)},
| type = ${completer.mode.is(Mode.Type)}
| results = $backtickedCompletions%, %""")
(offset, backtickedCompletions)
}

def backtickCompletions(completion: Completion, hasBackTick: Boolean) =
if hasBackTick || needsBacktick(completion.label) then
Expand All @@ -174,17 +215,17 @@ object Completion {
// https://github.com/scalameta/metals/blob/main/mtags/src/main/scala/scala/meta/internal/mtags/KeywordWrapper.scala
// https://github.com/com-lihaoyi/Ammonite/blob/73a874173cd337f953a3edc9fb8cb96556638fdd/amm/util/src/main/scala/ammonite/util/Model.scala
private def needsBacktick(s: String) =
val chunks = s.split("_", -1)
val chunks = s.split("_", -1).nn

val validChunks = chunks.zipWithIndex.forall { case (chunk, index) =>
chunk.forall(Chars.isIdentifierPart) ||
(chunk.forall(Chars.isOperatorPart) &&
chunk.nn.forall(Chars.isIdentifierPart) ||
(chunk.nn.forall(Chars.isOperatorPart) &&
index == chunks.length - 1 &&
!(chunks.lift(index - 1).contains("") && index - 1 == 0))
}

val validStart =
Chars.isIdentifierStart(s(0)) || chunks(0).forall(Chars.isOperatorPart)
Chars.isIdentifierStart(s(0)) || chunks(0).nn.forall(Chars.isOperatorPart)

val valid = validChunks && validStart && !keywords.contains(s)

Expand Down Expand Up @@ -216,7 +257,7 @@ object Completion {
* For the results of all `xyzCompletions` methods term names and type names are always treated as different keys in the same map
* and they never conflict with each other.
*/
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition):
/** Completions for terms and types that are currently in scope:
* the members of the current class, local definitions and the symbols that have been imported,
* recursively adding completions from outer scopes.
Expand All @@ -230,7 +271,7 @@ object Completion {
* (even if the import follows it syntactically)
* - a more deeply nested import shadowing a member or a local definition causes an ambiguity
*/
def scopeCompletions(using context: Context): CompletionMap = {
def scopeCompletions(using context: Context): CompletionMap =
val mappings = collection.mutable.Map.empty[Name, List[ScopedDenotations]].withDefaultValue(List.empty)
def addMapping(name: Name, denots: ScopedDenotations) =
mappings(name) = mappings(name) :+ denots
Expand Down Expand Up @@ -302,7 +343,7 @@ object Completion {
}

resultMappings
}
end scopeCompletions

/** Widen only those types which are applied or are exactly nothing
*/
Expand Down Expand Up @@ -335,16 +376,16 @@ object Completion {
/** Completions introduced by imports directly in this context.
* Completions from outer contexts are not included.
*/
private def importedCompletions(using Context): CompletionMap = {
private def importedCompletions(using Context): CompletionMap =
val imp = ctx.importInfo

def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
imp.site.member(name).alternatives
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }

if imp == null then
Map.empty
else
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
imp.site.member(name).alternatives
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }

val givenImports = imp.importedImplicits
.map { ref => (ref.implicitName: Name, ref.underlyingRef.denot.asSingleDenotation) }
.filter((name, denot) => include(denot, name))
Expand All @@ -370,7 +411,7 @@ object Completion {
}.toSeq.groupByName

givenImports ++ wildcardMembers ++ explicitMembers
}
end importedCompletions

/** Completions from implicit conversions including old style extensions using implicit classes */
private def implicitConversionMemberCompletions(qual: Tree)(using Context): CompletionMap =
Expand Down Expand Up @@ -532,7 +573,6 @@ object Completion {
extension [N <: Name](namedDenotations: Seq[(N, SingleDenotation)])
@annotation.targetName("groupByNameTupled")
def groupByName: CompletionMap = namedDenotations.groupMap((name, denot) => name)((name, denot) => denot)
}

private type CompletionMap = Map[Name, Seq[SingleDenotation]]

Expand All @@ -545,11 +585,11 @@ object Completion {
* The completion mode: defines what kinds of symbols should be included in the completion
* results.
*/
class Mode(val bits: Int) extends AnyVal {
class Mode(val bits: Int) extends AnyVal:
def is(other: Mode): Boolean = (bits & other.bits) == other.bits
def |(other: Mode): Mode = new Mode(bits | other.bits)
}
object Mode {

object Mode:
/** No symbol should be included */
val None: Mode = new Mode(0)

Expand All @@ -561,6 +601,4 @@ object Completion {

/** Both term and type symbols are allowed */
val ImportOrExport: Mode = new Mode(4) | Term | Type
}
}

35 changes: 24 additions & 11 deletions compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
end compile

final def typeOf(expr: String)(using state: State): Result[String] =
typeCheck(expr).map { tree =>
typeCheck(expr).map { (_, tpdTree) =>
given Context = state.context
tree.rhs match {
tpdTree.rhs match {
case Block(xs, _) => xs.last.tpe.widen.show
case _ =>
"""Couldn't compute the type of your expression, so sorry :(
Expand Down Expand Up @@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
Iterator(sym) ++ sym.allOverriddenSymbols
}

typeCheck(expr).map {
typeCheck(expr).map { (_, tpdTree) => tpdTree match
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
val stat = stats.last.asInstanceOf[tpd.Tree]
if (stat.tpe.isError) stat.tpe.show
Expand All @@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
}
}

final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {

def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
Expand Down Expand Up @@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
}
}

def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
def error: Result[tpd.ValDef] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors

def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
import tpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error)
.getOrElse(error[tpd.ValDef](sourceFile0))
case _ =>
error
error[tpd.ValDef](sourceFile0)
}
}

def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
import untpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error[untpd.ValDef](sourceFile0))
case _ =>
error[untpd.ValDef](sourceFile0)
}

val src = SourceFile.virtual("<typecheck>", expr)
inContext(state.context.fresh
Expand All @@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
ctx.run.nn.compileUnits(unit :: Nil, ctx)

if (errorsAllowed || !ctx.reporter.hasErrors)
unwrapped(unit.tpdTree, src)
for
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
yield untpdTree -> tpdTree
else
ctx.reporter.removeBufferedMessages.errors
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
given state: State = newRun(state0)
compiler
.typeCheck(expr, errorsAllowed = true)
.map { tree =>
.map { (untpdTree, tpdTree) =>
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
val unit = CompilationUnit(file)(using state.context)
unit.tpdTree = tree
unit.untpdTree = untpdTree
unit.tpdTree = tpdTree
given Context = state.context.fresh.setCompilationUnit(unit)
val srcPos = SourcePosition(file, Span(cursor))
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil
Expand Down
Loading

0 comments on commit c10275d

Please sign in to comment.