Skip to content

Commit

Permalink
More efficient decoding from JSON AST (#1317)
Browse files Browse the repository at this point in the history
  • Loading branch information
plokhotnyuk authored Feb 14, 2025
1 parent 198a512 commit a9875d6
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 108 deletions.
33 changes: 18 additions & 15 deletions zio-json/shared/src/main/scala-2.x/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ object DeriveJsonDecoder {
if (allFieldNames.length != allFieldNames.distinct.length) {
val aliasNames = aliases.map(_._1)
val collisions = aliasNames
.filter(alias => names.contains(alias) || aliases.count { case (a, _) => a == alias } > 1)
.filter(alias => names.contains(alias) || aliases.count(a => a._1 == alias) > 1)
.distinct
val msg = s"Field names and aliases in case class ${ctx.typeName.full} must be distinct, " +
s"alias(es) ${collisions.mkString(",")} collide with a field or another alias"
Expand Down Expand Up @@ -356,16 +356,16 @@ object DeriveJsonDecoder {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(keyValues) =>
case o: Json.Obj =>
val ps = new Array[Any](len)
for ((key, value) <- keyValues) {
namesMap.get(key) match {
o.fields.foreach { kv =>
namesMap.get(kv._1) match {
case Some(idx) =>
if (ps(idx) != null) Lexer.error("duplicate", trace)
val default = defaults(idx)
ps(idx) =
if ((default ne null) && (value eq Json.Null)) default()
else tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, value)
if ((default ne null) && (kv._2 eq Json.Null)) default()
else tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, kv._2)
case _ =>
if (no_extra) Lexer.error("invalid extra field", trace)
}
Expand Down Expand Up @@ -425,10 +425,10 @@ object DeriveJsonDecoder {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(chunk) if chunk.size == 1 =>
val keyValue = chunk.head
namesMap.get(keyValue._1) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, keyValue._2).asInstanceOf[A]
case o: Json.Obj if o.fields.length == 1 =>
val kv = o.fields(0)
namesMap.get(kv._1) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, kv._2).asInstanceOf[A]
case _ => Lexer.error("invalid disambiguator", trace)
}
case _ => Lexer.error("expected single field object", trace)
Expand Down Expand Up @@ -459,9 +459,12 @@ object DeriveJsonDecoder {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(fields) =>
fields.find { case (key, _) => key == hintfield } match {
case Some((_, Json.Str(name))) =>
case o: Json.Obj =>
o.fields.collectFirst {
case kv if kv._1 == hintfield && kv._2.isInstanceOf[Json.Str] =>
kv._2.asInstanceOf[Json.Str].value
} match {
case Some(name) =>
namesMap.get(name) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, json).asInstanceOf[A]
case _ => Lexer.error("invalid disambiguator", trace)
Expand Down Expand Up @@ -628,8 +631,8 @@ object DeriveJsonEncoder {

override def toJsonAST(a: A): Either[String, Json] = ctx.split(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).flatMap {
case Json.Obj(fields) =>
new Right(Json.Obj((hintfield -> Json.Str(names(sub.index))) +: fields)) // hint field is always first
case o: Json.Obj =>
new Right(Json.Obj((hintfield -> Json.Str(names(sub.index))) +: o.fields)) // hint field is always first
case _ =>
new Left("expected object")
}
Expand Down
28 changes: 15 additions & 13 deletions zio-json/shared/src/main/scala-3/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,16 @@ sealed class JsonDecoderDerivation(config: JsonCodecConfiguration) extends Deriv

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(keyValues) =>
case o: Json.Obj =>
val ps = new Array[Any](len)
for ((key, value) <- keyValues) {
namesMap.get(key) match {
o.fields.foreach { kv =>
namesMap.get(kv._1) match {
case Some(idx) =>
if (ps(idx) != null) Lexer.error("duplicate", trace)
val default = defaults(idx)
ps(idx) =
if ((default ne null) && (value eq Json.Null)) default()
else tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, value)
if ((default ne null) && (kv._2 eq Json.Null)) default()
else tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, kv._2)
case _ =>
if (no_extra) Lexer.error("invalid extra field", trace)
}
Expand Down Expand Up @@ -427,7 +427,7 @@ sealed class JsonDecoderDerivation(config: JsonCodecConfiguration) extends Deriv

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Str(typeName) => namesMap.get(typeName) match {
case s: Json.Str => namesMap.get(s.value) match {
case Some(idx) => tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case _ => Lexer.error("invalid enumeration value", trace)
}
Expand All @@ -453,8 +453,8 @@ sealed class JsonDecoderDerivation(config: JsonCodecConfiguration) extends Deriv

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(chunk) if chunk.size == 1 =>
val keyValue = chunk.head
case o: Json.Obj if o.fields.length == 1 =>
val keyValue = o.fields(0)
namesMap.get(keyValue._1) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, keyValue._2).asInstanceOf[A]
case _ => Lexer.error("invalid disambiguator", trace)
Expand Down Expand Up @@ -487,9 +487,11 @@ sealed class JsonDecoderDerivation(config: JsonCodecConfiguration) extends Deriv

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(fields) =>
fields.find { case (key, _) => key == hintfield } match {
case Some((_, Json.Str(name))) =>
case o: Json.Obj =>
o.fields.collectFirst { case kv if kv._1 == hintfield && kv._2.isInstanceOf[Json.Str] =>
kv._2.asInstanceOf[Json.Str].value
} match {
case Some(name) =>
namesMap.get(name) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, json).asInstanceOf[A]
case _ => Lexer.error("invalid disambiguator", trace)
Expand Down Expand Up @@ -705,11 +707,11 @@ sealed class JsonEncoderDerivation(config: JsonCodecConfiguration) extends Deriv

override final def toJsonAST(a: A): Either[String, Json] = ctx.choose(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).flatMap {
case Json.Obj(fields) =>
case o: Json.Obj =>
val name = sub.annotations.collectFirst {
case jsonHint(name) => name
}.getOrElse(jsonHintFormat(sub.typeInfo.short))
new Right(new Json.Obj((hintField -> new Json.Str(name)) +: fields)) // hint field is always first
new Right(Json.Obj((hintField -> new Json.Str(name)) +: o.fields)) // hint field is always first
case _ =>
new Left("expected object")
}
Expand Down
100 changes: 50 additions & 50 deletions zio-json/shared/src/main/scala/zio/json/JsonDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): String =
json match {
case Json.Str(value) => value
case _ => Lexer.error("expected string", trace)
case s: Json.Str => s.value
case _ => Lexer.error("expected string", trace)
}
}

Expand All @@ -283,8 +283,8 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Boolean =
json match {
case Json.Bool(value) => value
case _ => Lexer.error("expected boolean", trace)
case b: Json.Bool => b.value
case _ => Lexer.error("expected boolean", trace)
}
}

Expand All @@ -293,8 +293,8 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Char =
json match {
case Json.Str(s) if s.length == 1 => s.charAt(0)
case _ => Lexer.error("expected single character string", trace)
case s: Json.Str if s.value.length == 1 => s.value.charAt(0)
case _ => Lexer.error("expected single character string", trace)
}
}

Expand All @@ -314,13 +314,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Byte =
json match {
case Json.Num(value) =>
try value.byteValueExact
case n: Json.Num =>
try n.value.byteValueExact
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.byte(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.byte(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}

Expand All @@ -338,13 +338,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Short =
json match {
case Json.Num(value) =>
try value.shortValueExact
case n: Json.Num =>
try n.value.shortValueExact
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.short(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.short(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}

Expand All @@ -362,13 +362,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Int =
json match {
case Json.Num(value) =>
try value.intValueExact
case n: Json.Num =>
try n.value.intValueExact
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.int(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.int(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val long: JsonDecoder[Long] = new JsonDecoder[Long] {
Expand All @@ -385,13 +385,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Long =
json match {
case Json.Num(value) =>
try value.longValueExact
case n: Json.Num =>
try n.value.longValueExact
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.long(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.long(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}

Expand All @@ -409,13 +409,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): java.math.BigInteger =
json match {
case Json.Num(value) =>
try value.toBigIntegerExact
case n: Json.Num =>
try n.value.toBigIntegerExact
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.bigInteger(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.bigInteger(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val scalaBigInt: JsonDecoder[BigInt] = new JsonDecoder[BigInt] {
Expand All @@ -432,13 +432,13 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): BigInt =
json match {
case Json.Num(value) =>
try BigInt(value.toBigIntegerExact)
case n: Json.Num =>
try BigInt(n.value.toBigIntegerExact)
catch {
case ex: ArithmeticException => Lexer.error(ex.getMessage, trace)
}
case Json.Str(value) => Lexer.bigInt(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case s: Json.Str => Lexer.bigInt(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val float: JsonDecoder[Float] = new JsonDecoder[Float] {
Expand All @@ -455,9 +455,9 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Float =
json match {
case Json.Num(value) => value.floatValue
case Json.Str(value) => Lexer.float(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case n: Json.Num => n.value.floatValue
case s: Json.Str => Lexer.float(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val double: JsonDecoder[Double] = new JsonDecoder[Double] {
Expand All @@ -474,9 +474,9 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Double =
json match {
case Json.Num(value) => value.doubleValue
case Json.Str(value) => Lexer.double(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case n: Json.Num => n.value.doubleValue
case s: Json.Str => Lexer.double(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val bigDecimal: JsonDecoder[java.math.BigDecimal] = new JsonDecoder[java.math.BigDecimal] {
Expand All @@ -493,9 +493,9 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): java.math.BigDecimal =
json match {
case Json.Num(value) => value
case Json.Str(value) => Lexer.bigDecimal(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case n: Json.Num => n.value
case s: Json.Str => Lexer.bigDecimal(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
implicit val scalaBigDecimal: JsonDecoder[BigDecimal] = new JsonDecoder[BigDecimal] {
Expand All @@ -512,9 +512,9 @@ object JsonDecoder extends GeneratedTupleDecoders with DecoderLowPriority1 with

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): BigDecimal =
json match {
case Json.Num(value) => new BigDecimal(value, BigDecimal.defaultMathContext)
case Json.Str(value) => Lexer.bigDecimal(trace, new FastStringReader(value))
case _ => Lexer.error("expected number", trace)
case n: Json.Num => new BigDecimal(n.value, BigDecimal.defaultMathContext)
case s: Json.Str => Lexer.bigDecimal(trace, new FastStringReader(s.value))
case _ => Lexer.error("expected number", trace)
}
}
// Option treats empty and null values as Nothing and passes values to the decoder.
Expand Down Expand Up @@ -692,8 +692,8 @@ private[json] trait DecoderLowPriority1 extends DecoderLowPriority2 {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): Chunk[A] =
json match {
case Json.Arr(elements) =>
elements.map {
case a: Json.Arr =>
a.elements.map {
var i = 0
json =>
val span = new JsonError.ArrayAccess(i)
Expand Down Expand Up @@ -884,8 +884,8 @@ private[json] trait DecoderLowPriority3 extends DecoderLowPriority4 {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Str(value) => parseJavaTime(trace, value)
case _ => Lexer.error("expected string", trace)
case s: Json.Str => parseJavaTime(trace, s.value)
case _ => Lexer.error("expected string", trace)
}

// Commonized handling for decoding from string to java.time Class
Expand Down Expand Up @@ -915,8 +915,8 @@ private[json] trait DecoderLowPriority3 extends DecoderLowPriority4 {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): UUID =
json match {
case Json.Str(value) => parseUUID(trace, value)
case _ => Lexer.error("expected string", trace)
case s: Json.Str => parseUUID(trace, s.value)
case _ => Lexer.error("expected string", trace)
}

@inline private[this] def parseUUID(trace: List[JsonError], s: String): UUID =
Expand All @@ -932,8 +932,8 @@ private[json] trait DecoderLowPriority3 extends DecoderLowPriority4 {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): java.util.Currency =
json match {
case Json.Str(value) => parseCurrency(trace, value)
case _ => Lexer.error("expected string", trace)
case s: Json.Str => parseCurrency(trace, s.value)
case _ => Lexer.error("expected string", trace)
}

@inline private[this] def parseCurrency(trace: List[JsonError], s: String): java.util.Currency =
Expand Down
Loading

0 comments on commit a9875d6

Please sign in to comment.