Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport "bugfix: suggest correct arg name completions for lambda expressions" to LTS #19159

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ object CompletionValue:
description
override def insertMode: Option[InsertTextMode] = Some(InsertTextMode.AsIs)

def namedArg(label: String, sym: Symbol)(using
def namedArg(label: String, sym: ParamSymbol)(using
Context
): CompletionValue =
NamedArg(label, sym.info.widenTermRefExpr, sym)
NamedArg(label, sym.info.widenTermRefExpr, sym.symbol)

def keyword(label: String, insertText: String): CompletionValue =
Keyword(label, Some(insertText))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ import dotty.tools.dotc.core.Symbols
import dotty.tools.dotc.core.Symbols.Symbol
import dotty.tools.dotc.core.Types.AndType
import dotty.tools.dotc.core.Types.AppliedType
import dotty.tools.dotc.core.Types.MethodType
import dotty.tools.dotc.core.Types.OrType
import dotty.tools.dotc.core.Types.RefinedType
import dotty.tools.dotc.core.Types.TermRef
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.core.Types.TypeBounds
import dotty.tools.dotc.core.Types.WildcardType
import dotty.tools.dotc.util.SourcePosition
import dotty.tools.pc.IndexedContext
import dotty.tools.pc.utils.MtagsEnrichments.*
import scala.annotation.tailrec

object NamedArgCompletions:

Expand Down Expand Up @@ -195,9 +198,40 @@ object NamedArgCompletions:
// def curry(x: Int)(apple: String, banana: String) = ???
// curry(1)(apple = "test", b@@)
// ```
val (baseParams, baseArgs) =
val (baseParams0, baseArgs) =
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))

val baseParams: List[ParamSymbol] =
def defaultBaseParams = baseParams0.map(JustSymbol(_))
@tailrec
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
if level > 0 then
val resultTypeOpt =
refinedType match
case RefinedType(AppliedType(_, args), _, _) => args.lastOption
case AppliedType(_, args) => args.lastOption
case _ => None
resultTypeOpt match
case Some(resultType) => getRefinedParams(resultType, level - 1)
case _ => defaultBaseParams
else
refinedType match
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
RefinedSymbol(sym, name, arg)
}
case _ => defaultBaseParams
// finds param refinements for lambda expressions
// val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
@tailrec
def refineParams(method: Tree, level: Int): List[ParamSymbol] =
method match
case Select(Apply(f, _), _) => refineParams(f, level + 1)
case Select(h, v) => getRefinedParams(h.symbol.info, level)
case _ => defaultBaseParams
refineParams(method, 0)
end baseParams

val args = ident
.map(i => baseArgs.filterNot(_ == i))
.getOrElse(baseArgs)
Expand All @@ -221,7 +255,7 @@ object NamedArgCompletions:

baseParams.filterNot(param =>
isNamed(param.name) ||
param.denot.is(
param.symbol.denot.is(
Flags.Synthetic
) // filter out synthesized param, like evidence
)
Expand All @@ -232,7 +266,7 @@ object NamedArgCompletions:
.map(_.name.toString)
.getOrElse("")
.replace(Cursor.value, "")
val params: List[Symbol] =
val params: List[ParamSymbol] =
allParams
.filter(param => param.name.startsWith(prefix))
.distinctBy(sym => (sym.name, sym.info))
Expand All @@ -249,7 +283,7 @@ object NamedArgCompletions:
.filter(name => name != "Nil" && name != "None")
.sorted

def findDefaultValue(param: Symbol): String =
def findDefaultValue(param: ParamSymbol): String =
val matchingType = matchingTypesInScope(param.info)
if matchingType.size == 1 then s":${matchingType.head}"
else if matchingType.size > 1 then s"|???,${matchingType.mkString(",")}|"
Expand All @@ -260,12 +294,12 @@ object NamedArgCompletions:
def shouldShow =
allParams.exists(param => param.name.startsWith(prefix))
def isExplicitlyCalled = suffix.startsWith(prefix)
def hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1
def hasParamsToFill = allParams.count(!_.symbol.is(Flags.HasDefault)) > 1
if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill
then
val editText = allParams.zipWithIndex
.collect {
case (param, index) if !param.is(Flags.HasDefault) =>
case (param, index) if !param.symbol.is(Flags.HasDefault) =>
s"${param.nameBackticked.replace("$", "$$")} = $${${index + 1}${findDefaultValue(param)}}"
}
.mkString(", ")
Expand Down Expand Up @@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
case _ => t

end FuzzyArgMatcher

sealed trait ParamSymbol:
def name: Name
def info: Type
def symbol: Symbol
def nameBackticked(using Context) = name.decoded.backticked

case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol:
def name: Name = symbol.name
def info: Type = symbol.info

case class RefinedSymbol(symbol: Symbol, name: Name, info: Type)
extends ParamSymbol
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,55 @@ class CompletionArgSuite extends BaseCompletionSuite:
|""".stripMargin,
topLines = Some(1),
)


@Test def `lambda` =
check(
"""|val hello: (x: Int) => Unit = x => println(x)
|val k = hello(@@)
|""".stripMargin,
"""|x = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda2` =
check(
"""|object O:
| val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
|val k = O.hello(x = 1, @@)
|""".stripMargin,
"""|y = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda3` =
check(
"""|val hello: (x: Int) => (j: Int) => Unit = x => j => println(x)
|val k = hello(@@)
|""".stripMargin,
"""|x = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda4` =
check(
"""|val hello: (x: Int) => (j: Int) => (str: String) => Unit = x => j => str => println(str)
|val k = hello(x = 1)(2)(@@)
|""".stripMargin,
"""|str = : String
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda5` =
check(
"""|val hello: (x: Int) => Int => (str: String) => Unit = x => j => str => println(str)
|val k = hello(x = 1)(2)(@@)
|""".stripMargin,
"""|str = : String
|""".stripMargin,
topLines = Some(1),
)