From 3b91875fe3f23a7d03ffff3be25e754e3a24f4f4 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 11 Feb 2024 12:58:04 +0800 Subject: [PATCH] more tests pass --- unroll/plugin/src-3/UnrollPhaseScala3.scala | 47 ++++++++++----------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/unroll/plugin/src-3/UnrollPhaseScala3.scala b/unroll/plugin/src-3/UnrollPhaseScala3.scala index 77083b0..81499e3 100644 --- a/unroll/plugin/src-3/UnrollPhaseScala3.scala +++ b/unroll/plugin/src-3/UnrollPhaseScala3.scala @@ -115,11 +115,6 @@ class UnrollPhaseScala3() extends PluginPhase { newDefDef } - def isCaseFromProduct(t: Tree)(using Context) = t match{ - case defdef: DefDef => defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass) - case _ => false - } - def generateFromProduct(startParamIndex: Int, paramCount: Int, defdef: DefDef)(using Context) = { cpy.DefDef(defdef)( name = defdef.name, @@ -152,7 +147,7 @@ class UnrollPhaseScala3() extends PluginPhase { ).setDefTree } - def generateSyntheticDefs(tree: Tree)(using Context): Seq[Tree] = tree match{ + def generateSyntheticDefs(tree: Tree)(using Context): (Option[Symbol], Seq[Tree]) = tree match{ case defdef: DefDef if defdef.paramss.nonEmpty => import dotty.tools.dotc.core.NameOps.isConstructorName @@ -162,7 +157,7 @@ class UnrollPhaseScala3() extends PluginPhase { val isCaseApply = defdef.name.toString == "apply" && defdef.symbol.owner.companionClass.is(CaseClass) - val isCaseFromProduct = this.isCaseFromProduct(defdef) + val isCaseFromProduct = defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass) val annotated = if (isCaseCopy) defdef.symbol.owner.primaryConstructor @@ -172,44 +167,48 @@ class UnrollPhaseScala3() extends PluginPhase { val firstValueParamClauseIndex = annotated.paramSymss.indexWhere(!_.headOption.exists(_.isType)) - if (firstValueParamClauseIndex == -1) Nil + if (firstValueParamClauseIndex == -1) (None, Nil) else { val paramCount = annotated.paramSymss(firstValueParamClauseIndex).size annotated .paramSymss(firstValueParamClauseIndex) .indexWhere(_.annotations.exists(_.symbol.fullName.toString == "unroll.Unroll")) match{ - case -1 => Nil + case -1 => (None, Nil) case startParamIndex => if (isCaseFromProduct) { - Seq(generateFromProduct(startParamIndex, paramCount, defdef)) + (Some(defdef.symbol), Seq(generateFromProduct(startParamIndex, paramCount, defdef))) } else { - for (paramIndex <- Range(startParamIndex, paramCount)) yield { - generateSingleForwarder( - defdef, - defdef.symbol.info, - defdef.paramss, - firstValueParamClauseIndex, - paramIndex, - isCaseApply - ) - } + ( + None, + for (paramIndex <- Range(startParamIndex, paramCount)) yield { + generateSingleForwarder( + defdef, + defdef.symbol.info, + defdef.paramss, + firstValueParamClauseIndex, + paramIndex, + isCaseApply + ) + } + ) } } } - case _ => Nil + case _ => (None, Nil) } override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = { + val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip + val (None, generatedConstr) = generateSyntheticDefs(tmpl.constr) + val removed = removed0.flatten super.transformTemplate( cpy.Template(tmpl)( tmpl.constr, tmpl.parents, tmpl.derived, tmpl.self, - tmpl.body.filter(!this.isCaseFromProduct(_)) ++ - tmpl.body.flatMap(generateSyntheticDefs) ++ - generateSyntheticDefs(tmpl.constr) + tmpl.body.filter(t => !removed.contains(t.symbol)) ++ generatedDefs.flatten ++ generatedConstr ) ) }