From 71b983b13ddfc6b14a24f7a30c3d9813e78d2fd3 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 20 Feb 2024 09:31:34 +0100 Subject: [PATCH] Lift all non trivial prefixes for default parameters Checking if the prefix is pure is not enough to know if we need to list the prefix. In the case of default parameters, the prefix tree might be used several times to compute the default values. This expression should only be computed once and therefore it should be lifted if there is some computation/allocation involved. Furthermore, if the prefix contains a local definition, it must be lifted to avoid duplicating the definition. A similar situation could happen with dependent default parameters. This currently works as expected. Fixes #15315 --- .../dotty/tools/dotc/typer/EtaExpansion.scala | 25 ++++++++- .../backend/jvm/DottyBytecodeTests.scala | 52 +++++++++++++++++++ tests/run/i15315.scala | 5 ++ 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 tests/run/i15315.scala diff --git a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala index 2c441c2f915e..b09580d51943 100644 --- a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala +++ b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala @@ -122,7 +122,10 @@ abstract class Lifter { case TypeApply(fn, targs) => cpy.TypeApply(tree)(liftApp(defs, fn), targs) case Select(pre, name) if isPureRef(tree) => - cpy.Select(tree)(liftPrefix(defs, pre), name) + val liftedPrefix = + if tree.symbol.is(HasDefaultParams) then liftPrefix(defs, pre) + else liftNonIdempotentPrefix(defs, pre) + cpy.Select(tree)(liftedPrefix, name) case Block(stats, expr) => liftApp(defs ++= stats, expr) case New(tpt) => @@ -138,8 +141,26 @@ abstract class Lifter { * * unless `pre` is idempotent. */ - def liftPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree = + def liftNonIdempotentPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree = if (isIdempotentExpr(tree)) tree else lift(defs, tree) + + /** Lift prefix `pre` of an application `pre.f(...)` to + * + * val x0 = pre + * x0.f(...) + * + * unless `pre` is idempotent reference, a `this` reference, a literal value, or a or the prefix of an `init` (`New` tree). + * + * Note that default arguments will refer to the prefix, we do not want + * to re-evaluate a complex expression each time we access a getter. + */ + def liftPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree = + tree match + case tree: Literal => tree + case tree: This => tree + case tree: New => tree // prefix of call + case tree: RefTree if isIdempotentExpr(tree) => tree + case _ => lift(defs, tree) } /** No lifting at all */ diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 94d42952a6eb..51390e35b527 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -1733,6 +1733,58 @@ class DottyBytecodeTests extends DottyBytecodeTest { assertSameCode(instructions, expected) } } + + @Test def newInPrefixesOfDefaultParam = { + val source = + s"""class A: + | def f(x: Int = 1): Int = x + | + |class Test: + | def meth1() = (new A).f() + | def meth2() = { val a = new A; a.f() } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "`assert` was not properly inlined in `meth1`\n" + + diffInstructions(instructions1, instructions2)) + } + } + + @Test def newInDependentOfDefaultParam = { + val source = + s"""class A: + | def i: Int = 1 + | + |class Test: + | def f(a: A)(x: Int = a.i): Int = x + | def meth1() = f(new A)() + | def meth2() = { val a = new A; f(a)() } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "`assert` was not properly inlined in `meth1`\n" + + diffInstructions(instructions1, instructions2)) + } + } + } object invocationReceiversTestCode { diff --git a/tests/run/i15315.scala b/tests/run/i15315.scala new file mode 100644 index 000000000000..d9cab7b87b81 --- /dev/null +++ b/tests/run/i15315.scala @@ -0,0 +1,5 @@ +class A: + def f(x: Int = 1): Int = x + +@main def Test() = + (new A{}).f()