Skip to content

Commit

Permalink
Lift all non trivial prefixes for default parameters (#19739)
Browse files Browse the repository at this point in the history
Checking if the prefix is pure is not enough to know if we need to list
the prefix. In the case of default parameters, the prefix tree might be
used several times to compute the default values. This expression should
only be computed once and therefore it should be lifted if there is some
computation/allocation involved. Furthermore, if the prefix contains a
local definition, it must be lifted to avoid duplicating the definition.

A similar situation could happen with dependent default parameters. This
currently works as expected.

Fixes #15315
  • Loading branch information
nicolasstucki authored Feb 29, 2024
2 parents 55df3f3 + 71b983b commit 1191671
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
25 changes: 23 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ abstract class Lifter {
case TypeApply(fn, targs) =>
cpy.TypeApply(tree)(liftApp(defs, fn), targs)
case Select(pre, name) if isPureRef(tree) =>
cpy.Select(tree)(liftPrefix(defs, pre), name)
val liftedPrefix =
if tree.symbol.is(HasDefaultParams) then liftPrefix(defs, pre)
else liftNonIdempotentPrefix(defs, pre)
cpy.Select(tree)(liftedPrefix, name)
case Block(stats, expr) =>
liftApp(defs ++= stats, expr)
case New(tpt) =>
Expand All @@ -138,8 +141,26 @@ abstract class Lifter {
*
* unless `pre` is idempotent.
*/
def liftPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree =
def liftNonIdempotentPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree =
if (isIdempotentExpr(tree)) tree else lift(defs, tree)

/** Lift prefix `pre` of an application `pre.f(...)` to
*
* val x0 = pre
* x0.f(...)
*
* unless `pre` is idempotent reference, a `this` reference, a literal value, or a or the prefix of an `init` (`New` tree).
*
* Note that default arguments will refer to the prefix, we do not want
* to re-evaluate a complex expression each time we access a getter.
*/
def liftPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree =
tree match
case tree: Literal => tree
case tree: This => tree
case tree: New => tree // prefix of <init> call
case tree: RefTree if isIdempotentExpr(tree) => tree
case _ => lift(defs, tree)
}

/** No lifting at all */
Expand Down
52 changes: 52 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,58 @@ class DottyBytecodeTests extends DottyBytecodeTest {
assertSameCode(instructions, expected)
}
}

@Test def newInPrefixesOfDefaultParam = {
val source =
s"""class A:
| def f(x: Int = 1): Int = x
|
|class Test:
| def meth1() = (new A).f()
| def meth2() = { val a = new A; a.f() }
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1)
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
"`assert` was not properly inlined in `meth1`\n" +
diffInstructions(instructions1, instructions2))
}
}

@Test def newInDependentOfDefaultParam = {
val source =
s"""class A:
| def i: Int = 1
|
|class Test:
| def f(a: A)(x: Int = a.i): Int = x
| def meth1() = f(new A)()
| def meth2() = { val a = new A; f(a)() }
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1)
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
"`assert` was not properly inlined in `meth1`\n" +
diffInstructions(instructions1, instructions2))
}
}

}

object invocationReceiversTestCode {
Expand Down
5 changes: 5 additions & 0 deletions tests/run/i15315.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class A:
def f(x: Int = 1): Int = x

@main def Test() =
(new A{}).f()

0 comments on commit 1191671

Please sign in to comment.