diff --git a/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala b/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala index 92ac78f93..c83d11ed4 100644 --- a/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala +++ b/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala @@ -43,7 +43,7 @@ trait MacroImplicits extends MacrosCommon { this: upickle.core.Types => def macroW[T]: Writer[T] = macro MacroImplicits.applyW[T] def macroRW[T]: ReadWriter[T] = macro MacroImplicits.applyRW[ReadWriter[T]] - def macroR0[T, M[_]]: Reader[T] = macro internal.Macros.macroRImpl[T, M] - def macroW0[T, M[_]]: Writer[T] = macro internal.Macros.macroWImpl[T, M] + def macroR0[T, M[_]]: Reader[T] = macro internal.Macros2.macroRImpl[T, M] + def macroW0[T, M[_]]: Writer[T] = macro internal.Macros2.macroWImpl[T, M] } diff --git a/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala b/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala index 52350838f..3236716c3 100644 --- a/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala +++ b/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala @@ -10,6 +10,11 @@ import upickle.implicits.{MacrosCommon, key} import language.higherKinds import language.existentials +/** + * This file is deprecated and remained here for binary compatibility. + * Please use upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala instead. + */ + /** * Implementation of macros used by uPickle to serialize and deserialize * case classes automatically. You probably shouldn't need to use these @@ -177,7 +182,7 @@ object Macros { t.substituteTypes(typeParams, concrete) } else { - val TypeRef(pref, sym, _) = typeOf[Seq[Int]] + val TypeRef(pref, sym, args) = typeOf[Seq[Int]] import compat._ TypeRef(pref, sym, t.asInstanceOf[TypeRef].args) } diff --git a/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala b/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala new file mode 100644 index 000000000..0682adb92 --- /dev/null +++ b/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala @@ -0,0 +1,577 @@ +package upickle.implicits.internal + +import scala.annotation.{nowarn, StaticAnnotation} +import scala.language.experimental.macros +import compat._ + +import acyclic.file +import upickle.core.Annotator +import upickle.implicits.{MacrosCommon, flatten, key} +import language.higherKinds +import language.existentials + +/** + * Implementation of macros used by uPickle to serialize and deserialize + * case classes automatically. You probably shouldn't need to use these + * directly, since they are called implicitly when trying to read/write + * types you don't have a Reader/Writer in scope for. + */ +@nowarn("cat=deprecation") +object Macros2 { + + trait DeriveDefaults[M[_]] { + val c: scala.reflect.macros.blackbox.Context + private def fail(tpe: c.Type, s: String) = c.abort(c.enclosingPosition, s) + + import c.universe._ + private def companionTree(tpe: c.Type): Tree = { + val companionSymbol = tpe.typeSymbol.companionSymbol + + if (companionSymbol == NoSymbol && tpe.typeSymbol.isClass) { + val clsSymbol = tpe.typeSymbol.asClass + val msg = "[error] The companion symbol could not be determined for " + + s"[[${clsSymbol.name}]]. This may be due to a bug in scalac (SI-7567) " + + "that arises when a case class within a function is upickle. As a " + + "workaround, move the declaration to the module-level." + fail(tpe, msg) + } else { + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val pre = tpe.asInstanceOf[symTab.Type].prefix.asInstanceOf[Type] + c.universe.treeBuild.mkAttributedRef(pre, companionSymbol) + } + + } + + /** + * If a super-type is generic, find all the subtypes, but at the same time + * fill in all the generic type parameters that are based on the super-type's + * concrete type + */ + private def fleshedOutSubtypes(tpe: Type) = { + for{ + subtypeSym <- tpe.typeSymbol.asClass.knownDirectSubclasses.filter(!_.toString.contains("")) + if subtypeSym.isType + st = subtypeSym.asType.toType + baseClsArgs = st.baseType(tpe.typeSymbol).asInstanceOf[TypeRef].args + } yield { + tpe match{ + case ExistentialType(_, TypeRef(pre, sym, args)) => + st.substituteTypes(baseClsArgs.map(_.typeSymbol), args) + case ExistentialType(_, _) => st + case TypeRef(pre, sym, args) => + st.substituteTypes(baseClsArgs.map(_.typeSymbol), args) + } + } + } + + private def deriveObject(tpe: c.Type) = { + val mod = tpe.typeSymbol.asClass.module + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val pre = tpe.asInstanceOf[symTab.Type].prefix.asInstanceOf[Type] + val mod2 = c.universe.treeBuild.mkAttributedRef(pre, mod) + + annotate(tpe)(wrapObject(mod2)) + + } + + private[upickle] def mergeTrait(tagKey: Option[String], subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree + + private[upickle] def derive(tpe: c.Type) = { + if (tpe.typeSymbol.asClass.isTrait || (tpe.typeSymbol.asClass.isAbstractClass && !tpe.typeSymbol.isJava)) { + val derived = deriveTrait(tpe) + derived + } + else if (tpe.typeSymbol.isModuleClass) deriveObject(tpe) + else deriveClass(tpe) + } + + private def deriveTrait(tpe: c.Type): c.universe.Tree = { + val clsSymbol = tpe.typeSymbol.asClass + + if (!clsSymbol.isSealed) { + fail(tpe, s"[error] The referenced trait [[${clsSymbol.name}]] must be sealed.") + }else if (clsSymbol.knownDirectSubclasses.filter(!_.toString.contains("")).isEmpty) { + val msg = + s"The referenced trait [[${clsSymbol.name}]] does not have any sub-classes. This may " + + "happen due to a limitation of scalac (SI-7046). To work around this, " + + "try manually specifying the sealed trait picklers as described in " + + "https://com-lihaoyi.github.io/upickle/#ManualSealedTraitPicklers" + fail(tpe, msg) + }else{ + val tagKey = customKey(clsSymbol) + val subTypes = fleshedOutSubtypes(tpe).toSeq.sortBy(_.typeSymbol.fullName) + // println("deriveTrait") + val subDerives = subTypes.map(subCls => q"implicitly[${typeclassFor(subCls)}]") + // println(Console.GREEN + "subDerives " + Console.RESET + subDrivess) + val merged = mergeTrait(tagKey, subDerives, subTypes, tpe) + merged + } + } + + private[upickle] def typeclass: c.WeakTypeTag[M[_]] + + private def typeclassFor(t: Type) = { + // println("typeclassFor " + weakTypeOf[M[_]](typeclass)) + + weakTypeOf[M[_]](typeclass) match { + case TypeRef(a, b, _) => + import compat._ + TypeRef(a, b, List(t)) + case ExistentialType(_, TypeRef(a, b, _)) => + import compat._ + TypeRef(a, b, List(t)) + case x => + println("Dunno Wad Dis Typeclazz Is " + x) + println(x) + println(x.getClass) + ??? + } + } + + sealed trait Flatten + + object Flatten { + case class Class(companion: Tree, fields: List[Field], varArgs: Boolean) extends Flatten + case object Map extends Flatten + case object None extends Flatten + } + + case class Field( + name: String, + mappedName: String, + tpe: Type, + symbol: Symbol, + defaultValue: Option[Tree], + flatten: Flatten, + ) { + lazy val allFields: List[Field] = { + def loop(field: Field): List[Field] = + field.flatten match { + case Flatten.Class(_, fields, _) => fields.flatMap(loop) + case Flatten.Map => List(field) + case Flatten.None => List(field) + } + loop(this) + } + } + + private def getFields(tpe: c.Type): (c.Tree, List[Field], Boolean) = { + def applyTypeArguments(t: c.Type ): c.Type = { + val typeParams = tpe.typeSymbol.asClass.typeParams + val typeArguments = tpe.normalize.asInstanceOf[TypeRef].args + if (t.typeSymbol != definitions.RepeatedParamClass) { + t.substituteTypes(typeParams, typeArguments) + } else { + val TypeRef(pref, sym, _) = typeOf[Seq[Int]] + internal.typeRef(pref, sym, t.asInstanceOf[TypeRef].args) + } + } + + val companion = companionTree(tpe) + //tickle the companion members -- Not doing this leads to unexpected runtime behavior + //I wonder if there is an SI related to this? + companion.tpe.members.foreach(_ => ()) + tpe.members.find(x => x.isMethod && x.asMethod.isPrimaryConstructor) match { + case None => fail(tpe, "Can't find primary constructor of " + tpe) + case Some(primaryConstructor) => + val params = primaryConstructor.asMethod.paramLists.flatten + val varArgs = params.lastOption.exists(_.typeSignature.typeSymbol == definitions.RepeatedParamClass) + val fields = params.zipWithIndex.map { case (param, i) => + val name = param.name.decodedName.toString + val mappedName = customKey(param).getOrElse(name) + val tpeOfField = applyTypeArguments(param.typeSignature) + val defaultValue = if (param.asTerm.isParamWithDefault) + Some(q"$companion.${TermName("apply$default$" + (i + 1))}") + else + None + val flatten = param.annotations.find(_.tree.tpe =:= typeOf[flatten]) match { + case Some(_) => + if (tpeOfField.typeSymbol == typeOf[collection.immutable.Map[_, _]].typeSymbol) Flatten.Map + else if (tpeOfField.typeSymbol.isClass && tpeOfField.typeSymbol.asClass.isCaseClass) { + val (nestedCompanion, fields, nestedVarArgs) = getFields(tpeOfField) + Flatten.Class(nestedCompanion, fields, nestedVarArgs) + } + else fail(tpeOfField, + s"""Invalid type for flattening: $tpeOfField. + | Flatten only works on case classes and Maps""".stripMargin) + case None => + Flatten.None + } + Field(param.name.toString, mappedName, tpeOfField, param, defaultValue, flatten) + } + (companion, fields, varArgs) + } + } + + private def deriveClass(tpe: c.Type) = { + val (companion, fields, varArgs) = getFields(tpe) + // According to @retronym, this is necessary in order to force the + // default argument `apply$default$n` methods to be synthesized + companion.tpe.member(TermName("apply")).info + + val allFields = fields.flatMap(_.allFields) + validateFlattenAnnotation(allFields) + + val derive = + // Otherwise, reading and writing are kinda identical + wrapCaseN( + companion, + fields, + varArgs, + targetType = tpe, + ) + + annotate(tpe)(derive) + } + + private def validateFlattenAnnotation(fields: List[Field]): Unit = { + if (fields.count(_.flatten == Flatten.Map) > 1) { + fail(NoType, "Only one Map can be annotated with @upickle.implicits.flatten in the same level") + } + if (fields.map(_.mappedName).distinct.length != fields.length) { + fail(NoType, "There are multiple fields with the same key") + } + if (fields.exists(field => field.flatten == Flatten.Map && !(field.tpe <:< typeOf[Map[String, _]]))) { + fail(NoType, "The key type of a Map annotated with @flatten must be String.") + } + } + + /** If there is a sealed base class, annotate the derived tree in the JSON + * representation with a class label. + */ + private def annotate(tpe: c.Type)(derived: c.universe.Tree) = { + val sealedParents = tpe.baseClasses.filter(_.asClass.isSealed) + + if (sealedParents.isEmpty) derived + else { + val tagKey = MacrosCommon.tagKeyFromParents( + tpe.typeSymbol.name.toString, + sealedParents, + customKey, + (_: c.Symbol).name.toString, + fail(tpe, _), + ) + + val sealedClassSymbol: Option[Symbol] = sealedParents.find(_ == tpe.typeSymbol) + val segments = + sealedClassSymbol.toList.map(_.fullName.split('.')) ++ + sealedParents + .flatMap(_.asClass.knownDirectSubclasses) + .map(_.fullName.split('.')) + + + // -1 because even if there is only one subclass, and so no name segments + // are needed to differentiate between them, we want to keep at least + // the rightmost name segment + val identicalSegmentCount = Range(0, segments.map(_.length).max - 1) + .takeWhile(i => segments.map(_.lift(i)).distinct.size == 1) + .length + + val tagValue = customKey(tpe.typeSymbol) + .getOrElse(TypeName(tpe.typeSymbol.fullName).decodedName.toString) + + val shortTagValue = customKey(tpe.typeSymbol) + .getOrElse( + TypeName( + tpe.typeSymbol.fullName.split('.').drop(identicalSegmentCount).mkString(".") + ).decodedName.toString + ) + + val tagKeyExpr = tagKey match { + case Some(v) => q"$v" + case None => q"${c.prefix}.tagName" + } + q"${c.prefix}.annotate($derived, $tagKeyExpr, $tagValue, $shortTagValue)" + } + } + + private def customKey(sym: c.Symbol): Option[String] = { + sym.annotations + .find(_.tpe == typeOf[key]) + .flatMap(_.scalaArgs.headOption) + .map{case Literal(Constant(s)) => s.toString} + } + + private[upickle] def serializeDefaults(sym: c.Symbol): Option[Boolean] = { + sym.annotations + .find(_.tpe == typeOf[upickle.implicits.serializeDefaults]) + .flatMap(_.scalaArgs.headOption) + .map{case Literal(Constant(s)) => s.asInstanceOf[Boolean]} + } + + private[upickle] def wrapObject(obj: Tree): Tree + + private[upickle] def wrapCaseN(companion: Tree, + fields: List[Field], + varargs: Boolean, + targetType: c.Type): Tree + } + + abstract class Reading[M[_]] extends DeriveDefaults[M] { + val c: scala.reflect.macros.blackbox.Context + import c.universe._ + def wrapObject(t: c.Tree) = q"new ${c.prefix}.SingletonReader($t)" + + def wrapCaseN(companion: c.universe.Tree, fields: List[Field], varargs: Boolean, targetType: c.Type): c.universe.Tree = { + val allowUnknownKeysAnnotation = targetType.typeSymbol + .annotations + .find(_.tree.tpe == typeOf[upickle.implicits.allowUnknownKeys]) + .flatMap(_.tree.children.tail.headOption) + .map { case Literal(Constant(b)) => b.asInstanceOf[Boolean] } + + val allFields = fields.flatMap(_.allFields).toArray.filter(_.flatten != Flatten.Map) + val (hasFlattenOnMap, valueTypeOfMap) = fields.flatMap(_.allFields).find(_.flatten == Flatten.Map) match { + case Some(f) => + val TypeRef(_, _, _ :: valueType :: Nil) = f.tpe + (true, valueType) + case None => (false, NoType) + } + val numberOfFields = allFields.length + val (localReaders, aggregates) = allFields.zipWithIndex.map { case (_, idx) => + (TermName(s"localReader$idx"), TermName(s"aggregated$idx")) + }.unzip + + val fieldToId = allFields.zipWithIndex.toMap + def constructClass(companion: c.universe.Tree, fields: List[Field], varargs: Boolean): c.universe.Tree = + q""" + $companion.apply( + ..${ + fields.map { field => + field.flatten match { + case Flatten.Class(c, f, v) => constructClass(c, f, v) + case Flatten.Map => + val termName = TermName(s"aggregatedMap") + q"$termName.toMap" + case Flatten.None => + val idx = fieldToId(field) + val termName = TermName(s"aggregated$idx") + if (field == fields.last && varargs) q"$termName:_*" + else q"$termName" + } + } + } + ) + """ + + q""" + ..${ + for (i <- allFields.indices) + yield q"private[this] lazy val ${localReaders(i)} = implicitly[${c.prefix}.Reader[${allFields(i).tpe}]]" + } + ..${ + if (hasFlattenOnMap) + List( + q"private[this] lazy val localReaderMap = implicitly[${c.prefix}.Reader[$valueTypeOfMap]]", + ) + else Nil + } + new ${c.prefix}.CaseClassReader[$targetType] { + override def visitObject(length: Int, jsonableKeys: Boolean, index: Int) = new ${if (numberOfFields <= 64) tq"_root_.upickle.implicits.CaseObjectContext[$targetType]" else tq"_root_.upickle.implicits.HugeCaseObjectContext[$targetType]"}(${numberOfFields}) { + ..${ + for (i <- allFields.indices) + yield q"private[this] var ${aggregates(i)}: ${allFields(i).tpe} = _" + } + ..${ + if (hasFlattenOnMap) + List( + q"private[this] lazy val aggregatedMap: scala.collection.mutable.ListBuffer[(String, $valueTypeOfMap)] = scala.collection.mutable.ListBuffer.empty", + ) + else Nil + } + + def storeAggregatedValue(currentIndex: Int, v: Any): Unit = currentIndex match { + case ..${ + for (i <- aggregates.indices) + yield cq"$i => ${aggregates(i)} = v.asInstanceOf[${allFields(i).tpe}]" + } + case ..${ + if (hasFlattenOnMap) + List(cq"-1 => aggregatedMap += currentKey -> v.asInstanceOf[$valueTypeOfMap]") + else Nil + } + case _ => throw new java.lang.IndexOutOfBoundsException(currentIndex.toString) + } + + def visitKeyValue(s: Any) = { + storeToMap = false + currentKey = ${c.prefix}.objectAttributeKeyReadMap(s.toString).toString + currentIndex = currentKey match { + case ..${ + for (i <- allFields.indices) + yield cq"${allFields(i).mappedName} => $i" + } + case _ => + ${ + (allowUnknownKeysAnnotation, hasFlattenOnMap) match { + case (_, true) => q"storeToMap = true; -1" + case (None, false) => + q""" + if (${ c.prefix }.allowUnknownKeys) -1 + else throw new _root_.upickle.core.Abort("Unknown Key: " + s.toString) + """ + case (Some(false), false) => q"""throw new _root_.upickle.core.Abort(" Unknown Key: " + s.toString)""" + case (Some(true), false) => q"-1" + } + } + } + } + + def visitEnd(index: Int) = { + ..${ + for(i <- allFields.indices if allFields(i).defaultValue.isDefined) + yield q"this.storeValueIfNotFound($i, ${allFields(i).defaultValue.get})" + } + + // Special-case 64 because java bit shifting ignores any RHS values above 63 + // https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19 + if (${ + if (numberOfFields <= 64) q"this.checkErrorMissingKeys(${if (numberOfFields == 64) -1 else (1L << numberOfFields) - 1})" + else q"this.checkErrorMissingKeys(${numberOfFields})" + }) { + this.errorMissingKeys(${numberOfFields}, ${allFields.map(_.mappedName)}) + } + + ${constructClass(companion, fields, varargs)} + } + + def subVisitor: _root_.upickle.core.Visitor[_, _] = currentIndex match { + case -1 => + ${ + if (hasFlattenOnMap) + q"localReaderMap" + else + q"_root_.upickle.core.NoOpVisitor" + } + case ..${ + for (i <- allFields.indices) + yield cq"$i => ${localReaders(i)} " + } + case _ => throw new java.lang.IndexOutOfBoundsException(currentIndex.toString) + } + } + } + """ + } + + override def mergeTrait(tagKey: Option[String], subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = { + val tagKeyExpr = tagKey match { + case Some(v) => q"$v" + case None => q"${c.prefix}.tagName" + } + q"${c.prefix}.Reader.merge[$targetType]($tagKeyExpr, ..$subtrees)" + } + } + + abstract class Writing[M[_]] extends DeriveDefaults[M] { + val c: scala.reflect.macros.blackbox.Context + import c.universe._ + def wrapObject(obj: c.Tree) = q"new ${c.prefix}.SingletonWriter($obj)" + + def internal = q"${c.prefix}.Internal" + + def wrapCaseN(companion: c.universe.Tree, fields: List[Field], varargs: Boolean, targetType: c.Type): c.universe.Tree = { + def serDfltVals(field: Field) = { + val b: Option[Boolean] = serializeDefaults(field.symbol).orElse(serializeDefaults(targetType.typeSymbol)) + b match { + case Some(b) => q"${b}" + case None => q"${c.prefix}.serializeDefaults" + } + } + + def write(field: Field, outer: c.universe.Tree): List[c.universe.Tree] = { + val select = Select(outer, TermName(field.name)) + field.flatten match { + case Flatten.Class(_, fields, _) => + fields.flatMap(write(_, select)) + case Flatten.Map => + val TypeRef(_, _, _ :: valueType :: Nil) = field.tpe + q""" + $select.foreach { case (key, value) => + this.writeSnippetMappedName[R, $valueType]( + ctx, + key.toString, + implicitly[${c.prefix}.Writer[$valueType]], + value + ) + } + """ :: Nil + case Flatten.None => + val snippet = + q""" + this.writeSnippetMappedName[R, ${field.tpe}]( + ctx, + ${c.prefix}.objectAttributeKeyWriteMap(${field.mappedName}), + implicitly[${c.prefix}.Writer[${field.tpe}]], + $select + ) + """ + val default = if (field.defaultValue.isEmpty) snippet + else q"""if (${serDfltVals(field)} || $select != ${field.defaultValue.get}) $snippet""" + default :: Nil + } + } + + def getLength(field: Field, outer: c.universe.Tree): List[c.universe.Tree] = { + val select = Select(outer, TermName(field.name)) + field.flatten match { + case Flatten.Class(_, fields, _) => fields.flatMap(getLength(_, select)) + case Flatten.Map => q"${select}.size" :: Nil + case Flatten.None => + ( + if (field.defaultValue.isEmpty) q"1" + else q"""if (${serDfltVals(field)} || ${select} != ${field.defaultValue}.get) 1 else 0""" + ) :: Nil + } + } + + q""" + new ${c.prefix}.CaseClassWriter[$targetType]{ + def length(v: $targetType) = { + ${ + fields.flatMap(getLength(_, q"v")) + .foldLeft[Tree](q"0") { case (prev, next) => q"$prev + $next" } + } + } + override def write0[R](out: _root_.upickle.core.Visitor[_, R], v: $targetType): R = { + if (v == null) out.visitNull(-1) + else { + val ctx = out.visitObject(length(v), true, -1) + ..${fields.flatMap(write(_, q"v"))} + ctx.visitEnd(-1) + } + } + def writeToObject[R](ctx: _root_.upickle.core.ObjVisitor[_, R], + v: $targetType): Unit = { + ..${fields.flatMap(write(_, q"v"))} + } + } + """ + } + + override def mergeTrait(tagKey: Option[String], subtree: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = { + q"${c.prefix}.Writer.merge[$targetType](..$subtree)" + } + } + def macroRImpl[T, R[_]](c0: scala.reflect.macros.blackbox.Context) + (implicit e1: c0.WeakTypeTag[T], e2: c0.WeakTypeTag[R[_]]): c0.Expr[R[T]] = { + import c0.universe._ + val res = new Reading[R]{ + val c: c0.type = c0 + def typeclass = e2 + }.derive(e1.tpe) +// println(c0.universe.showCode(res)) + c0.Expr[R[T]](res) + } + + def macroWImpl[T, W[_]](c0: scala.reflect.macros.blackbox.Context) + (implicit e1: c0.WeakTypeTag[T], e2: c0.WeakTypeTag[W[_]]): c0.Expr[W[T]] = { + import c0.universe._ + val res = new Writing[W]{ + val c: c0.type = c0 + def typeclass = e2 + }.derive(e1.tpe) +// println(c0.universe.showCode(res)) + c0.Expr[W[T]](res) + } +} + diff --git a/upickle/implicits/src-3/upickle/implicits/Readers.scala b/upickle/implicits/src-3/upickle/implicits/Readers.scala index 7540cbbe8..aa93f4067 100644 --- a/upickle/implicits/src-3/upickle/implicits/Readers.scala +++ b/upickle/implicits/src-3/upickle/implicits/Readers.scala @@ -5,6 +5,7 @@ import deriving.Mirror import scala.util.NotGiven import upickle.core.{Annotator, ObjVisitor, Visitor, Abort, CurrentlyDeriving} import upickle.implicits.BaseCaseObjectContext +import scala.collection.mutable trait ReadersVersionSpecific extends MacrosCommon @@ -15,28 +16,46 @@ trait ReadersVersionSpecific abstract class CaseClassReader3[T](paramCount: Int, missingKeyCount: Long, allowUnknownKeys: Boolean, - construct: Array[Any] => T) extends CaseClassReader[T] { + construct: (Array[Any], scala.collection.mutable.Map[String, Any]) => T) extends CaseClassReader[T] { - def visitors0: Product - lazy val visitors = visitors0 - def fromProduct(p: Product): T + def visitors0: (AnyRef, Array[AnyRef]) + lazy val (visitorMap, visitors) = visitors0 + lazy val hasFlattenOnMap = visitorMap ne null def keyToIndex(x: String): Int def allKeysArray: Array[String] def storeDefaults(x: upickle.implicits.BaseCaseObjectContext): Unit trait ObjectContext extends ObjVisitor[Any, T] with BaseCaseObjectContext{ private val params = new Array[Any](paramCount) - - def storeAggregatedValue(currentIndex: Int, v: Any): Unit = params(currentIndex) = v + private val map = scala.collection.mutable.Map.empty[String, Any] + + def storeAggregatedValue(currentIndex: Int, v: Any): Unit = + if (currentIndex == -1) { + if (storeToMap) { + map(currentKey) = v + } + } else { + params(currentIndex) = v + } def subVisitor: Visitor[_, _] = - if (currentIndex == -1) upickle.core.NoOpVisitor - else visitors.productElement(currentIndex).asInstanceOf[Visitor[_, _]] + if (currentIndex == -1) { + if (hasFlattenOnMap) visitorMap.asInstanceOf[Visitor[_, _]] + else upickle.core.NoOpVisitor + } + else { + visitors(currentIndex).asInstanceOf[Visitor[_, _]] + } def visitKeyValue(v: Any): Unit = - val k = objectAttributeKeyReadMap(v.toString).toString - currentIndex = keyToIndex(k) - if (currentIndex == -1 && !allowUnknownKeys) { - throw new upickle.core.Abort("Unknown Key: " + k.toString) + storeToMap = false + currentKey = objectAttributeKeyReadMap(v.toString).toString + currentIndex = keyToIndex(currentKey) + if (currentIndex == -1) { + if (hasFlattenOnMap) { + storeToMap = true + } else if (!allowUnknownKeys) { + throw new upickle.core.Abort("Unknown Key: " + currentKey.toString) + } } def visitEnd(index: Int): T = @@ -47,7 +66,7 @@ trait ReadersVersionSpecific if (this.checkErrorMissingKeys(missingKeyCount)) this.errorMissingKeys(paramCount, allKeysArray) - construct(params) + construct(params, map) } override def visitObject(length: Int, jsonableKeys: Boolean, @@ -58,16 +77,18 @@ trait ReadersVersionSpecific inline def macroR[T](using m: Mirror.Of[T]): Reader[T] = inline m match { case m: Mirror.ProductOf[T] => + macros.validateFlattenAnnotation[T]() + val paramCount = macros.paramsCount[T] val reader = new CaseClassReader3[T]( - macros.paramsCount[T], - macros.checkErrorMissingKeysCount[T](), + paramCount, + if (paramCount <= 64) if (paramCount == 64) -1 else (1L << paramCount) - 1 + else paramCount, macros.extractIgnoreUnknownKeys[T]().headOption.getOrElse(this.allowUnknownKeys), - params => macros.applyConstructor[T](params) + (params: Array[Any], map :scala.collection.mutable.Map[String ,Any]) => macros.applyConstructor[T](params, map) ){ - override def visitors0 = compiletime.summonAll[Tuple.Map[m.MirroredElemTypes, Reader]] - override def fromProduct(p: Product): T = m.fromProduct(p) + override def visitors0 = macros.allReaders[T, Reader] override def keyToIndex(x: String): Int = macros.keyToIndex[T](x) - override def allKeysArray = macros.fieldLabels[T].map(_._2).toArray + override def allKeysArray = macros.allFieldsMappedName[T].toArray override def storeDefaults(x: upickle.implicits.BaseCaseObjectContext): Unit = macros.storeDefaults[T](x) } diff --git a/upickle/implicits/src-3/upickle/implicits/Writers.scala b/upickle/implicits/src-3/upickle/implicits/Writers.scala index 9cdae09d1..db2791446 100644 --- a/upickle/implicits/src-3/upickle/implicits/Writers.scala +++ b/upickle/implicits/src-3/upickle/implicits/Writers.scala @@ -23,7 +23,7 @@ trait WritersVersionSpecific if (v == null) out.visitNull(-1) else { val ctx = out.visitObject(length(v), true, -1) - macros.writeSnippets[R, T, Tuple.Map[m.MirroredElemTypes, Writer]]( + macros.writeSnippets[R, T, Writer]( outerThis, this, v, @@ -34,7 +34,7 @@ trait WritersVersionSpecific } def writeToObject[R](ctx: _root_.upickle.core.ObjVisitor[_, R], v: T): Unit = - macros.writeSnippets[R, T, Tuple.Map[m.MirroredElemTypes, Writer]]( + macros.writeSnippets[R, T, Writer]( outerThis, this, v, diff --git a/upickle/implicits/src-3/upickle/implicits/macros.scala b/upickle/implicits/src-3/upickle/implicits/macros.scala index 50d50ee9d..0a65eba9c 100644 --- a/upickle/implicits/src-3/upickle/implicits/macros.scala +++ b/upickle/implicits/src-3/upickle/implicits/macros.scala @@ -3,11 +3,13 @@ package upickle.implicits.macros import scala.quoted.{ given, _ } import deriving._, compiletime._ import upickle.implicits.{MacrosCommon, ReadersVersionSpecific} -type IsInt[A <: Int] = A def getDefaultParamsImpl0[T](using Quotes, Type[T]): Map[String, Expr[AnyRef]] = import quotes.reflect._ - val unwrapped = TypeRepr.of[T] match{case AppliedType(p, v) => p case t => t} + val unwrapped = TypeRepr.of[T] match { + case AppliedType(p, v) => p + case t => t + } val sym = unwrapped.typeSymbol if (!sym.isClassDef) Map.empty @@ -60,27 +62,109 @@ def extractIgnoreUnknownKeysImpl[T](using Quotes, Type[T]): Expr[List[Boolean]] .toList ) +def extractFlatten[A](using Quotes)(sym: quotes.reflect.Symbol): Boolean = + import quotes.reflect._ + sym + .annotations + .exists(_.tpe =:= TypeRepr.of[upickle.implicits.flatten]) + inline def paramsCount[T]: Int = ${paramsCountImpl[T]} def paramsCountImpl[T](using Quotes, Type[T]) = { - Expr(fieldLabelsImpl0[T].size) + import quotes.reflect._ + val fields = allFields[T] + val count = fields.filter {case (_, _, _, _, flattenMap) => !flattenMap}.length + Expr(count) +} + +inline def allReaders[T, R[_]]: (AnyRef, Array[AnyRef]) = ${allReadersImpl[T, R]} +def allReadersImpl[T, R[_]](using Quotes, Type[T], Type[R]): Expr[(AnyRef, Array[AnyRef])] = { + import quotes.reflect._ + val fields = allFields[T] + val (readerMap, readers) = fields.partitionMap { case (_, _, tpe, _, isFlattenMap) => + if (isFlattenMap) { + val valueTpe = tpe.typeArgs(1) + val readerTpe = TypeRepr.of[R].appliedTo(valueTpe) + val reader = readerTpe.asType match { + case '[t] => '{summonInline[t].asInstanceOf[AnyRef]} + } + Left(reader) + } + else { + val readerTpe = TypeRepr.of[R].appliedTo(tpe) + val reader = readerTpe.asType match { + case '[t] => '{summonInline[t].asInstanceOf[AnyRef]} + } + Right(reader) + } + } + Expr.ofTuple( + ( + readerMap.headOption.getOrElse('{null}.asInstanceOf[Expr[AnyRef]]), + '{${Expr.ofList(readers)}.toArray}, + ) + ) +} + +inline def allFieldsMappedName[T]: List[String] = ${allFieldsMappedNameImpl[T]} +def allFieldsMappedNameImpl[T](using Quotes, Type[T]): Expr[List[String]] = { + import quotes.reflect._ + Expr(allFields[T].map { case (_, label, _, _, _) => label }) } inline def storeDefaults[T](inline x: upickle.implicits.BaseCaseObjectContext): Unit = ${storeDefaultsImpl[T]('x)} def storeDefaultsImpl[T](x: Expr[upickle.implicits.BaseCaseObjectContext])(using Quotes, Type[T]) = { import quotes.reflect.* - - val statements = fieldLabelsImpl0[T] + val statements = allFields[T] + .filter(!_._5) .zipWithIndex - .map { case ((rawLabel, label), i) => - val defaults = getDefaultParamsImpl0[T] - if (defaults.contains(label)) '{${x}.storeValueIfNotFound(${Expr(i)}, ${defaults(label)})} - else '{} + .map { case ((_, _, _, default, _), i) => + default match { + case Some(defaultValue) => '{${x}.storeValueIfNotFound(${Expr(i)}, ${defaultValue})} + case None => '{} + } } Expr.block(statements, '{}) } -inline def fieldLabels[T]: List[(String, String)] = ${fieldLabelsImpl[T]} +def allFields[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, String, quotes.reflect.TypeRepr, Option[Expr[Any]], Boolean)] = { + import quotes.reflect._ + + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, defaults: Map[String, Expr[Object]]): List[(Symbol, String, TypeRepr, Option[Expr[Any]], Boolean)] = { + val flatten = extractFlatten(field) + val substitutedTypeRepr = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = substitutedTypeRepr.typeSymbol + if (flatten) { + if (isMap(substitutedTypeRepr)) { + (field, label, substitutedTypeRepr, defaults.get(label), true) :: Nil + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newDefaults) + } + case _ => + report.errorAndAbort(s"Unsupported type $typeSymbol for flattening") + } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else { + (field, label, substitutedTypeRepr, defaults.get(label), false) :: Nil + } + } + + fieldLabelsImpl0[T] + .flatMap{ (rawLabel, label) => + val defaults = getDefaultParamsImpl0[T] + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, defaults) + } +} + def fieldLabelsImpl0[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, String)] = import quotes.reflect._ val fields: List[Symbol] = TypeRepr.of[T].typeSymbol @@ -96,16 +180,14 @@ def fieldLabelsImpl0[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, Str case None => (sym, sym.name) } -def fieldLabelsImpl[T](using Quotes, Type[T]): Expr[List[(String, String)]] = - Expr.ofList(fieldLabelsImpl0[T].map((a, b) => Expr((a.name, b)))) - inline def keyToIndex[T](inline x: String): Int = ${keyToIndexImpl[T]('x)} def keyToIndexImpl[T](x: Expr[String])(using Quotes, Type[T]): Expr[Int] = { import quotes.reflect.* + val fields = allFields[T].filter { case (_, _, _, _, isFlattenMap) => !isFlattenMap } val z = Match( x.asTerm, - fieldLabelsImpl0[T].map(_._2).zipWithIndex.map{(f, i) => - CaseDef(Literal(StringConstant(f)), None, Literal(IntConstant(i))) + fields.zipWithIndex.map{case ((_, label, _, _, _), i) => + CaseDef(Literal(StringConstant(label)), None, Literal(IntConstant(i))) } ++ Seq( CaseDef(Wildcard(), None, Literal(IntConstant(-1))) ) @@ -126,71 +208,138 @@ def serDfltVals(using quotes: Quotes)(thisOuter: Expr[upickle.core.Types with up case None => '{ ${ thisOuter }.serializeDefaults } } } + def writeLengthImpl[T](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], v: Expr[T]) (using quotes: Quotes, t: Type[T]): Expr[Int] = import quotes.reflect.* + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, select: Select, defaults: Map[String, Expr[Object]]): List[Expr[Int]] = + val flatten = extractFlatten(field) + if (flatten) { + val subsitituted = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = subsitituted.typeSymbol + if (isMap(subsitituted)) { + List( + '{${select.asExprOf[Map[_, _]]}.size} + ) + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newSelect = Select.unique(select, rawLabel.name) + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newSelect, newDefaults) + } + case _ => + report.errorAndAbort("Unsupported type for flattening") + } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else if (!defaults.contains(label)) List('{1}) + else { + val serDflt = serDfltVals(thisOuter, field, classTypeRepr.typeSymbol) + List( + '{if (${serDflt} || ${select.asExprOf[Any]} != ${defaults(label)}) 1 else 0} + ) + } + fieldLabelsImpl0[T] - .map{(rawLabel, label) => + .flatMap { (rawLabel, label) => val defaults = getDefaultParamsImpl0[T] - val select = Select.unique(v.asTerm, rawLabel.name).asExprOf[Any] - - if (!defaults.contains(label)) '{1} - else { - val serDflt = serDfltVals(thisOuter, rawLabel, TypeRepr.of[T].typeSymbol) - '{if (${serDflt} || ${select} != ${defaults(label)}) 1 else 0} - } + val select = Select.unique(v.asTerm, rawLabel.name) + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, select, defaults) } .foldLeft('{0}) { case (prev, next) => '{$prev + $next} } -inline def checkErrorMissingKeysCount[T](): Long = - ${checkErrorMissingKeysCountImpl[T]()} - -def checkErrorMissingKeysCountImpl[T]()(using Quotes, Type[T]): Expr[Long] = - import quotes.reflect.* - val paramCount = fieldLabelsImpl0[T].size - if (paramCount <= 64) if (paramCount == 64) Expr(-1) else Expr((1L << paramCount) - 1) - else Expr(paramCount) - -inline def writeSnippets[R, T, WS <: Tuple](inline thisOuter: upickle.core.Types with upickle.implicits.MacrosCommon, +inline def writeSnippets[R, T, W[_]](inline thisOuter: upickle.core.Types with upickle.implicits.MacrosCommon, inline self: upickle.implicits.CaseClassReadWriters#CaseClassWriter[T], inline v: T, inline ctx: _root_.upickle.core.ObjVisitor[_, R]): Unit = - ${writeSnippetsImpl[R, T, WS]('thisOuter, 'self, 'v, 'ctx)} + ${writeSnippetsImpl[R, T, W]('thisOuter, 'self, 'v, 'ctx)} -def writeSnippetsImpl[R, T, WS <: Tuple](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], +def writeSnippetsImpl[R, T, W[_]](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], self: Expr[upickle.implicits.CaseClassReadWriters#CaseClassWriter[T]], v: Expr[T], ctx: Expr[_root_.upickle.core.ObjVisitor[_, R]]) - (using Quotes, Type[T], Type[R], Type[WS]): Expr[Unit] = + (using Quotes, Type[T], Type[R], Type[W]): Expr[Unit] = import quotes.reflect.* - Expr.block( - for (((rawLabel, label), i) <- fieldLabelsImpl0[T].zipWithIndex) yield { - - val tpe0 = TypeRepr.of[T].memberType(rawLabel).asType - tpe0 match - case '[tpe] => - val defaults = getDefaultParamsImpl0[T] - Literal(IntConstant(i)).tpe.asType match - case '[IsInt[index]] => - val select = Select.unique(v.asTerm, rawLabel.name).asExprOf[Any] - val snippet = '{ - ${self}.writeSnippetMappedName[R, tpe]( - ${ctx}, - ${thisOuter}.objectAttributeKeyWriteMap(${Expr(label)}), - summonInline[Tuple.Elem[WS, index]], - ${select}, - ) + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, select: Select, defaults: Map[String, Expr[Object]]): List[Expr[Any]] = + val flatten = extractFlatten(field) + val fieldTypeRepr = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = fieldTypeRepr.typeSymbol + if (flatten) { + if (isMap(fieldTypeRepr)) { + val (keyTpe0, valueTpe0) = fieldTypeRepr.typeArgs match { + case key :: value :: Nil => (key, value) + case _ => report.errorAndAbort(s"Unsupported type ${typeSymbol} for flattening", v.asTerm.pos) } - if (!defaults.contains(label)) snippet - else { - val serDflt = serDfltVals(thisOuter, rawLabel, TypeRepr.of[T].typeSymbol) - '{if ($serDflt || ${select} != ${defaults(label)}) $snippet} + val writerTpe0 = TypeRepr.of[W].appliedTo(valueTpe0) + (keyTpe0.asType, valueTpe0.asType, writerTpe0.asType) match { + case ('[keyTpe], '[valueTpe], '[writerTpe])=> + val snippet = '{ + ${select.asExprOf[Map[keyTpe, valueTpe]]}.foreach { (k, v) => + ${self}.writeSnippetMappedName[R, valueTpe]( + ${ctx}, + k.toString, + summonInline[writerTpe], + v, + ) + } + } + List(snippet) + } + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newSelect = Select.unique(select, rawLabel.name) + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newSelect, newDefaults) + } + case _ => + report.errorAndAbort("Unsupported type for flattening", v) } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map", v.asTerm.pos) + } + else { + val tpe0 = fieldTypeRepr + val writerTpe0 = TypeRepr.of[W].appliedTo(tpe0) + (tpe0.asType, writerTpe0.asType) match + case ('[tpe], '[writerTpe]) => + val snippet = '{ + ${self}.writeSnippetMappedName[R, tpe]( + ${ctx}, + ${thisOuter}.objectAttributeKeyWriteMap(${Expr(label)}), + summonInline[writerTpe], + ${select.asExprOf[Any]}, + ) + } + List( + if (!defaults.contains(label)) snippet + else { + val serDflt = serDfltVals(thisOuter, field, classTypeRepr.typeSymbol) + '{if ($serDflt || ${select.asExprOf[Any]} != ${defaults(label)}) $snippet} + } + ) + } - }, + Expr.block( + fieldLabelsImpl0[T] + .flatMap { (rawLabel, label) => + val defaults = getDefaultParamsImpl0[T] + val select = Select.unique(v.asTerm, rawLabel.name) + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, select, defaults) + }, '{()} ) @@ -221,11 +370,19 @@ def tagKeyImpl[T](using Quotes, Type[T])(thisOuter: Expr[upickle.core.Types with case None => '{${thisOuter}.tagName} } -inline def applyConstructor[T](params: Array[Any]): T = ${ applyConstructorImpl[T]('params) } -def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Array[Any]]): Expr[T] = +def substituteTypeArgs(using Quotes)(tpe: quotes.reflect.TypeRepr, subsitituted: quotes.reflect.TypeRepr): quotes.reflect.TypeRepr = { + import quotes.reflect._ + val constructorSym = tpe.typeSymbol.primaryConstructor + val constructorParamSymss = constructorSym.paramSymss + + val tparams0 = constructorParamSymss.flatten.filter(_.isType) + subsitituted.substituteTypes(tparams0 ,tpe.typeArgs) +} + +inline def applyConstructor[T](params: Array[Any], map: scala.collection.mutable.Map[String, Any]): T = ${ applyConstructorImpl[T]('params, 'map) } +def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Array[Any]], map: Expr[scala.collection.mutable.Map[String, Any]]): Expr[T] = import quotes.reflect._ - def apply(typeApply: Option[List[TypeRepr]]) = { - val tpe = TypeRepr.of[T] + def apply(tpe: TypeRepr, typeArgs: List[TypeRepr], offset: Int): (Term, Int) = { val companion: Symbol = tpe.classSymbol.get.companionModule val constructorSym = tpe.typeSymbol.primaryConstructor val constructorParamSymss = constructorSym.paramSymss @@ -233,39 +390,64 @@ def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Arra val (tparams0, params0) = constructorParamSymss.flatten.partition(_.isType) val constructorTpe = tpe.memberType(constructorSym).widen - val rhs = params0.zipWithIndex.map { - case (sym0, i) => - val lhs = '{$params(${ Expr(i) })} + val (rhs, nextOffset) = params0.foldLeft((List.empty[Term], offset)) { case ((terms, i), sym0) => val tpe0 = constructorTpe.memberType(sym0) - - typeApply.map(tps => tpe0.substituteTypes(tparams0, tps)).getOrElse(tpe0) match { - case AnnotatedType(AppliedType(base, Seq(arg)), x) - if x.tpe =:= defn.RepeatedAnnot.typeRef => - arg.asType match { - case '[t] => - Typed( - lhs.asTerm, - TypeTree.of(using AppliedType(defn.RepeatedParamClass.typeRef, List(arg)).asType) - ) + val appliedTpe = tpe0.substituteTypes(tparams0, typeArgs) + val typeSymbol = appliedTpe.typeSymbol + val flatten = extractFlatten(sym0) + if (flatten) { + if (isMap(appliedTpe)) { + val keyTpe0 = appliedTpe.typeArgs.head + val valueTpe0 = appliedTpe.typeArgs(1) + (keyTpe0.asType, valueTpe0.asType) match { + case ('[keyTpe], '[valueTpe]) => + val typedMap = '{${map}.asInstanceOf[collection.mutable.Map[keyTpe, valueTpe]]}.asTerm + val term = Select.unique(typedMap, "toMap") + (term :: terms, i) } - case tpe => - tpe.asType match { - case '[t] => '{ $lhs.asInstanceOf[t] }.asTerm + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + val newTpe = TypeRepr.of[t] + val (term, nextOffset) = newTpe match { + case t: AppliedType => apply(newTpe, t.args, i) + case t: TypeRef => apply(newTpe, List.empty, i) + case t: TermRef => (Ref(t.classSymbol.get.companionModule), i) + } + (term :: terms, nextOffset) + case _ => + report.errorAndAbort(s"Unsupported type $typeSymbol for flattening") } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else { + val lhs = '{$params(${ Expr(i) })} + val term = appliedTpe match { + case AnnotatedType(AppliedType(base, Seq(arg)), x) if x.tpe =:= defn.RepeatedAnnot.typeRef => + arg.asType match { + case '[t] => + Typed( + lhs.asTerm, + TypeTree.of(using AppliedType(defn.RepeatedParamClass.typeRef, List(arg)).asType) + ) + } + case tpe => + tpe.asType match { + case '[t] => '{ $lhs.asInstanceOf[t] }.asTerm + } + } + (term :: terms, i + 1) } - } - typeApply match{ - case None => Select.overloaded(Ref(companion), "apply", Nil, rhs).asExprOf[T] - case Some(args) => - Select.overloaded(Ref(companion), "apply", args, rhs).asExprOf[T] - } + (Select.overloaded(Ref(companion), "apply", typeArgs, rhs.reverse), nextOffset) } - TypeRepr.of[T] match{ - case t: AppliedType => apply(Some(t.args)) - case t: TypeRef => apply(None) + val tpe = TypeRepr.of[T] + tpe match{ + case t: AppliedType => apply(tpe, t.args, 0)._1.asExprOf[T] + case t: TypeRef => apply(tpe, List.empty, 0)._1.asExprOf[T] case t: TermRef => '{${Ref(t.classSymbol.get.companionModule).asExprOf[Any]}.asInstanceOf[T]} } @@ -389,3 +571,27 @@ def defineEnumVisitorsImpl[T0, T <: Tuple](prefix: Expr[Any], macroX: String)(us Block(allDefs.map(_._1), Ident(allDefs.head._2.termRef)).asExprOf[T0] +inline def validateFlattenAnnotation[T](): Unit = ${ validateFlattenAnnotationImpl[T] } +def validateFlattenAnnotationImpl[T](using Quotes, Type[T]): Expr[Unit] = + import quotes.reflect._ + val fields = allFields[T] + if (fields.count(_._5) > 1) { + report.errorAndAbort("Only one Map can be annotated with @upickle.implicits.flatten in the same level") + } + if (fields.map(_._2).distinct.length != fields.length) { + report.errorAndAbort("There are multiple fields with the same key") + } + if (fields.exists {case (_, _, tpe, _, isFlattenMap) => isFlattenMap && !(tpe.typeArgs.head.dealias =:= TypeRepr.of[String].dealias)}) { + report.errorAndAbort("The key type of a Map annotated with @flatten must be String.") + } + '{()} + +private def isMap(using Quotes)(tpe: quotes.reflect.TypeRepr): Boolean = { + import quotes.reflect._ + tpe.typeSymbol == TypeRepr.of[collection.immutable.Map[_, _]].typeSymbol +} + +private def isCaseClass(using Quotes)(typeSymbol: quotes.reflect.Symbol): Boolean = { + import quotes.reflect._ + typeSymbol.isClassDef && typeSymbol.flags.is(Flags.Case) +} diff --git a/upickle/implicits/src/upickle/implicits/ObjectContexts.scala b/upickle/implicits/src/upickle/implicits/ObjectContexts.scala index cac1c17e6..49f33f225 100644 --- a/upickle/implicits/src/upickle/implicits/ObjectContexts.scala +++ b/upickle/implicits/src/upickle/implicits/ObjectContexts.scala @@ -4,6 +4,9 @@ import upickle.core.ObjVisitor trait BaseCaseObjectContext { + var currentKey = "" + var storeToMap = false + def storeAggregatedValue(currentIndex: Int, v: Any): Unit def visitKey(index: Int) = _root_.upickle.core.StringVisitor @@ -21,10 +24,13 @@ abstract class CaseObjectContext[V](fieldCount: Int) extends ObjVisitor[Any, V] var found = 0L def visitValue(v: Any, index: Int): Unit = { - if (currentIndex != -1 && ((found & (1L << currentIndex)) == 0)) { + if ((currentIndex != -1) && ((found & (1L << currentIndex)) == 0)) { storeAggregatedValue(currentIndex, v) found |= (1L << currentIndex) } + else if (storeToMap) { + storeAggregatedValue(currentIndex, v) + } } def storeValueIfNotFound(i: Int, v: Any) = { @@ -53,10 +59,13 @@ abstract class HugeCaseObjectContext[V](fieldCount: Int) extends ObjVisitor[Any, var found = new Array[Long](fieldCount / 64 + 1) def visitValue(v: Any, index: Int): Unit = { - if (currentIndex != -1 && ((found(currentIndex / 64) & (1L << currentIndex)) == 0)) { + if ((currentIndex != -1) && ((found(currentIndex / 64) & (1L << currentIndex)) == 0)) { storeAggregatedValue(currentIndex, v) found(currentIndex / 64) |= (1L << currentIndex) } + else if (storeToMap) { + storeAggregatedValue(currentIndex, v) + } } def storeValueIfNotFound(i: Int, v: Any) = { diff --git a/upickle/implicits/src/upickle/implicits/key.scala b/upickle/implicits/src/upickle/implicits/key.scala index 36c013d71..2be787810 100644 --- a/upickle/implicits/src/upickle/implicits/key.scala +++ b/upickle/implicits/src/upickle/implicits/key.scala @@ -30,4 +30,16 @@ class serializeDefaults(s: Boolean) extends StaticAnnotation */ class allowUnknownKeys(b: Boolean) extends StaticAnnotation + +/** + * An annotation that, when applied to a field in a case class, flattens the fields of the + * annotated `case class` or `Map` into the parent case class during serialization. + * This means the fields will appear at the same level as the parent case class's fields + * rather than nested under the field name. During deserialization, these fields are + * grouped back into the annotated `case class` or `Map`. + * + * **Limitations**: + * - Only works with `Map` types that are subtypes of `Map[String, _]`. + * - Cannot flatten more than two `Map` instances in a same level. + */ class flatten extends StaticAnnotation diff --git a/upickle/test/src/upickle/FailureTests.scala b/upickle/test/src/upickle/FailureTests.scala index 109781a16..9a70f0d76 100644 --- a/upickle/test/src/upickle/FailureTests.scala +++ b/upickle/test/src/upickle/FailureTests.scala @@ -37,6 +37,11 @@ object WrongTag { } +case class FlattenTwoMaps(@upickle.implicits.flatten map1: Map[String, String], @upickle.implicits.flatten map2: Map[String, String]) +case class ConflictingKeys(i: Int, @upickle.implicits.flatten cm: ConflictingMessage) +case class ConflictingMessage(i: Int) +case class MapWithNoneStringKey(@upickle.implicits.flatten map: Map[ConflictingMessage, String]) + object TaggedCustomSerializer{ sealed trait BooleanOrInt @@ -265,6 +270,9 @@ object FailureTests extends TestSuite { // compileError("""read[Array[Object]]("")""").msg // Make sure this doesn't hang the compiler =/ compileError("implicitly[upickle.default.Reader[Nothing]]") + compileError("upickle.default.macroRW[FlattenTwoMaps]") + compileError("upickle.default.macroRW[ConflictingKeys]") + compileError("upickle.default.macroRW[MapWithNoneStringKey]") } test("expWholeNumbers"){ upickle.default.read[Byte]("0e0") ==> 0.toByte diff --git a/upickle/test/src/upickle/MacroTests.scala b/upickle/test/src/upickle/MacroTests.scala index 2c9736317..9217e15ac 100644 --- a/upickle/test/src/upickle/MacroTests.scala +++ b/upickle/test/src/upickle/MacroTests.scala @@ -145,25 +145,63 @@ object TagName{ implicit val fooRw: TagNamePickler.ReadWriter[Foo] = TagNamePickler.macroRW } -case class Pagination(limit: Int, offset: Int, total: Int) +object Flatten { + case class FlattenTest(i: Int, s: String, @upickle.implicits.flatten n: Nested, @upickle.implicits.flatten n2: Nested2) -object Pagination { - implicit val rw: RW[Pagination] = upickle.default.macroRW -} + object FlattenTest { + implicit val rw: RW[FlattenTest] = upickle.default.macroRW + } -case class Users(Ids: List[Int], @upickle.implicits.flatten pagination: Pagination) + case class Nested(d: Double, @upickle.implicits.flatten m: Map[String, Int]) -object Users { - implicit val rw: RW[Users] = upickle.default.macroRW -} + object Nested { + implicit val rw: RW[Nested] = upickle.default.macroRW + } + + case class Nested2(name: String) + + object Nested2 { + implicit val rw: RW[Nested2] = upickle.default.macroRW + } + + case class FlattenTestWithType[T](i: Int, @upickle.implicits.flatten t: T) + + object FlattenTestWithType { + // implicit def rw[T: RW]: RW[FlattenTestWithType[T]] = upickle.default.macroRW + implicit val rw: RW[FlattenTestWithType[Nested]] = upickle.default.macroRW + } + + case class InnerMost(a: String, b: Int) + + object InnerMost { + implicit val rw: RW[InnerMost] = upickle.default.macroRW + } + + case class Inner(@upickle.implicits.flatten innerMost: InnerMost, c: Boolean) + + object Inner { + implicit val rw: RW[Inner] = upickle.default.macroRW + } -case class PackageManifest( - name: String, - @upickle.implicits.flatten otherStuff: Map[String, ujson.Value] - ) + case class Outer(d: Double, @upickle.implicits.flatten inner: Inner) -object PackageManifest { - implicit val rw: RW[PackageManifest] = upickle.default.macroRW + object Outer { + implicit val rw: RW[Outer] = upickle.default.macroRW + } + + case class HasMap(@upickle.implicits.flatten map: Map[String, String], i: Int) + object HasMap { + implicit val rw: RW[HasMap] = upickle.default.macroRW + } + + case class FlattenWithDefault(i: Int, @upickle.implicits.flatten n: NestedWithDefault) + object FlattenWithDefault { + implicit val rw: RW[FlattenWithDefault] = upickle.default.macroRW + } + case class NestedWithDefault(k: Int = 100, l: String) + object NestedWithDefault { + implicit val rw: RW[NestedWithDefault] = upickle.default.macroRW + } } object MacroTests extends TestSuite { @@ -172,7 +210,7 @@ object MacroTests extends TestSuite { // case class A_(objects: Option[C_]); case class C_(nodes: Option[C_]) // implicitly[Reader[A_]] -// implicitly[upickle.old.Writer[upickle.MixedIn.Obj.ClsB]code] +// implicitly[upickle.old.Writer[upickle.MixedIn.Obj.ClsB]] // println(write(ADTs.ADTc(1, "lol", (1.1, 1.2)))) // implicitly[upickle.old.Writer[ADTs.ADTc]] @@ -904,9 +942,39 @@ object MacroTests extends TestSuite { } test("flatten"){ - val a = Users(List(1, 2, 3), Pagination(10, 20, 30)) - upickle.default.write[Users](a) ==> """{"Ids":[1,2,3],"limit":10,"offset":20,"total":30}""" + import Flatten._ + val a = FlattenTest(10, "test", Nested(3.0, Map("one" -> 1, "two" -> 2)), Nested2("hello")) + rw(a, """{"i":10,"s":"test","d":3,"one":1,"two":2,"name":"hello"}""") + } + + test("flattenTypeParam"){ + import Flatten._ + val a = FlattenTestWithType[Nested](10, Nested(5.0, Map("one" -> 1, "two" -> 2))) + rw(a, """{"i":10,"d":5,"one":1,"two":2}""") } + test("nestedFlatten") { + import Flatten._ + val value = Outer(1.1, Inner(InnerMost("test", 42), true)) + rw(value, """{"d":1.1,"a":"test","b":42,"c":true}""") + } + + test("flattenWithMap") { + import Flatten._ + val value = HasMap(Map("key1" -> "value1", "key2" -> "value2"), 10) + rw(value, """{"key1":"value1","key2":"value2","i":10}""") + } + + test("flattenEmptyMap") { + import Flatten._ + val value = HasMap(Map.empty, 10) + rw(value, """{"i":10}""") + } + + test("flattenWithDefaults") { + import Flatten._ + val value = FlattenWithDefault(10, NestedWithDefault(l = "default")) + rw(value, """{"i":10,"l":"default"}""") + } } }