diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java index 9121b60666aae..8fa46dbbd2594 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; /** A sub-class of SchemaCoder that can only encode {@link Row} instances. */ @@ -35,7 +36,12 @@ public static RowCoder of(Schema schema) { /** Override encoding positions for the given schema. */ public static void overrideEncodingPositions(UUID uuid, Map encodingPositions) { - SchemaCoder.overrideEncodingPositions(uuid, encodingPositions); + RowCoderGenerator.overrideEncodingPositions(uuid, encodingPositions); + } + + @VisibleForTesting + static void clearGeneratedRowCoders() { + RowCoderGenerator.clearRowCoderCache(); } private RowCoder(Schema schema) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java index e3bd218945bfc..7a1b16d7e91fd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.UUID; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.modifier.FieldManifestation; import net.bytebuddy.description.modifier.Ownership; @@ -53,10 +54,14 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.util.StringUtils; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A utility for automatically generating a {@link Coder} for {@link Row} objects corresponding to a @@ -109,21 +114,99 @@ public abstract class RowCoderGenerator { private static final String CODERS_FIELD_NAME = "FIELD_CODERS"; private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS"; + static class WithStackTrace { + private final T value; + private final String stackTrace; + + public WithStackTrace(T value, String stackTrace) { + this.value = value; + this.stackTrace = stackTrace; + } + + public T getValue() { + return value; + } + + public String getStackTrace() { + return stackTrace; + } + } + // Cache for Coder class that are already generated. - private static final Map> GENERATED_CODERS = Maps.newConcurrentMap(); - private static final Map> ENCODING_POSITION_OVERRIDES = - Maps.newConcurrentMap(); + @GuardedBy("cacheLock") + private static final Map>> GENERATED_CODERS = Maps.newHashMap(); + + @GuardedBy("cacheLock") + private static final Map>> ENCODING_POSITION_OVERRIDES = + Maps.newHashMap(); + + private static final Object cacheLock = new Object(); + + private static final Logger LOG = LoggerFactory.getLogger(RowCoderGenerator.class); + + private static String getStackTrace() { + return StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 10); + } public static void overrideEncodingPositions(UUID uuid, Map encodingPositions) { - ENCODING_POSITION_OVERRIDES.put(uuid, encodingPositions); + final String stackTrace = getStackTrace(); + synchronized (cacheLock) { + @Nullable + WithStackTrace> previousEncodingPositions = + ENCODING_POSITION_OVERRIDES.put( + uuid, new WithStackTrace<>(encodingPositions, stackTrace)); + @Nullable WithStackTrace> existingCoder = GENERATED_CODERS.get(uuid); + if (previousEncodingPositions == null) { + if (existingCoder != null) { + LOG.error( + "Received encoding positions for uuid {} too late after creating RowCoder. Created: {}\n Override: {}", + uuid, + existingCoder.getStackTrace(), + stackTrace); + } else { + LOG.info("Received encoding positions {} for uuid {}.", encodingPositions, uuid); + } + } else if (!previousEncodingPositions.getValue().equals(encodingPositions)) { + if (existingCoder == null) { + LOG.error( + "Received differing encoding positions for uuid {} before coder creation. Was {} at {}\n Now {} at {}", + uuid, + previousEncodingPositions.getValue(), + encodingPositions, + previousEncodingPositions.getStackTrace(), + stackTrace); + } else { + LOG.error( + "Received differing encoding positions for uuid {} after coder creation at {}\n. " + + "Was {} at {}\n Now {} at {}\n", + uuid, + existingCoder.getStackTrace(), + previousEncodingPositions.getValue(), + encodingPositions, + previousEncodingPositions.getStackTrace(), + stackTrace); + } + } + } + } + + @VisibleForTesting + static void clearRowCoderCache() { + synchronized (cacheLock) { + GENERATED_CODERS.clear(); + } } @SuppressWarnings("unchecked") public static Coder generate(Schema schema) { - // Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of nested - // coders. Using HashMap::computeIfAbsent generates ConcurrentModificationExceptions in Java 11. - Coder rowCoder = GENERATED_CODERS.get(schema.getUUID()); - if (rowCoder == null) { + String stackTrace = getStackTrace(); + UUID uuid = Preconditions.checkNotNull(schema.getUUID()); + // Avoid using computeIfAbsent which may cause issues with nested schemas. + synchronized (cacheLock) { + @Nullable WithStackTrace> existingRowCoder = GENERATED_CODERS.get(uuid); + if (existingRowCoder != null) { + return existingRowCoder.getValue(); + } TypeDescription.Generic coderType = TypeDescription.Generic.Builder.parameterizedType(Coder.class, Row.class).build(); DynamicType.Builder builder = @@ -131,8 +214,13 @@ public static Coder generate(Schema schema) { builder = implementMethods(schema, builder); int[] encodingPosToRowIndex = new int[schema.getFieldCount()]; + @Nullable + WithStackTrace> existingEncodingPositions = + ENCODING_POSITION_OVERRIDES.get(uuid); Map encodingPositions = - ENCODING_POSITION_OVERRIDES.getOrDefault(schema.getUUID(), schema.getEncodingPositions()); + existingEncodingPositions == null + ? schema.getEncodingPositions() + : existingEncodingPositions.getValue(); for (int recordIndex = 0; recordIndex < schema.getFieldCount(); ++recordIndex) { String name = schema.getField(recordIndex).getName(); int encodingPosition = encodingPositions.get(name); @@ -163,6 +251,7 @@ public static Coder generate(Schema schema) { .withParameters(Coder[].class, int[].class) .intercept(new GeneratedCoderConstructor()); + Coder rowCoder; try { rowCoder = builder @@ -179,9 +268,14 @@ public static Coder generate(Schema schema) { | InvocationTargetException e) { throw new RuntimeException("Unable to generate coder for schema " + schema, e); } - GENERATED_CODERS.put(schema.getUUID(), rowCoder); + GENERATED_CODERS.put(uuid, new WithStackTrace<>(rowCoder, stackTrace)); + LOG.debug( + "Created row coder for uuid {} with encoding positions {} at {}", + uuid, + encodingPositions, + stackTrace); + return rowCoder; } - return rowCoder; } private static class GeneratedCoderConstructor implements Implementation { @@ -326,7 +420,7 @@ static void encodeDelegate( } // Encode a bitmap for the null fields to save having to encode a bunch of nulls. - NULL_LIST_CODER.encode(scanNullFields(fieldValues), outputStream); + NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream); for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) { @Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]]; if (fieldValue != null) { @@ -348,14 +442,15 @@ static void encodeDelegate( // Figure out which fields of the Row are null, and returns a BitSet. This allows us to save // on encoding each null field separately. - private static BitSet scanNullFields(Object[] fieldValues) { + private static BitSet scanNullFields(Object[] fieldValues, int[] encodingPosToIndex) { + Preconditions.checkState(fieldValues.length == encodingPosToIndex.length); BitSet nullFields = new BitSet(fieldValues.length); - for (int idx = 0; idx < fieldValues.length; ++idx) { - if (fieldValues[idx] == null) { - nullFields.set(idx); + for (int encodingPos = 0; encodingPos < encodingPosToIndex.length; ++encodingPos) { + int fieldIndex = encodingPosToIndex[encodingPos]; + if (fieldValues[fieldIndex] == null) { + nullFields.set(encodingPos); } } - return nullFields; } } @@ -425,7 +520,7 @@ static Row decodeDelegate( // in which case we drop the extra fields. if (encodingPos < coders.length) { int rowIndex = encodingPosToIndex[encodingPos]; - if (nullFields.get(rowIndex)) { + if (nullFields.get(encodingPos)) { fieldValues[rowIndex] = null; } else { Object fieldValue = coders[encodingPos].decode(inputStream); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java index 323f4e98dc551..b93b64f7dbe87 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java @@ -164,7 +164,10 @@ public String toString() { } // Sets the schema id, and then recursively ensures that all schemas have ids set. - private static void setSchemaIds(Schema schema) { + private static void setSchemaIds(@Nullable Schema schema) { + if (schema == null) { + return; + } if (schema.getUUID() == null) { schema.setUUID(UUID.randomUUID()); } @@ -187,7 +190,7 @@ private static void setSchemaIds(FieldType fieldType) { return; case ARRAY: - case ITERABLE:; + case ITERABLE: setSchemaIds(fieldType.getCollectionElementType()); return; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java index f62a2611a1cf3..885ff8f1491ab 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java @@ -22,10 +22,12 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.UUID; import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -37,6 +39,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.junit.Assume; @@ -62,7 +65,7 @@ public void testPrimitiveTypes() throws Exception { .build(); DateTime dateTime = - new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); + new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); Row row = Row.withSchema(schema) .addValues( @@ -219,12 +222,14 @@ public FieldType getBaseType() { } @Override - public Value toBaseType(String input) { + @NonNull + public Value toBaseType(@NonNull String input) { return enumeration.valueOf(input); } @Override - public String toInputType(Value base) { + @NonNull + public String toInputType(@NonNull Value base) { return enumeration.toString(base); } } @@ -401,6 +406,129 @@ public void testEncodingPositionReorderFields() throws Exception { assertEquals(expected, decoded); } + @Test + public void testEncodingPositionReorderFieldsWithNulls() throws Exception { + Schema schema1 = + Schema.builder() + .addNullableField("f_int32", FieldType.INT32) + .addNullableField("f_string", FieldType.STRING) + .build(); + Schema schema2 = + Schema.builder() + .addNullableField("f_string", FieldType.STRING) + .addNullableField("f_int32", FieldType.INT32) + .build(); + schema2.setEncodingPositions(ImmutableMap.of("f_int32", 0, "f_string", 1)); + Row schema1row = + Row.withSchema(schema1) + .withFieldValue("f_int32", null) + .withFieldValue("f_string", "hello world!") + .build(); + + Row schema2row = + Row.withSchema(schema2) + .withFieldValue("f_int32", null) + .withFieldValue("f_string", "hello world!") + .build(); + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + RowCoder.of(schema1).encode(schema1row, os); + Row schema1to2decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray())); + assertEquals(schema2row, schema1to2decoded); + + os.reset(); + RowCoder.of(schema2).encode(schema2row, os); + Row schema2to1decoded = RowCoder.of(schema1).decode(new ByteArrayInputStream(os.toByteArray())); + assertEquals(schema1row, schema2to1decoded); + } + + @Test + public void testEncodingPositionReorderViaStaticOverride() throws Exception { + Schema schema1 = + Schema.builder() + .addNullableField("failsafeTableRowPayload", FieldType.STRING) + .addByteArrayField("payload") + .addNullableField("timestamp", FieldType.INT32) + .addNullableField("unknownFieldsPayload", FieldType.STRING) + .build(); + UUID uuid = UUID.randomUUID(); + schema1.setUUID(uuid); + + Row row = + Row.withSchema(schema1) + .addValues("", "hello world!".getBytes(StandardCharsets.UTF_8), 1, "") + .build(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + RowCoder.of(schema1).encode(row, os); + // Pretend that we are restarting and want to recover from persisted state with a compatible + // schema using the + // overridden encoding positions. + RowCoder.clearGeneratedRowCoders(); + RowCoder.overrideEncodingPositions( + uuid, + ImmutableMap.of( + "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2, "unknownFieldsPayload", 3)); + + Schema schema2 = + Schema.builder() + .addByteArrayField("payload") + .addNullableField("timestamp", FieldType.INT32) + .addNullableField("unknownFieldsPayload", FieldType.STRING) + .addNullableField("failsafeTableRowPayload", FieldType.STRING) + .build(); + schema2.setUUID(uuid); + + Row expected = + Row.withSchema(schema2) + .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1, "", "") + .build(); + Row decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray())); + assertEquals(expected, decoded); + } + + @Test + public void testEncodingPositionReorderViaStaticOverrideWithNulls() throws Exception { + Schema schema1 = + Schema.builder() + .addNullableField("failsafeTableRowPayload", FieldType.BYTES) + .addByteArrayField("payload") + .addNullableField("timestamp", FieldType.INT32) + .addNullableField("unknownFieldsPayload", FieldType.BYTES) + .build(); + UUID uuid = UUID.randomUUID(); + schema1.setUUID(uuid); + + Row row = + Row.withSchema(schema1) + .addValues(null, "hello world!".getBytes(StandardCharsets.UTF_8), 1, null) + .build(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + RowCoder.of(schema1).encode(row, os); + // Pretend that we are restarting and want to recover from persisted state with a compatible + // schema using the overridden encoding positions. + RowCoder.clearGeneratedRowCoders(); + RowCoder.overrideEncodingPositions( + uuid, + ImmutableMap.of( + "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2, "unknownFieldsPayload", 3)); + + Schema schema2 = + Schema.builder() + .addByteArrayField("payload") + .addNullableField("timestamp", FieldType.INT32) + .addNullableField("unknownFieldsPayload", FieldType.BYTES) + .addNullableField("failsafeTableRowPayload", FieldType.BYTES) + .build(); + schema2.setUUID(uuid); + + Row expected = + Row.withSchema(schema2) + .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1, null, null) + .build(); + Row decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray())); + assertEquals(expected, decoded); + } + @Test public void testEncodingPositionAddNewFields() throws Exception { Schema schema1 =