From 2ac7c1c560e358b3fa2cbe855e2a12b81e0526af Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 15 Nov 2023 22:36:06 +0000 Subject: [PATCH] Fixup and finish List optimisation --- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../tools/dotc/transform/ArrayApply.scala | 23 +++++-- .../tools/backend/jvm/ArrayApplyOptTest.scala | 65 +++++++++++++++++-- tests/run/list-apply-eval.scala | 14 +++- 4 files changed, 91 insertions(+), 12 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index d147e54fd005..928298077ee9 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -521,6 +521,7 @@ class Definitions { def ListType: TypeRef = ListClass.typeRef @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List) @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") def NilType: TermRef = NilModule.termRef @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index af17acdf3cbd..264e34da0e46 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -34,14 +34,15 @@ class ArrayApply extends MiniPhase { case _ => tree - else if isListOrSeqModuleApply(tree.symbol) then + else if isSeqApply(tree) then tree.args match // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && rest.elems.lengthIs < transformListApplyLimit => - rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + val consed = rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + consed.cast(tree.tpe) case _ => tree @@ -52,8 +53,22 @@ class ArrayApply extends MiniPhase { sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) - private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean = - sym == defn.ListModule_apply || sym == defn.SeqModule_apply + private def isListApply(tree: Tree)(using Context): Boolean = + (tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.ListModule + || sym == defn.ListModuleAlias + case _ => false + + private def isSeqApply(tree: Tree)(using Context): Boolean = + isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.SeqModule + || sym == defn.SeqModuleAlias + || sym == defn.CollectionSeqType.symbol.companionModule + case _ => false /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index 94f7682cda31..c39ad2602d4b 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -1,4 +1,5 @@ -package dotty.tools.backend.jvm +package dotty.tools +package backend.jvm import org.junit.Test import org.junit.Assert._ @@ -161,26 +162,76 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } @Test def testListApplyAvoidsIntermediateArray = { - val source = - """ + checkApplyAvoidsIntermediateArray("List"): + """import scala.collection.immutable.{ ::, Nil } |class Foo { | def meth1: List[String] = List("1", "2", "3") - | def meth2: List[String] = - | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + | def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray = { + checkApplyAvoidsIntermediateArray("Seq"): + """import scala.collection.immutable.{ ::, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) |} """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray2 = { + checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"): + """import scala.collection.immutable.{ ::, Seq, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray3 = { + checkApplyAvoidsIntermediateArray("scala.collection.Seq"): + """import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + def checkApplyAvoidsIntermediateArray(name: String)(source: String) = { checkBCode(source) { dir => val clsIn = dir.lookupName("Foo.class", directory = false).input val clsNode = loadClassNode(clsIn) val meth1 = getMethod(clsNode, "meth1") val meth2 = getMethod(clsNode, "meth2") - val instructions1 = instructionsFromMethod(meth1) + val instructions1 = instructionsFromMethod(meth1) match + case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => + instr :+ ret + case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => + // List.apply[?A] doesn't, strictly, return List[?A], + // because it cascades to its definition on IterableFactory + // where it returns CC[A]. The erasure of that is Object, + // which is why Erasure's Typer adds a cast to compensate. + // If we drop that cast while optimising (because using + // the constructor for :: doesn't require the cast like + // List.apply did) then then cons construction chain will + // be typed as ::. + // Unfortunately the LUB of :: and Nil.type is Product + // instead of List, so a cast remains necessary, + // across whatever causes the lub, like `if` or `try` branches. + // Therefore if we dropping the cast may cause a needed cast + // to be necessary, we shouldn't drop the cast, + // which was only motivated by the assert here. + instr :+ ret + case instr => instr val instructions2 = instructionsFromMethod(meth2) assert(instructions1 == instructions2, - "the List.apply method " + + s"the $name.apply method\n" + diffInstructions(instructions1, instructions2)) } } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala index 4e25444689cc..afc594e28101 100644 --- a/tests/run/list-apply-eval.scala +++ b/tests/run/list-apply-eval.scala @@ -6,7 +6,7 @@ object Test: counter += 1 counter.toString - def main(args: Array[String]): Unit = + def main(args: Array[String]): Unit = //List.apply is subject to an optimisation in cleanup //ensure that the arguments are evaluated in the currect order // Rewritten to: @@ -19,3 +19,15 @@ object Test: val emptyList = List[Int]() assert(emptyList == Nil) + + // just assert it doesn't throw CCE to List + val queue = scala.collection.mutable.Queue[String]() + + // test for the cast instruction described in checkApplyAvoidsIntermediateArray + def lub(b: Boolean): List[(String, String)] = + if b then List(("foo", "bar")) else Nil + + // from minimising CI failure in oslib + // again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce) + def lub2(b: Boolean): Unit = + Seq(1) ++ (if (b) Seq(2) else Nil)