From 9a8fd46fa2f6a0fb6c2e017066641a940b578995 Mon Sep 17 00:00:00 2001 From: Martin Kucera <3159068+KuceraMartin@users.noreply.github.com> Date: Tue, 28 Mar 2023 19:08:58 +0200 Subject: [PATCH 1/7] List(...) optimization to avoid intermediate array (closes https://github.com/lampepfl/dotty/issues/17035) [Cherry-picked aaf9ec7910e787ac06bedb4da482a68ec60c4826] --- .../dotty/tools/dotc/core/Definitions.scala | 20 +++++++------- .../tools/dotc/transform/ArrayApply.scala | 27 +++++++++++++++++-- .../tools/backend/jvm/ArrayApplyOptTest.scala | 25 +++++++++++++++++ tests/run/list-apply-eval.scala | 21 +++++++++++++++ 4 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 tests/run/list-apply-eval.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 2084cd4b04b1..6296a77f3bc3 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -513,14 +513,15 @@ class Definitions { methodNames.map(getWrapVarargsArrayModule.requiredMethod(_)) }) - @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") - def ListType: TypeRef = ListClass.typeRef - @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") - @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") - def NilType: TermRef = NilModule.termRef - @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") - def ConsType: TypeRef = ConsClass.typeRef - @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") + @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") + def ListType: TypeRef = ListClass.typeRef + @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") + @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") + def NilType: TermRef = NilModule.termRef + @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") + def ConsType: TypeRef = ConsClass.typeRef + @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") @tu lazy val SingletonClass: ClassSymbol = // needed as a synthetic class because Scala 2.x refers to it in classfiles @@ -539,7 +540,8 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply) @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 6ece8ad63808..3446931cad4f 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -22,9 +22,18 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description + private var transformListApplyLimit = 8 + + private def reducingTransformListApply[A](depth: Int)(body: => A): A = { + val saved = transformListApplyLimit + transformListApplyLimit -= depth + try body + finally transformListApplyLimit = saved + } + override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = if isArrayModuleApply(tree.symbol) then - tree.args match { + tree.args match case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => seqLit @@ -35,7 +44,18 @@ class ArrayApply extends MiniPhase { case _ => tree - } + + else if isListOrSeqModuleApply(tree.symbol) then + tree.args match + // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil + if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && + rest.elems.lengthIs < transformListApplyLimit => + rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) => + tpd.New(defn.ConsType, List(elem, acc)) + + case _ => + tree else tree @@ -43,6 +63,9 @@ class ArrayApply extends MiniPhase { sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) + private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean = + sym == defn.ListModule_apply || sym == defn.SeqModule_apply + /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` * - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ`` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index e7cd20ba98b2..a2d37b8399e5 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -160,4 +160,29 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } } + @Test def testListApplyAvoidsIntermediateArray = { + val source = + """ + |class Foo { + | def meth1: List[String] = List("1", "2", "3") + | def meth2: List[String] = + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "the List.apply method " + + diffInstructions(instructions1, instructions2)) + } + } + } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala new file mode 100644 index 000000000000..4e25444689cc --- /dev/null +++ b/tests/run/list-apply-eval.scala @@ -0,0 +1,21 @@ +object Test: + + var counter = 0 + + def next = + counter += 1 + counter.toString + + def main(args: Array[String]): Unit = + //List.apply is subject to an optimisation in cleanup + //ensure that the arguments are evaluated in the currect order + // Rewritten to: + // val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil))); + val myList = List(next, next, next) + assert(myList == List("1", "2", "3"), myList) + + val mySeq = Seq(next, next, next) + assert(mySeq == Seq("4", "5", "6"), mySeq) + + val emptyList = List[Int]() + assert(emptyList == Nil) From a8ccfd63a4b1a0d2271cded0ad6ba8c68e85653b Mon Sep 17 00:00:00 2001 From: Decel Date: Thu, 26 Oct 2023 14:04:28 +0200 Subject: [PATCH 2/7] Introduce Boxing for Singletons [Cherry-picked ca7bd7d22630d299a4cab9ad147a6c5ab033cd0a] --- compiler/src/dotty/tools/dotc/transform/ArrayApply.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 3446931cad4f..f80645d1e065 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -52,7 +52,7 @@ class ArrayApply extends MiniPhase { if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && rest.elems.lengthIs < transformListApplyLimit => rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) => - tpd.New(defn.ConsType, List(elem, acc)) + tpd.New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) case _ => tree From a3de566c6777ef666f08073430d4682ae8be8016 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 15 Nov 2023 22:35:27 +0000 Subject: [PATCH 3/7] Initial pre-change cleanups [Cherry-picked 90aea07f6fb298a81f357f1ce54a54dab5b9cd8d] --- .../dotty/tools/dotc/core/Definitions.scala | 9 +++-- .../tools/dotc/transform/ArrayApply.scala | 37 +++++++------------ 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6296a77f3bc3..bd58148da2b4 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -531,8 +531,11 @@ class Definitions { List(AnyType), EmptyScope) @tu lazy val SingletonType: TypeRef = SingletonClass.typeRef - @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") - @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") + @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply) + def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq) def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass @tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply) @tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head) @@ -540,8 +543,6 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq") - @tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply) @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index f80645d1e065..af17acdf3cbd 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -1,15 +1,11 @@ -package dotty.tools.dotc +package dotty.tools +package dotc package transform -import core.* +import ast.tpd +import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.* +import reporting.trace import MegaPhase.* -import Contexts.* -import Symbols.* -import Flags.* -import StdNames.* -import dotty.tools.dotc.ast.tpd - - /** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode. * @@ -22,25 +18,18 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description - private var transformListApplyLimit = 8 - - private def reducingTransformListApply[A](depth: Int)(body: => A): A = { - val saved = transformListApplyLimit - transformListApplyLimit -= depth - try body - finally transformListApplyLimit = saved - } + private val transformListApplyLimit = 8 - override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = + override def transformApply(tree: Apply)(using Context): Tree = if isArrayModuleApply(tree.symbol) then tree.args match - case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil + case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => seqLit - case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil + case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) => - tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) + JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) case _ => tree @@ -48,11 +37,11 @@ class ArrayApply extends MiniPhase { else if isListOrSeqModuleApply(tree.symbol) then tree.args match // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types - case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && rest.elems.lengthIs < transformListApplyLimit => - rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) => - tpd.New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) case _ => tree From f7aaa53fc66720d5f32cc07bbd556662f9032231 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 15 Nov 2023 22:36:06 +0000 Subject: [PATCH 4/7] Fixup and finish List optimisation [Cherry-picked 2ac7c1c560e358b3fa2cbe855e2a12b81e0526af] --- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../tools/dotc/transform/ArrayApply.scala | 23 +++++-- .../tools/backend/jvm/ArrayApplyOptTest.scala | 65 +++++++++++++++++-- tests/run/list-apply-eval.scala | 14 +++- 4 files changed, 91 insertions(+), 12 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index bd58148da2b4..c0435494497f 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -517,6 +517,7 @@ class Definitions { def ListType: TypeRef = ListClass.typeRef @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List) @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") def NilType: TermRef = NilModule.termRef @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index af17acdf3cbd..264e34da0e46 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -34,14 +34,15 @@ class ArrayApply extends MiniPhase { case _ => tree - else if isListOrSeqModuleApply(tree.symbol) then + else if isSeqApply(tree) then tree.args match // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && rest.elems.lengthIs < transformListApplyLimit => - rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + val consed = rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + consed.cast(tree.tpe) case _ => tree @@ -52,8 +53,22 @@ class ArrayApply extends MiniPhase { sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) - private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean = - sym == defn.ListModule_apply || sym == defn.SeqModule_apply + private def isListApply(tree: Tree)(using Context): Boolean = + (tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.ListModule + || sym == defn.ListModuleAlias + case _ => false + + private def isSeqApply(tree: Tree)(using Context): Boolean = + isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.SeqModule + || sym == defn.SeqModuleAlias + || sym == defn.CollectionSeqType.symbol.companionModule + case _ => false /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index a2d37b8399e5..ff6fae11dde9 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -1,4 +1,5 @@ -package dotty.tools.backend.jvm +package dotty.tools +package backend.jvm import org.junit.Test import org.junit.Assert._ @@ -161,26 +162,76 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } @Test def testListApplyAvoidsIntermediateArray = { - val source = - """ + checkApplyAvoidsIntermediateArray("List"): + """import scala.collection.immutable.{ ::, Nil } |class Foo { | def meth1: List[String] = List("1", "2", "3") - | def meth2: List[String] = - | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + | def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray = { + checkApplyAvoidsIntermediateArray("Seq"): + """import scala.collection.immutable.{ ::, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) |} """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray2 = { + checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"): + """import scala.collection.immutable.{ ::, Seq, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray3 = { + checkApplyAvoidsIntermediateArray("scala.collection.Seq"): + """import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + def checkApplyAvoidsIntermediateArray(name: String)(source: String) = { checkBCode(source) { dir => val clsIn = dir.lookupName("Foo.class", directory = false).input val clsNode = loadClassNode(clsIn) val meth1 = getMethod(clsNode, "meth1") val meth2 = getMethod(clsNode, "meth2") - val instructions1 = instructionsFromMethod(meth1) + val instructions1 = instructionsFromMethod(meth1) match + case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => + instr :+ ret + case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => + // List.apply[?A] doesn't, strictly, return List[?A], + // because it cascades to its definition on IterableFactory + // where it returns CC[A]. The erasure of that is Object, + // which is why Erasure's Typer adds a cast to compensate. + // If we drop that cast while optimising (because using + // the constructor for :: doesn't require the cast like + // List.apply did) then then cons construction chain will + // be typed as ::. + // Unfortunately the LUB of :: and Nil.type is Product + // instead of List, so a cast remains necessary, + // across whatever causes the lub, like `if` or `try` branches. + // Therefore if we dropping the cast may cause a needed cast + // to be necessary, we shouldn't drop the cast, + // which was only motivated by the assert here. + instr :+ ret + case instr => instr val instructions2 = instructionsFromMethod(meth2) assert(instructions1 == instructions2, - "the List.apply method " + + s"the $name.apply method\n" + diffInstructions(instructions1, instructions2)) } } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala index 4e25444689cc..afc594e28101 100644 --- a/tests/run/list-apply-eval.scala +++ b/tests/run/list-apply-eval.scala @@ -6,7 +6,7 @@ object Test: counter += 1 counter.toString - def main(args: Array[String]): Unit = + def main(args: Array[String]): Unit = //List.apply is subject to an optimisation in cleanup //ensure that the arguments are evaluated in the currect order // Rewritten to: @@ -19,3 +19,15 @@ object Test: val emptyList = List[Int]() assert(emptyList == Nil) + + // just assert it doesn't throw CCE to List + val queue = scala.collection.mutable.Queue[String]() + + // test for the cast instruction described in checkApplyAvoidsIntermediateArray + def lub(b: Boolean): List[(String, String)] = + if b then List(("foo", "bar")) else Nil + + // from minimising CI failure in oslib + // again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce) + def lub2(b: Boolean): Unit = + Seq(1) ++ (if (b) Seq(2) else Nil) From afd622ce2d137054667e4ded24dc87a601868f81 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Thu, 16 Nov 2023 14:12:08 +0000 Subject: [PATCH 5/7] Implement list optimisation limit [Cherry-picked ea1731a3b3812255fd60b3cf8325d013f019ee7b] --- .../tools/dotc/transform/ArrayApply.scala | 38 +++++++++---- tests/run/list-apply-eval.scala | 56 +++++++++++++++++++ 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 264e34da0e46..c8d7a6548870 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -5,6 +5,7 @@ package transform import ast.tpd import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.* import reporting.trace +import util.Property import MegaPhase.* /** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode. @@ -18,7 +19,16 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description - private val transformListApplyLimit = 8 + private val TransformListApplyBudgetKey = new Property.Key[Int] + private def transformListApplyBudget(using Context) = ctx.property(TransformListApplyBudgetKey).getOrElse(8) + + override def prepareForApply(tree: Apply)(using Context): Context = + if isSeqApply(tree) then + val args = seqApplyArgsOrNull(tree) + if args != null then + ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - args.elems.length) + else ctx + else ctx override def transformApply(tree: Apply)(using Context): Tree = if isArrayModuleApply(tree.symbol) then @@ -35,17 +45,12 @@ class ArrayApply extends MiniPhase { tree else if isSeqApply(tree) then - tree.args match - // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types - case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil - if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && - rest.elems.lengthIs < transformListApplyLimit => - val consed = rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => - New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) - consed.cast(tree.tpe) - - case _ => - tree + val args = seqApplyArgsOrNull(tree) + if args != null && (transformListApplyBudget > 0 || args.elems.isEmpty) then + val consed = args.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + consed.cast(tree.tpe) + else tree else tree @@ -70,6 +75,15 @@ class ArrayApply extends MiniPhase { || sym == defn.CollectionSeqType.symbol.companionModule case _ => false + private def seqApplyArgsOrNull(tree: Apply)(using Context): JavaSeqLiteral | Null = + // assumes isSeqApply(tree) + tree.args match + // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil + if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) => + rest + case _ => null + /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` * - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ`` diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala index afc594e28101..19ebaaa69812 100644 --- a/tests/run/list-apply-eval.scala +++ b/tests/run/list-apply-eval.scala @@ -31,3 +31,59 @@ object Test: // again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce) def lub2(b: Boolean): Unit = Seq(1) ++ (if (b) Seq(2) else Nil) + + // Examples of arity and nesting arity + // to find the thresholds and reproduce the behaviour of nsc + // tested manually, comparing -Xprint across compilers (ran out of time) + def examples(): Unit = + val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil + val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil + val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6")) + val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6")) + + val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array + val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array + val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7 + val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2 + + val max5 = + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + )))))))) // 7 cons + 1 nil + + val over5 = + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( List[Object]() + )))))))) // 7 cons + 1-sized array wrapping nil + + val max6 = + List[Object]( // ::( + "1", "2", List[Object]( // 1, ::(2, ::(::( + "3", "4", List[Object]( // 3, ::(4, ::(::( + List[Object]() // Nil, Nil + ) // ), Nil)) + ) // ), Nil)) + ) // ) + // 7 cons + 4 string heads + 4 nils for nested lists + + val max7 = + List[Object]( // ::( + "1", "2", List[Object]( // 1, ::(2, ::(::( + "3", "4", List[Object]( // 3, ::(4, ::(::( + "5" // 5, Nil + ) // ), Nil)) + ) // ), Nil)) + ) // ) + // 7 cons + 5 string heads + 3 nils for nested lists From 09a2c4d65574adc86da8d8ef55461d0c4e6d22d5 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Thu, 16 Nov 2023 15:10:23 +0000 Subject: [PATCH 6/7] Automate list optimisation threshhold cases [Cherry-picked 3b9f0c9052596110fce398f9f574c55f8ef765a7] --- .../tools/backend/jvm/ArrayApplyOptTest.scala | 118 ++++++++++++++---- tests/run/list-apply-eval.scala | 1 - 2 files changed, 96 insertions(+), 23 deletions(-) diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index ff6fae11dde9..37e7d5316f9d 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -201,34 +201,108 @@ class ArrayApplyOptTest extends DottyBytecodeTest { """.stripMargin } - def checkApplyAvoidsIntermediateArray(name: String)(source: String) = { + @Test def testListApplyAvoidsIntermediateArray_max1 = { + checkApplyAvoidsIntermediateArray_examples("max1"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7") + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::("7", Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max2 = { + checkApplyAvoidsIntermediateArray_examples("max2"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(Nil, Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max3 = { + checkApplyAvoidsIntermediateArray_examples("max3"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(new ::("6", Nil), Nil)))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max4 = { + checkApplyAvoidsIntermediateArray_examples("max4"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", List[Object]("5", "6")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::(new ::("5", new ::("6", Nil)), Nil))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over1 = { + checkApplyAvoidsIntermediateArray_examples("over1"): + """ def meth1: List[Object] = List("1", "2", "3", "4", "5", "6", "7", "8") + | def meth2: List[Object] = List(wrapRefArray(Array("1", "2", "3", "4", "5", "6", "7", "8"))*) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over2 = { + checkApplyAvoidsIntermediateArray_examples("over2"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) + | def meth2: List[Object] = List(wrapRefArray(Array[Object]("1", "2", "3", "4", "5", "6", "7", Nil))*) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over3 = { + checkApplyAvoidsIntermediateArray_examples("over3"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(List(wrapRefArray(Array[Object]("7"))*), Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over4 = { + checkApplyAvoidsIntermediateArray_examples("over4"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(List(wrapRefArray(Array[Object]("6", "7"))*), Nil)))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max5 = { + checkApplyAvoidsIntermediateArray_examples("max5"): + """ def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]()))))))) + | def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(Nil, Nil), Nil), Nil), Nil), Nil), Nil), Nil) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over5 = { + checkApplyAvoidsIntermediateArray_examples("over5"): + """ def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]())))))))) + | def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(List[Object](wrapRefArray(Array[Object](Nil))*), Nil), Nil), Nil), Nil), Nil), Nil), Nil) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max6 = { + checkApplyAvoidsIntermediateArray_examples("max6"): + """ def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object](List[Object]()))) + | def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::(Nil, Nil), Nil))), Nil))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over6 = { + checkApplyAvoidsIntermediateArray_examples("over6"): + """ def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object]("5"))) + | def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::("5", Nil), Nil))), Nil))) + """.stripMargin + } + + def checkApplyAvoidsIntermediateArray_examples(name: String)(body: String): Unit = { + checkApplyAvoidsIntermediateArray(s"List_$name"): + s"""import scala.collection.immutable.{ ::, Nil }, scala.runtime.ScalaRunTime.wrapRefArray + |class Foo { + |$body + |} + """.stripMargin + } + + def checkApplyAvoidsIntermediateArray(name: String)(source: String): Unit = { checkBCode(source) { dir => val clsIn = dir.lookupName("Foo.class", directory = false).input val clsNode = loadClassNode(clsIn) val meth1 = getMethod(clsNode, "meth1") val meth2 = getMethod(clsNode, "meth2") - val instructions1 = instructionsFromMethod(meth1) match - case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => - instr :+ ret - case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) => - // List.apply[?A] doesn't, strictly, return List[?A], - // because it cascades to its definition on IterableFactory - // where it returns CC[A]. The erasure of that is Object, - // which is why Erasure's Typer adds a cast to compensate. - // If we drop that cast while optimising (because using - // the constructor for :: doesn't require the cast like - // List.apply did) then then cons construction chain will - // be typed as ::. - // Unfortunately the LUB of :: and Nil.type is Product - // instead of List, so a cast remains necessary, - // across whatever causes the lub, like `if` or `try` branches. - // Therefore if we dropping the cast may cause a needed cast - // to be necessary, we shouldn't drop the cast, - // which was only motivated by the assert here. - instr :+ ret - case instr => instr - val instructions2 = instructionsFromMethod(meth2) + val instructions1 = instructionsFromMethod(meth1).filter { case TypeOp(CHECKCAST, _) => false case _ => true } + val instructions2 = instructionsFromMethod(meth2).filter { case TypeOp(CHECKCAST, _) => false case _ => true } assert(instructions1 == instructions2, s"the $name.apply method\n" + diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala index 19ebaaa69812..4cbba6d3e6c2 100644 --- a/tests/run/list-apply-eval.scala +++ b/tests/run/list-apply-eval.scala @@ -34,7 +34,6 @@ object Test: // Examples of arity and nesting arity // to find the thresholds and reproduce the behaviour of nsc - // tested manually, comparing -Xprint across compilers (ran out of time) def examples(): Unit = val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil From 063cdddf3339e9cd4843c74648141885e1bce06b Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Mon, 27 Nov 2023 16:52:56 +0000 Subject: [PATCH 7/7] Document 8 & use Option extractor over | Null [Cherry-picked bdb89d89649413aac85a51fd22830748bf9281b2] --- .../tools/dotc/transform/ArrayApply.scala | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index c8d7a6548870..98ca8f2e2b5b 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -20,15 +20,13 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description private val TransformListApplyBudgetKey = new Property.Key[Int] - private def transformListApplyBudget(using Context) = ctx.property(TransformListApplyBudgetKey).getOrElse(8) + private def transformListApplyBudget(using Context) = + ctx.property(TransformListApplyBudgetKey).getOrElse(8) // default is 8, as originally implemented in nsc - override def prepareForApply(tree: Apply)(using Context): Context = - if isSeqApply(tree) then - val args = seqApplyArgsOrNull(tree) - if args != null then - ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - args.elems.length) - else ctx - else ctx + override def prepareForApply(tree: Apply)(using Context): Context = tree match + case SeqApplyArgs(elems) => + ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - elems.length) + case _ => ctx override def transformApply(tree: Apply)(using Context): Tree = if isArrayModuleApply(tree.symbol) then @@ -44,15 +42,12 @@ class ArrayApply extends MiniPhase { case _ => tree - else if isSeqApply(tree) then - val args = seqApplyArgsOrNull(tree) - if args != null && (transformListApplyBudget > 0 || args.elems.isEmpty) then - val consed = args.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + else tree match + case SeqApplyArgs(elems) if transformListApplyBudget > 0 || elems.isEmpty => + val consed = elems.foldRight(ref(defn.NilModule)): (elem, acc) => New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) consed.cast(tree.tpe) - else tree - - else tree + case _ => tree private def isArrayModuleApply(sym: Symbol)(using Context): Boolean = sym.name == nme.apply @@ -75,14 +70,17 @@ class ArrayApply extends MiniPhase { || sym == defn.CollectionSeqType.symbol.companionModule case _ => false - private def seqApplyArgsOrNull(tree: Apply)(using Context): JavaSeqLiteral | Null = - // assumes isSeqApply(tree) - tree.args match - // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types - case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil - if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) => - rest - case _ => null + private object SeqApplyArgs: + def unapply(tree: Apply)(using Context): Option[List[Tree]] = + if isSeqApply(tree) then + tree.args match + // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil + if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) => + Some(rest.elems) + case _ => None + else None + /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])`