Skip to content

Commit e76b51f

Browse files
authored
Fix RowCoderGenerator to use the encodingPositions when encoding and decoding the bit set representing null fields. (#32389)
1 parent e8c6a8c commit e76b51f

File tree

4 files changed

+256
-24
lines changed

4 files changed

+256
-24
lines changed

sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.beam.sdk.transforms.SerializableFunctions;
2626
import org.apache.beam.sdk.values.Row;
2727
import org.apache.beam.sdk.values.TypeDescriptors;
28+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
2829
import org.checkerframework.checker.nullness.qual.Nullable;
2930

3031
/** A sub-class of SchemaCoder that can only encode {@link Row} instances. */
@@ -35,7 +36,12 @@ public static RowCoder of(Schema schema) {
3536

3637
/** Override encoding positions for the given schema. */
3738
public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> encodingPositions) {
38-
SchemaCoder.overrideEncodingPositions(uuid, encodingPositions);
39+
RowCoderGenerator.overrideEncodingPositions(uuid, encodingPositions);
40+
}
41+
42+
@VisibleForTesting
43+
static void clearGeneratedRowCoders() {
44+
RowCoderGenerator.clearRowCoderCache();
3945
}
4046

4147
private RowCoder(Schema schema) {

sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java

Lines changed: 113 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Map;
3131
import java.util.UUID;
3232
import javax.annotation.Nullable;
33+
import javax.annotation.concurrent.GuardedBy;
3334
import net.bytebuddy.ByteBuddy;
3435
import net.bytebuddy.description.modifier.FieldManifestation;
3536
import net.bytebuddy.description.modifier.Ownership;
@@ -53,10 +54,14 @@
5354
import org.apache.beam.sdk.schemas.Schema.Field;
5455
import org.apache.beam.sdk.schemas.Schema.FieldType;
5556
import org.apache.beam.sdk.schemas.SchemaCoder;
57+
import org.apache.beam.sdk.util.StringUtils;
5658
import org.apache.beam.sdk.util.common.ReflectHelpers;
5759
import org.apache.beam.sdk.values.Row;
60+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
5861
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
5962
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
63+
import org.slf4j.Logger;
64+
import org.slf4j.LoggerFactory;
6065

6166
/**
6267
* A utility for automatically generating a {@link Coder} for {@link Row} objects corresponding to a
@@ -109,30 +114,113 @@ public abstract class RowCoderGenerator {
109114
private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
110115
private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS";
111116

117+
static class WithStackTrace<T> {
118+
private final T value;
119+
private final String stackTrace;
120+
121+
public WithStackTrace(T value, String stackTrace) {
122+
this.value = value;
123+
this.stackTrace = stackTrace;
124+
}
125+
126+
public T getValue() {
127+
return value;
128+
}
129+
130+
public String getStackTrace() {
131+
return stackTrace;
132+
}
133+
}
134+
112135
// Cache for Coder class that are already generated.
113-
private static final Map<UUID, Coder<Row>> GENERATED_CODERS = Maps.newConcurrentMap();
114-
private static final Map<UUID, Map<String, Integer>> ENCODING_POSITION_OVERRIDES =
115-
Maps.newConcurrentMap();
136+
@GuardedBy("cacheLock")
137+
private static final Map<UUID, WithStackTrace<Coder<Row>>> GENERATED_CODERS = Maps.newHashMap();
138+
139+
@GuardedBy("cacheLock")
140+
private static final Map<UUID, WithStackTrace<Map<String, Integer>>> ENCODING_POSITION_OVERRIDES =
141+
Maps.newHashMap();
142+
143+
private static final Object cacheLock = new Object();
144+
145+
private static final Logger LOG = LoggerFactory.getLogger(RowCoderGenerator.class);
146+
147+
private static String getStackTrace() {
148+
return StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 10);
149+
}
116150

117151
public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> encodingPositions) {
118-
ENCODING_POSITION_OVERRIDES.put(uuid, encodingPositions);
152+
final String stackTrace = getStackTrace();
153+
synchronized (cacheLock) {
154+
@Nullable
155+
WithStackTrace<Map<String, Integer>> previousEncodingPositions =
156+
ENCODING_POSITION_OVERRIDES.put(
157+
uuid, new WithStackTrace<>(encodingPositions, stackTrace));
158+
@Nullable WithStackTrace<Coder<Row>> existingCoder = GENERATED_CODERS.get(uuid);
159+
if (previousEncodingPositions == null) {
160+
if (existingCoder != null) {
161+
LOG.error(
162+
"Received encoding positions for uuid {} too late after creating RowCoder. Created: {}\n Override: {}",
163+
uuid,
164+
existingCoder.getStackTrace(),
165+
stackTrace);
166+
} else {
167+
LOG.info("Received encoding positions {} for uuid {}.", encodingPositions, uuid);
168+
}
169+
} else if (!previousEncodingPositions.getValue().equals(encodingPositions)) {
170+
if (existingCoder == null) {
171+
LOG.error(
172+
"Received differing encoding positions for uuid {} before coder creation. Was {} at {}\n Now {} at {}",
173+
uuid,
174+
previousEncodingPositions.getValue(),
175+
encodingPositions,
176+
previousEncodingPositions.getStackTrace(),
177+
stackTrace);
178+
} else {
179+
LOG.error(
180+
"Received differing encoding positions for uuid {} after coder creation at {}\n. "
181+
+ "Was {} at {}\n Now {} at {}\n",
182+
uuid,
183+
existingCoder.getStackTrace(),
184+
previousEncodingPositions.getValue(),
185+
encodingPositions,
186+
previousEncodingPositions.getStackTrace(),
187+
stackTrace);
188+
}
189+
}
190+
}
191+
}
192+
193+
@VisibleForTesting
194+
static void clearRowCoderCache() {
195+
synchronized (cacheLock) {
196+
GENERATED_CODERS.clear();
197+
}
119198
}
120199

121200
@SuppressWarnings("unchecked")
122201
public static Coder<Row> generate(Schema schema) {
123-
// Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of nested
124-
// coders. Using HashMap::computeIfAbsent generates ConcurrentModificationExceptions in Java 11.
125-
Coder<Row> rowCoder = GENERATED_CODERS.get(schema.getUUID());
126-
if (rowCoder == null) {
202+
String stackTrace = getStackTrace();
203+
UUID uuid = Preconditions.checkNotNull(schema.getUUID());
204+
// Avoid using computeIfAbsent which may cause issues with nested schemas.
205+
synchronized (cacheLock) {
206+
@Nullable WithStackTrace<Coder<Row>> existingRowCoder = GENERATED_CODERS.get(uuid);
207+
if (existingRowCoder != null) {
208+
return existingRowCoder.getValue();
209+
}
127210
TypeDescription.Generic coderType =
128211
TypeDescription.Generic.Builder.parameterizedType(Coder.class, Row.class).build();
129212
DynamicType.Builder<Coder> builder =
130213
(DynamicType.Builder<Coder>) BYTE_BUDDY.subclass(coderType);
131214
builder = implementMethods(schema, builder);
132215

133216
int[] encodingPosToRowIndex = new int[schema.getFieldCount()];
217+
@Nullable
218+
WithStackTrace<Map<String, Integer>> existingEncodingPositions =
219+
ENCODING_POSITION_OVERRIDES.get(uuid);
134220
Map<String, Integer> encodingPositions =
135-
ENCODING_POSITION_OVERRIDES.getOrDefault(schema.getUUID(), schema.getEncodingPositions());
221+
existingEncodingPositions == null
222+
? schema.getEncodingPositions()
223+
: existingEncodingPositions.getValue();
136224
for (int recordIndex = 0; recordIndex < schema.getFieldCount(); ++recordIndex) {
137225
String name = schema.getField(recordIndex).getName();
138226
int encodingPosition = encodingPositions.get(name);
@@ -163,6 +251,7 @@ public static Coder<Row> generate(Schema schema) {
163251
.withParameters(Coder[].class, int[].class)
164252
.intercept(new GeneratedCoderConstructor());
165253

254+
Coder<Row> rowCoder;
166255
try {
167256
rowCoder =
168257
builder
@@ -179,9 +268,14 @@ public static Coder<Row> generate(Schema schema) {
179268
| InvocationTargetException e) {
180269
throw new RuntimeException("Unable to generate coder for schema " + schema, e);
181270
}
182-
GENERATED_CODERS.put(schema.getUUID(), rowCoder);
271+
GENERATED_CODERS.put(uuid, new WithStackTrace<>(rowCoder, stackTrace));
272+
LOG.debug(
273+
"Created row coder for uuid {} with encoding positions {} at {}",
274+
uuid,
275+
encodingPositions,
276+
stackTrace);
277+
return rowCoder;
183278
}
184-
return rowCoder;
185279
}
186280

187281
private static class GeneratedCoderConstructor implements Implementation {
@@ -326,7 +420,7 @@ static void encodeDelegate(
326420
}
327421

328422
// Encode a bitmap for the null fields to save having to encode a bunch of nulls.
329-
NULL_LIST_CODER.encode(scanNullFields(fieldValues), outputStream);
423+
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
330424
for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) {
331425
@Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]];
332426
if (fieldValue != null) {
@@ -348,14 +442,15 @@ static void encodeDelegate(
348442

349443
// Figure out which fields of the Row are null, and returns a BitSet. This allows us to save
350444
// on encoding each null field separately.
351-
private static BitSet scanNullFields(Object[] fieldValues) {
445+
private static BitSet scanNullFields(Object[] fieldValues, int[] encodingPosToIndex) {
446+
Preconditions.checkState(fieldValues.length == encodingPosToIndex.length);
352447
BitSet nullFields = new BitSet(fieldValues.length);
353-
for (int idx = 0; idx < fieldValues.length; ++idx) {
354-
if (fieldValues[idx] == null) {
355-
nullFields.set(idx);
448+
for (int encodingPos = 0; encodingPos < encodingPosToIndex.length; ++encodingPos) {
449+
int fieldIndex = encodingPosToIndex[encodingPos];
450+
if (fieldValues[fieldIndex] == null) {
451+
nullFields.set(encodingPos);
356452
}
357453
}
358-
359454
return nullFields;
360455
}
361456
}
@@ -425,7 +520,7 @@ static Row decodeDelegate(
425520
// in which case we drop the extra fields.
426521
if (encodingPos < coders.length) {
427522
int rowIndex = encodingPosToIndex[encodingPos];
428-
if (nullFields.get(rowIndex)) {
523+
if (nullFields.get(encodingPos)) {
429524
fieldValues[rowIndex] = null;
430525
} else {
431526
Object fieldValue = coders[encodingPos].decode(inputStream);

sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ public String toString() {
164164
}
165165

166166
// Sets the schema id, and then recursively ensures that all schemas have ids set.
167-
private static void setSchemaIds(Schema schema) {
167+
private static void setSchemaIds(@Nullable Schema schema) {
168+
if (schema == null) {
169+
return;
170+
}
168171
if (schema.getUUID() == null) {
169172
schema.setUUID(UUID.randomUUID());
170173
}
@@ -187,7 +190,7 @@ private static void setSchemaIds(FieldType fieldType) {
187190
return;
188191

189192
case ARRAY:
190-
case ITERABLE:;
193+
case ITERABLE:
191194
setSchemaIds(fieldType.getCollectionElementType());
192195
return;
193196

0 commit comments

Comments
 (0)