Skip to content

Commit

Permalink
bugfix: suggest correct arg name completions for lambda expressions
Browse files Browse the repository at this point in the history
[Cherry-picked 77dcdb7]
  • Loading branch information
kasiaMarek authored and Kordyjan committed Dec 7, 2023
1 parent 43d421a commit 9cbc402
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ object CompletionValue:
description
override def insertMode: Option[InsertTextMode] = Some(InsertTextMode.AsIs)

def namedArg(label: String, sym: Symbol)(using
def namedArg(label: String, sym: ParamSymbol)(using
Context
): CompletionValue =
NamedArg(label, sym.info.widenTermRefExpr, sym)
NamedArg(label, sym.info.widenTermRefExpr, sym.symbol)

def keyword(label: String, insertText: String): CompletionValue =
Keyword(label, Some(insertText))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ import dotty.tools.dotc.core.Symbols
import dotty.tools.dotc.core.Symbols.Symbol
import dotty.tools.dotc.core.Types.AndType
import dotty.tools.dotc.core.Types.AppliedType
import dotty.tools.dotc.core.Types.MethodType
import dotty.tools.dotc.core.Types.OrType
import dotty.tools.dotc.core.Types.RefinedType
import dotty.tools.dotc.core.Types.TermRef
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.core.Types.TypeBounds
import dotty.tools.dotc.core.Types.WildcardType
import dotty.tools.dotc.util.SourcePosition
import dotty.tools.pc.IndexedContext
import dotty.tools.pc.utils.MtagsEnrichments.*
import scala.annotation.tailrec

object NamedArgCompletions:

Expand Down Expand Up @@ -195,9 +198,40 @@ object NamedArgCompletions:
// def curry(x: Int)(apple: String, banana: String) = ???
// curry(1)(apple = "test", b@@)
// ```
val (baseParams, baseArgs) =
val (baseParams0, baseArgs) =
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))

val baseParams: List[ParamSymbol] =
def defaultBaseParams = baseParams0.map(JustSymbol(_))
@tailrec
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
if level > 0 then
val resultTypeOpt =
refinedType match
case RefinedType(AppliedType(_, args), _, _) => args.lastOption
case AppliedType(_, args) => args.lastOption
case _ => None
resultTypeOpt match
case Some(resultType) => getRefinedParams(resultType, level - 1)
case _ => defaultBaseParams
else
refinedType match
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
RefinedSymbol(sym, name, arg)
}
case _ => defaultBaseParams
// finds param refinements for lambda expressions
// val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
@tailrec
def refineParams(method: Tree, level: Int): List[ParamSymbol] =
method match
case Select(Apply(f, _), _) => refineParams(f, level + 1)
case Select(h, v) => getRefinedParams(h.symbol.info, level)
case _ => defaultBaseParams
refineParams(method, 0)
end baseParams

val args = ident
.map(i => baseArgs.filterNot(_ == i))
.getOrElse(baseArgs)
Expand All @@ -221,7 +255,7 @@ object NamedArgCompletions:

baseParams.filterNot(param =>
isNamed(param.name) ||
param.denot.is(
param.symbol.denot.is(
Flags.Synthetic
) // filter out synthesized param, like evidence
)
Expand All @@ -232,7 +266,7 @@ object NamedArgCompletions:
.map(_.name.toString)
.getOrElse("")
.replace(Cursor.value, "")
val params: List[Symbol] =
val params: List[ParamSymbol] =
allParams
.filter(param => param.name.startsWith(prefix))
.distinctBy(sym => (sym.name, sym.info))
Expand All @@ -249,7 +283,7 @@ object NamedArgCompletions:
.filter(name => name != "Nil" && name != "None")
.sorted

def findDefaultValue(param: Symbol): String =
def findDefaultValue(param: ParamSymbol): String =
val matchingType = matchingTypesInScope(param.info)
if matchingType.size == 1 then s":${matchingType.head}"
else if matchingType.size > 1 then s"|???,${matchingType.mkString(",")}|"
Expand All @@ -260,12 +294,12 @@ object NamedArgCompletions:
def shouldShow =
allParams.exists(param => param.name.startsWith(prefix))
def isExplicitlyCalled = suffix.startsWith(prefix)
def hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1
def hasParamsToFill = allParams.count(!_.symbol.is(Flags.HasDefault)) > 1
if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill
then
val editText = allParams.zipWithIndex
.collect {
case (param, index) if !param.is(Flags.HasDefault) =>
case (param, index) if !param.symbol.is(Flags.HasDefault) =>
s"${param.nameBackticked.replace("$", "$$")} = $${${index + 1}${findDefaultValue(param)}}"
}
.mkString(", ")
Expand Down Expand Up @@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
case _ => t

end FuzzyArgMatcher

sealed trait ParamSymbol:
def name: Name
def info: Type
def symbol: Symbol
def nameBackticked(using Context) = name.decoded.backticked

case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol:
def name: Name = symbol.name
def info: Type = symbol.info

case class RefinedSymbol(symbol: Symbol, name: Name, info: Type)
extends ParamSymbol
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,55 @@ class CompletionArgSuite extends BaseCompletionSuite:
|""".stripMargin,
topLines = Some(1),
)


@Test def `lambda` =
check(
"""|val hello: (x: Int) => Unit = x => println(x)
|val k = hello(@@)
|""".stripMargin,
"""|x = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda2` =
check(
"""|object O:
| val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
|val k = O.hello(x = 1, @@)
|""".stripMargin,
"""|y = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda3` =
check(
"""|val hello: (x: Int) => (j: Int) => Unit = x => j => println(x)
|val k = hello(@@)
|""".stripMargin,
"""|x = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda4` =
check(
"""|val hello: (x: Int) => (j: Int) => (str: String) => Unit = x => j => str => println(str)
|val k = hello(x = 1)(2)(@@)
|""".stripMargin,
"""|str = : String
|""".stripMargin,
topLines = Some(1),
)

@Test def `lambda5` =
check(
"""|val hello: (x: Int) => Int => (str: String) => Unit = x => j => str => println(str)
|val k = hello(x = 1)(2)(@@)
|""".stripMargin,
"""|str = : String
|""".stripMargin,
topLines = Some(1),
)

0 comments on commit 9cbc402

Please sign in to comment.