Skip to content

Commit

Permalink
Robust throwable kryo coder (#5318)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Apr 3, 2024
1 parent 3a3df03 commit b8d1301
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 170 deletions.
12 changes: 10 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ ThisBuild / githubWorkflowAddedJobs ++= Seq(
ThisBuild / mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.testing.TransformOverride.ofSource"
),
// removal of private classes
ProblemFilters.exclude[MissingClassProblem](
"com.spotify.scio.coders.instances.kryo.GaxApiExceptionSerializer"
),
ProblemFilters.exclude[MissingClassProblem](
"com.spotify.scio.coders.instances.kryo.StatusRuntimeExceptionSerializer"
),
ProblemFilters.exclude[MissingClassProblem](
"com.spotify.scio.coders.instances.kryo.BigtableRetriesExhaustedExceptionSerializer"
)
)

Expand Down Expand Up @@ -624,8 +634,6 @@ lazy val `scio-core` = project
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion,
"com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonVersion,
"com.google.api" % "gax" % gaxVersion,
"com.google.api" % "gax-grpc" % gaxVersion,
"com.google.api" % "gax-httpjson" % gaxVersion,
"com.google.api-client" % "google-api-client" % googleApiClientVersion,
"com.google.auto.service" % "auto-service-annotations" % autoServiceVersion,
"com.google.auto.service" % "auto-service" % autoServiceVersion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

package com.spotify.scio.coders

import java.io.{EOFException, InputStream, OutputStream}
import java.nio.file.Path
import java.util.concurrent.atomic.AtomicInteger
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.{InputChunked, OutputChunked}
import com.esotericsoftware.kryo.serializers.JavaSerializer
import com.google.protobuf.{ByteString, Message}
import com.spotify.scio.coders.instances.kryo._
import com.spotify.scio.coders.instances.JavaCollectionWrappers
import com.spotify.scio.coders.instances.kryo._
import com.spotify.scio.options.ScioOptions

import java.io.{EOFException, InputStream, OutputStream}
import java.nio.file.Path
import java.util.concurrent.atomic.AtomicInteger
import com.twitter.chill._
import com.twitter.chill.algebird.AlgebirdRegistrar
import com.twitter.chill.protobuf.ProtobufSerializer
Expand All @@ -43,8 +43,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.reflect.ClassP
import org.joda.time.{DateTime, LocalDate, LocalDateTime, LocalTime}
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters._
import scala.collection.mutable
import scala.jdk.CollectionConverters._

