From 9cbc402a8501a4d41b024d0a6eda4e1615fa3a75 Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Thu, 10 Aug 2023 09:47:49 +0200 Subject: [PATCH] bugfix: suggest correct arg name completions for lambda expressions [Cherry-picked 77dcdb76d64fef320fcfe713d8806a4f8fc339bf] --- .../pc/completions/CompletionValue.scala | 4 +- .../pc/completions/NamedArgCompletions.scala | 59 +++++++++++++++++-- .../tests/completion/CompletionArgSuite.scala | 52 ++++++++++++++++ 3 files changed, 107 insertions(+), 8 deletions(-) diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/CompletionValue.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/CompletionValue.scala index fbd2ed432d63..ee9731fa5d34 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/CompletionValue.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/CompletionValue.scala @@ -247,10 +247,10 @@ object CompletionValue: description override def insertMode: Option[InsertTextMode] = Some(InsertTextMode.AsIs) - def namedArg(label: String, sym: Symbol)(using + def namedArg(label: String, sym: ParamSymbol)(using Context ): CompletionValue = - NamedArg(label, sym.info.widenTermRefExpr, sym) + NamedArg(label, sym.info.widenTermRefExpr, sym.symbol) def keyword(label: String, insertText: String): CompletionValue = Keyword(label, Some(insertText)) diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala index 3cd2ef79ccce..30346690bd18 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala @@ -16,7 +16,9 @@ import dotty.tools.dotc.core.Symbols import dotty.tools.dotc.core.Symbols.Symbol import dotty.tools.dotc.core.Types.AndType import dotty.tools.dotc.core.Types.AppliedType +import dotty.tools.dotc.core.Types.MethodType import dotty.tools.dotc.core.Types.OrType +import dotty.tools.dotc.core.Types.RefinedType import dotty.tools.dotc.core.Types.TermRef import dotty.tools.dotc.core.Types.Type import dotty.tools.dotc.core.Types.TypeBounds @@ -24,6 +26,7 @@ import dotty.tools.dotc.core.Types.WildcardType import dotty.tools.dotc.util.SourcePosition import dotty.tools.pc.IndexedContext import dotty.tools.pc.utils.MtagsEnrichments.* +import scala.annotation.tailrec object NamedArgCompletions: @@ -195,9 +198,40 @@ object NamedArgCompletions: // def curry(x: Int)(apple: String, banana: String) = ??? // curry(1)(apple = "test", b@@) // ``` - val (baseParams, baseArgs) = + val (baseParams0, baseArgs) = vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) + val baseParams: List[ParamSymbol] = + def defaultBaseParams = baseParams0.map(JustSymbol(_)) + @tailrec + def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] = + if level > 0 then + val resultTypeOpt = + refinedType match + case RefinedType(AppliedType(_, args), _, _) => args.lastOption + case AppliedType(_, args) => args.lastOption + case _ => None + resultTypeOpt match + case Some(resultType) => getRefinedParams(resultType, level - 1) + case _ => defaultBaseParams + else + refinedType match + case RefinedType(AppliedType(_, args), _, MethodType(ri)) => + baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) => + RefinedSymbol(sym, name, arg) + } + case _ => defaultBaseParams + // finds param refinements for lambda expressions + // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x) + @tailrec + def refineParams(method: Tree, level: Int): List[ParamSymbol] = + method match + case Select(Apply(f, _), _) => refineParams(f, level + 1) + case Select(h, v) => getRefinedParams(h.symbol.info, level) + case _ => defaultBaseParams + refineParams(method, 0) + end baseParams + val args = ident .map(i => baseArgs.filterNot(_ == i)) .getOrElse(baseArgs) @@ -221,7 +255,7 @@ object NamedArgCompletions: baseParams.filterNot(param => isNamed(param.name) || - param.denot.is( + param.symbol.denot.is( Flags.Synthetic ) // filter out synthesized param, like evidence ) @@ -232,7 +266,7 @@ object NamedArgCompletions: .map(_.name.toString) .getOrElse("") .replace(Cursor.value, "") - val params: List[Symbol] = + val params: List[ParamSymbol] = allParams .filter(param => param.name.startsWith(prefix)) .distinctBy(sym => (sym.name, sym.info)) @@ -249,7 +283,7 @@ object NamedArgCompletions: .filter(name => name != "Nil" && name != "None") .sorted - def findDefaultValue(param: Symbol): String = + def findDefaultValue(param: ParamSymbol): String = val matchingType = matchingTypesInScope(param.info) if matchingType.size == 1 then s":${matchingType.head}" else if matchingType.size > 1 then s"|???,${matchingType.mkString(",")}|" @@ -260,12 +294,12 @@ object NamedArgCompletions: def shouldShow = allParams.exists(param => param.name.startsWith(prefix)) def isExplicitlyCalled = suffix.startsWith(prefix) - def hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1 + def hasParamsToFill = allParams.count(!_.symbol.is(Flags.HasDefault)) > 1 if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill then val editText = allParams.zipWithIndex .collect { - case (param, index) if !param.is(Flags.HasDefault) => + case (param, index) if !param.symbol.is(Flags.HasDefault) => s"${param.nameBackticked.replace("$", "$$")} = $${${index + 1}${findDefaultValue(param)}}" } .mkString(", ") @@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context): case _ => t end FuzzyArgMatcher + +sealed trait ParamSymbol: + def name: Name + def info: Type + def symbol: Symbol + def nameBackticked(using Context) = name.decoded.backticked + +case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol: + def name: Name = symbol.name + def info: Type = symbol.info + +case class RefinedSymbol(symbol: Symbol, name: Name, info: Type) + extends ParamSymbol diff --git a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala index b376284aa6a6..6e91209d5788 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala @@ -881,3 +881,55 @@ class CompletionArgSuite extends BaseCompletionSuite: |""".stripMargin, topLines = Some(1), ) + + + @Test def `lambda` = + check( + """|val hello: (x: Int) => Unit = x => println(x) + |val k = hello(@@) + |""".stripMargin, + """|x = : Int + |""".stripMargin, + topLines = Some(1), + ) + + @Test def `lambda2` = + check( + """|object O: + | val hello: (x: Int, y: Int) => Unit = (x, _) => println(x) + |val k = O.hello(x = 1, @@) + |""".stripMargin, + """|y = : Int + |""".stripMargin, + topLines = Some(1), + ) + + @Test def `lambda3` = + check( + """|val hello: (x: Int) => (j: Int) => Unit = x => j => println(x) + |val k = hello(@@) + |""".stripMargin, + """|x = : Int + |""".stripMargin, + topLines = Some(1), + ) + + @Test def `lambda4` = + check( + """|val hello: (x: Int) => (j: Int) => (str: String) => Unit = x => j => str => println(str) + |val k = hello(x = 1)(2)(@@) + |""".stripMargin, + """|str = : String + |""".stripMargin, + topLines = Some(1), + ) + + @Test def `lambda5` = + check( + """|val hello: (x: Int) => Int => (str: String) => Unit = x => j => str => println(str) + |val k = hello(x = 1)(2)(@@) + |""".stripMargin, + """|str = : String + |""".stripMargin, + topLines = Some(1), + )