private object KryoRegistrarLoader {
private[this] val logger = LoggerFactory.getLogger(this.getClass)
Expand Down Expand Up @@ -121,9 +121,7 @@ final private class ScioKryoRegistrar extends IKryoRegistrar {
k.forSubclass[ByteString](new ByteStringSerializer)
k.forClass(new KVSerializer)
k.forClass[io.grpc.Status](new StatusSerializer)
k.forSubclass[io.grpc.StatusRuntimeException](new StatusRuntimeExceptionSerializer)
k.forSubclass[com.google.api.gax.rpc.ApiException](new GaxApiExceptionSerializer)
k.addDefaultSerializer(classOf[Throwable], new JavaSerializer)
k.addDefaultSerializer(classOf[Throwable], new ThrowableSerializer)
()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ package com.spotify.scio.coders.instances.kryo

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import com.google.api.gax.grpc.GrpcStatusCode
import com.google.api.gax.httpjson.HttpJsonStatusCode
import com.google.api.gax.rpc.{ApiException, ApiExceptionFactory}
import com.twitter.chill.KSerializer
import io.grpc.{Metadata, Status, StatusRuntimeException}
import io.grpc.Status

private[coders] class StatusSerializer extends KSerializer[Status] {
override def write(kryo: Kryo, output: Output, status: Status): Unit = {
Expand All @@ -43,71 +40,3 @@ private[coders] class StatusSerializer extends KSerializer[Status] {
.withCause(cause)
}
}

private[coders] class StatusRuntimeExceptionSerializer extends KSerializer[StatusRuntimeException] {
private lazy val statusSer = new StatusSerializer()

override def write(kryo: Kryo, output: Output, e: StatusRuntimeException): Unit = {
kryo.writeObject(output, e.getStatus, statusSer)
kryo.writeObjectOrNull(output, e.getTrailers, classOf[Metadata])
}

override def read(
kryo: Kryo,
input: Input,
`type`: Class[StatusRuntimeException]
): StatusRuntimeException = {
val status = kryo.readObject(input, classOf[Status], statusSer)
val trailers = kryo.readObjectOrNull(input, classOf[Metadata])

new StatusRuntimeException(status, trailers)
}
}

private[coders] class GaxApiExceptionSerializer extends KSerializer[ApiException] {
private lazy val statusSer = new StatusSerializer()
override def write(kryo: Kryo, output: Output, e: ApiException): Unit = {
kryo.writeObject(output, e.getMessage)
kryo.writeClassAndObject(output, e.getCause)
e.getStatusCode match {
case grpc: GrpcStatusCode =>
kryo.writeClass(output, classOf[GrpcStatusCode])
kryo.writeObject(output, grpc.getTransportCode.toStatus, statusSer)
case http: HttpJsonStatusCode =>
kryo.writeClass(output, classOf[HttpJsonStatusCode])
kryo.writeObject(output, http.getTransportCode)
case statusCode =>
kryo.writeClass(output, statusCode.getClass)
}
kryo.writeObject(output, e.isRetryable)
// kryo.writeObjectOrNull(output, e.getErrorDetails, classOf[ErrorDetails])
}

override def read(
kryo: Kryo,
input: Input,
`type`: Class[ApiException]
): ApiException = {
val message = kryo.readObject(input, classOf[String])
val cause = kryo.readClassAndObject(input).asInstanceOf[Throwable]
val codeClass = kryo.readClass(input).getType
val code = if (codeClass == classOf[GrpcStatusCode]) {
val status = kryo.readObject(input, classOf[Status], statusSer)
GrpcStatusCode.of(status.getCode)
} else if (codeClass == classOf[HttpJsonStatusCode]) {
val status = kryo.readObject(input, classOf[Integer])
HttpJsonStatusCode.of(status)
} else {
null
}
val retryable = kryo.readObjectOrNull(input, classOf[Boolean])
// val errorDetails = kryo.readObjectOrNull(input, classOf[ErrorDetails])

ApiExceptionFactory.createException(
message,
cause,
code,
retryable
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2024 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.spotify.scio.coders.instances.kryo

import com.esotericsoftware.kryo.KryoException
import com.google.api.gax.rpc.{ApiException, ApiExceptionFactory, StatusCode}
import com.twitter.chill.{Input, KSerializer, Kryo, Output}
import io.grpc.{Status, StatusException, StatusRuntimeException}

import java.io.{InputStream, ObjectInputStream, ObjectOutputStream, OutputStream}

/**
* Java based serialization for `Throwable`. This uses replace/resolve for throwable that do not
* respect the Serializable interface:
* - io.grpc.StatusException
* - io.grpc.StatusRuntimeException
* - com.google.api.gax.rpc.ApiException
*/
private object ThrowableSerializer {
final private case class SerializableStatusException(
code: Status.Code,
desc: String,
cause: Throwable
)
final private case class SerializableStatusRuntimeException(
code: Status.Code,
desc: String,
cause: Throwable
)
final private case class SerializableApiException(
message: String,
code: StatusCode.Code,
retryable: Boolean,
cause: Throwable
)

final class ThrowableObjectOutputStream(out: OutputStream) extends ObjectOutputStream(out) {
enableReplaceObject(true)
override def replaceObject(obj: AnyRef): AnyRef = obj match {
case e: StatusException =>
SerializableStatusException(
e.getStatus.getCode,
e.getStatus.getDescription,
e.getStatus.getCause
)
case e: StatusRuntimeException =>
SerializableStatusRuntimeException(
e.getStatus.getCode,
e.getStatus.getDescription,
e.getStatus.getCause
)
case e: ApiException =>
SerializableApiException(e.getMessage, e.getStatusCode.getCode, e.isRetryable, e.getCause)
case _ => obj
}
}

final class ThrowableObjectInputStream(in: InputStream) extends ObjectInputStream(in) {
enableResolveObject(true)
override def resolveObject(obj: AnyRef): AnyRef = obj match {
case SerializableStatusException(code, desc, cause) =>
new StatusException(Status.fromCode(code).withDescription(desc).withCause(cause))
case SerializableStatusRuntimeException(code, desc, cause) =>
new StatusRuntimeException(Status.fromCode(code).withDescription(desc).withCause(cause))
case SerializableApiException(message, code, retryable, cause) =>
// generic status code. we lost transport information during serialization
val c = new StatusCode() {
override def getCode: StatusCode.Code = code
override def getTransportCode: AnyRef = null
}
ApiExceptionFactory.createException(message, cause, c, retryable)
case _ => obj
}
}
}

final private[coders] class ThrowableSerializer extends KSerializer[Throwable] {
import ThrowableSerializer._
override def write(kryo: Kryo, out: Output, obj: Throwable): Unit = {
try {
val objectStream = new ThrowableObjectOutputStream(out)
objectStream.writeObject(obj)
objectStream.flush()
} catch {
case e: Exception => throw new KryoException("Error during Java serialization.", e)
}
}

override def read(kryo: Kryo, input: Input, `type`: Class[Throwable]): Throwable =
try {
val objectStream = new ThrowableObjectInputStream(input)
objectStream.readObject.asInstanceOf[Throwable]
} catch {
case e: Exception => throw new KryoException("Error during Java deserialization.", e)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,17 @@
package com.spotify.scio.coders

import com.google.cloud.bigtable.data.v2.models.MutateRowsException
import com.google.cloud.bigtable.grpc.scanner.BigtableRetriesExhaustedException
import com.spotify.scio.bigquery.TableRow
import com.spotify.scio.coders.instances.kryo.{
BigtableRetriesExhaustedExceptionSerializer,
CoderSerializer,
MutateRowsExceptionSerializer
}
import com.spotify.scio.coders.instances.kryo.{CoderSerializer, MutateRowsExceptionSerializer}
import com.twitter.chill._
import org.apache.beam.sdk.io.gcp.bigquery.TableRowJsonCoder

@KryoRegistrar
class GcpKryoRegistrar extends IKryoRegistrar {
override def apply(k: Kryo): Unit = {
k.forClass[TableRow](new CoderSerializer(TableRowJsonCoder.of()))
k.forClass[BigtableRetriesExhaustedException](new BigtableRetriesExhaustedExceptionSerializer)
// if MutateRowsException is used as cause in another throwable,
// it will be serialized as generic InternalException gax ApiException
k.forClass[MutateRowsException](new MutateRowsExceptionSerializer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,18 @@

package com.spotify.scio.coders.instances.kryo

import com.esotericsoftware.kryo.serializers.DefaultSerializers.StringSerializer
import com.google.api.gax.rpc.ApiException
import com.google.cloud.bigtable.data.v2.models.MutateRowsException
import com.google.cloud.bigtable.grpc.scanner.BigtableRetriesExhaustedException
import com.twitter.chill._

private[coders] class BigtableRetriesExhaustedExceptionSerializer
extends KSerializer[BigtableRetriesExhaustedException] {

private lazy val stringSerializer = new StringSerializer()
private lazy val statusExceptionSerializer = new StatusRuntimeExceptionSerializer()

override def write(kryo: Kryo, output: Output, e: BigtableRetriesExhaustedException): Unit = {
kryo.writeObject(output, e.getMessage, stringSerializer)
kryo.writeObject(output, e.getCause, statusExceptionSerializer)
}

override def read(
kryo: Kryo,
input: Input,
`type`: Class[BigtableRetriesExhaustedException]
): BigtableRetriesExhaustedException = {
val message = kryo.readObject(input, classOf[String], stringSerializer)
val cause = kryo.readObject(input, classOf[Throwable], statusExceptionSerializer)
new BigtableRetriesExhaustedException(message, cause)
}
}

private[coders] class MutateRowsExceptionSerializer extends KSerializer[MutateRowsException] {
private lazy val apiExceptionSer = new GaxApiExceptionSerializer()
override def write(kryo: Kryo, output: Output, e: MutateRowsException): Unit = {
kryo.writeClassAndObject(output, e.getCause)
val failedMutations = e.getFailedMutations
kryo.writeObject(output, failedMutations.size())
failedMutations.forEach { fm =>
kryo.writeObject(output, fm.getIndex)
kryo.writeObject(output, fm.getError, apiExceptionSer)
kryo.writeObject(output, fm.getError)
}
kryo.writeObject(output, e.isRetryable)
}
Expand All @@ -67,7 +42,7 @@ private[coders] class MutateRowsExceptionSerializer extends KSerializer[MutateRo
val failedMutations = new _root_.java.util.ArrayList[MutateRowsException.FailedMutation](size)
(0 until size).foreach { _ =>
val index = kryo.readObject(input, classOf[Integer])
val error = kryo.readObject(input, classOf[ApiException], apiExceptionSer)
val error = kryo.readObject(input, classOf[ApiException])
failedMutations.add(MutateRowsException.FailedMutation.create(index, error))
}
val retryable = kryo.readObject(input, classOf[Boolean])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,13 @@ object GcpSerializerTest {
implicit val eqBigtableRetriesExhaustedException: Equality[BigtableRetriesExhaustedException] = {
case (a: BigtableRetriesExhaustedException, b: BigtableRetriesExhaustedException) =>
a.getMessage == b.getMessage &&
((Option(a.getCause), Option(b.getCause)) match {
case (None, None) => true
case (Some(ac: StatusRuntimeException), Some(bc: StatusRuntimeException)) =>
eqStatusRuntimeException.areEqual(ac, bc)
case _ =>
false
})
eqCause.areEqual(a.getCause, b.getCause)
case _ => false
}

implicit val eqMutateRowsException: Equality[MutateRowsException] = {
case (a: MutateRowsException, b: MutateRowsException) =>
// a.getCause == b.getCause &&
eqCause.areEqual(a.getCause, b.getCause) &&
a.getStatusCode == b.getStatusCode &&
a.isRetryable == b.isRetryable &&
a.getFailedMutations.size() == b.getFailedMutations.size() &&
Expand Down Expand Up @@ -79,7 +73,8 @@ class GcpSerializerTest extends AnyFlatSpec with Matchers {
val cause = new StatusRuntimeException(Status.OK)
val apiException = new InternalException(cause, GrpcStatusCode.of(Code.OK), false)
val failedMutations = List(MutateRowsException.FailedMutation.create(1, apiException))
MutateRowsException.create(cause, failedMutations.asJava, false) coderShould roundtrip()
}
val mutateRowsException = MutateRowsException.create(cause, failedMutations.asJava, false)

mutateRowsException coderShould roundtrip()
}
}
Loading

0 comments on commit b8d1301

Please sign in to comment.