Skip to content

Commit 2ce1ff7

Browse files
l46kokcopybara-github
authored andcommitted
Handle unknown fields in messagelite
PiperOrigin-RevId: 752522622
1 parent a3f0064 commit 2ce1ff7

File tree

5 files changed

+277
-33
lines changed

5 files changed

+277
-33
lines changed

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

+51-7
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818

19+
import com.google.auto.value.AutoValue;
1920
import com.google.common.annotations.VisibleForTesting;
2021
import com.google.common.base.Defaults;
2122
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.ImmutableMap;
24+
import com.google.common.collect.Multimap;
25+
import com.google.common.collect.Multimaps;
2326
import com.google.common.primitives.UnsignedLong;
2427
import com.google.errorprone.annotations.Immutable;
2528
import com.google.protobuf.ByteString;
@@ -41,6 +44,7 @@
4144
import java.util.LinkedHashMap;
4245
import java.util.List;
4346
import java.util.Map;
47+
import java.util.TreeMap;
4448

4549
/**
4650
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -221,27 +225,33 @@ private ImmutableList<Object> readPackedRepeatedFields(
221225
private Map.Entry<Object, Object> readSingleMapEntry(
222226
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
223227
ImmutableMap<String, Object> singleMapEntry =
224-
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName());
228+
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName())
229+
.values();
225230
Object key = checkNotNull(singleMapEntry.get("key"));
226231
Object value = checkNotNull(singleMapEntry.get("value"));
227232

228233
return new AbstractMap.SimpleEntry<>(key, value);
229234
}
230235

231236
@VisibleForTesting
232-
ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
233-
throws IOException {
234-
// TODO: Handle unknown fields by collecting them into a separate map.
237+
MessageFields readAllFields(byte[] bytes, String protoTypeName) throws IOException {
235238
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
236239
CodedInputStream inputStream = CodedInputStream.newInstance(bytes);
237240

241+
Multimap<Integer, Object> unknownFields =
242+
Multimaps.newMultimap(new TreeMap<>(), ArrayList::new);
238243
ImmutableMap.Builder<String, Object> fieldValues = ImmutableMap.builder();
239244
Map<Integer, List<Object>> repeatedFieldValues = new LinkedHashMap<>();
240245
Map<Integer, Map<Object, Object>> mapFieldValues = new LinkedHashMap<>();
241246
for (int tag = inputStream.readTag(); tag != 0; tag = inputStream.readTag()) {
242247
int tagWireType = WireFormat.getTagWireType(tag);
243248
int fieldNumber = WireFormat.getTagFieldNumber(tag);
244-
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber);
249+
FieldLiteDescriptor fieldDescriptor =
250+
messageDescriptor.findByFieldNumber(fieldNumber).orElse(null);
251+
if (fieldDescriptor == null) {
252+
unknownFields.put(fieldNumber, readUnknownField(tagWireType, inputStream));
253+
continue;
254+
}
245255

246256
Object payload;
247257
switch (tagWireType) {
@@ -318,12 +328,32 @@ ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
318328

319329
// Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
320330
// we accept the last value encountered.
321-
return fieldValues.buildKeepingLast();
331+
return MessageFields.create(fieldValues.buildKeepingLast(), unknownFields);
322332
}
323333

324334
ImmutableMap<String, Object> readAllFields(MessageLite msg, String protoTypeName)
325335
throws IOException {
326-
return readAllFields(msg.toByteArray(), protoTypeName);
336+
return readAllFields(msg.toByteArray(), protoTypeName).values();
337+
}
338+
339+
private static Object readUnknownField(int tagWireType, CodedInputStream inputStream)
340+
throws IOException {
341+
switch (tagWireType) {
342+
case WireFormat.WIRETYPE_VARINT:
343+
return inputStream.readInt64();
344+
case WireFormat.WIRETYPE_FIXED64:
345+
return inputStream.readFixed64();
346+
case WireFormat.WIRETYPE_LENGTH_DELIMITED:
347+
return inputStream.readBytes();
348+
case WireFormat.WIRETYPE_FIXED32:
349+
return inputStream.readFixed32();
350+
case WireFormat.WIRETYPE_START_GROUP:
351+
case WireFormat.WIRETYPE_END_GROUP:
352+
// TODO: Support groups
353+
throw new UnsupportedOperationException("Groups are not supported");
354+
default:
355+
throw new IllegalArgumentException("Unknown wire type: " + tagWireType);
356+
}
327357
}
328358

329359
@Override
@@ -342,6 +372,20 @@ public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg
342372
return super.fromWellKnownProtoToCelValue(msg, wellKnownProto);
343373
}
344374

375+
@AutoValue
376+
@SuppressWarnings("AutoValueImmutableFields") // Unknowns are inaccessible to users.
377+
abstract static class MessageFields {
378+
379+
abstract ImmutableMap<String, Object> values();
380+
381+
abstract Multimap<Integer, Object> unknowns();
382+
383+
static MessageFields create(
384+
ImmutableMap<String, Object> fieldValues, Multimap<Integer, Object> unknownFields) {
385+
return new AutoValue_ProtoLiteCelValueConverter_MessageFields(fieldValues, unknownFields);
386+
}
387+
}
388+
345389
private ProtoLiteCelValueConverter(CelLiteDescriptorPool celLiteDescriptorPool) {
346390
this.descriptorPool = celLiteDescriptorPool;
347391
}

common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java

+216
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
import static com.google.common.truth.Truth.assertThat;
1818
import static java.nio.charset.StandardCharsets.UTF_8;
1919

20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableListMultimap;
2022
import com.google.common.collect.ImmutableSet;
23+
import com.google.common.collect.Multimap;
2124
import com.google.protobuf.ByteString;
2225
import com.google.protobuf.Duration;
26+
import com.google.protobuf.ExtensionRegistryLite;
2327
import com.google.protobuf.FloatValue;
2428
import com.google.protobuf.Int32Value;
2529
import com.google.protobuf.Int64Value;
@@ -32,9 +36,11 @@
3236
import dev.cel.common.internal.CelLiteDescriptorPool;
3337
import dev.cel.common.internal.DefaultLiteDescriptorPool;
3438
import dev.cel.common.internal.WellKnownProto;
39+
import dev.cel.common.values.ProtoLiteCelValueConverter.MessageFields;
3540
import dev.cel.expr.conformance.proto3.TestAllTypes;
3641
import dev.cel.expr.conformance.proto3.TestAllTypesProto3CelDescriptor;
3742
import java.time.Instant;
43+
import java.util.LinkedHashMap;
3844
import org.junit.Test;
3945
import org.junit.runner.RunWith;
4046

@@ -107,4 +113,214 @@ public void fromProtoMessageToCelValue_withWellKnownProto_convertsToEquivalentCe
107113

108114
assertThat(convertedCelValue).isEqualTo(testCase.celValue);
109115
}
116+
117+
/** Test cases for repeated_int64: 1L,2L,3L */
118+
@SuppressWarnings("ImmutableEnumChecker") // Test only
119+
private enum RepeatedFieldBytesTestCase {
120+
PACKED(new byte[] {(byte) 0x82, 0x2, 0x3, 0x1, 0x2, 0x3}),
121+
NON_PACKED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x80, 0x2, 0x2, (byte) 0x80, 0x2, 0x3}),
122+
// 1L is not packed, but 2L and 3L are
123+
MIXED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x82, 0x2, 0x2, 0x2, 0x3});
124+
125+
private final byte[] bytes;
126+
127+
RepeatedFieldBytesTestCase(byte[] bytes) {
128+
this.bytes = bytes;
129+
}
130+
}
131+
132+
@Test
133+
public void readAllFields_repeatedFields_packedBytesCombinations(
134+
@TestParameter RepeatedFieldBytesTestCase testCase) throws Exception {
135+
MessageFields fields =
136+
PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields(
137+
testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes");
138+
139+
assertThat(fields.values()).containsExactly("repeated_int64", ImmutableList.of(1L, 2L, 3L));
140+
}
141+
142+
/**
143+
* Unknown test with the following hypothetical fields:
144+
*
145+
* <pre>{@code
146+
* message TestAllTypes {
147+
* int64 single_int64_unknown = 2500;
148+
* fixed32 single_fixed32_unknown = 2501;
149+
* fixed64 single_fixed64_unknown = 2502;
150+
* string single_string_unknown = 2503;
151+
* repeated int64 repeated_int64_unknown = 2504;
152+
* map<string, int64> map_string_int64_unknown = 2505;
153+
* }
154+
* }</pre>
155+
*/
156+
@SuppressWarnings("ImmutableEnumChecker") // Test only
157+
private enum UnknownFieldsTestCase {
158+
INT64(new byte[] {-96, -100, 1, 1}, "2500: 1", ImmutableListMultimap.of(2500, 1L)),
159+
FIXED32(
160+
new byte[] {-83, -100, 1, 2, 0, 0, 0},
161+
"2501: 0x00000002",
162+
ImmutableListMultimap.of(2501, 2)),
163+
FIXED64(
164+
new byte[] {-79, -100, 1, 3, 0, 0, 0, 0, 0, 0, 0},
165+
"2502: 0x0000000000000003",
166+
ImmutableListMultimap.of(2502, 3L)),
167+
STRING(
168+
new byte[] {-70, -100, 1, 11, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100},
169+
"2503: \"Hello world\"",
170+
ImmutableListMultimap.of(2503, ByteString.copyFromUtf8("Hello world"))),
171+
REPEATED_INT64(
172+
new byte[] {-62, -100, 1, 2, 4, 5},
173+
"2504: \"\\004\\005\"",
174+
ImmutableListMultimap.of(2504, ByteString.copyFrom(new byte[] {4, 5}))),
175+
MAP_STRING_INT64(
176+
new byte[] {
177+
-54, -100, 1, 7, 10, 3, 102, 111, 111, 16, 4, -54, -100, 1, 7, 10, 3, 98, 97, 114, 16, 5
178+
},
179+
"2505: {\n"
180+
+ " 1: \"foo\"\n"
181+
+ " 2: 4\n"
182+
+ "}\n"
183+
+ "2505: {\n"
184+
+ " 1: \"bar\"\n"
185+
+ " 2: 5\n"
186+
+ "}",
187+
ImmutableListMultimap.of(
188+
2505,
189+
ByteString.copyFromUtf8("\n\003foo\020\004"),
190+
2505,
191+
ByteString.copyFromUtf8("\n\003bar\020\005")));
192+
193+
private final byte[] bytes;
194+
private final String formattedOutput;
195+
private final Multimap<Integer, Object> unknownMap;
196+
197+
UnknownFieldsTestCase(
198+
byte[] bytes, String formattedOutput, Multimap<Integer, Object> unknownMap) {
199+
this.bytes = bytes;
200+
this.formattedOutput = formattedOutput;
201+
this.unknownMap = unknownMap;
202+
}
203+
}
204+
205+
@Test
206+
public void unknowns_repeatedEncodedBytes_allRecordsKeptWithKeysSorted() throws Exception {
207+
// 2500: 2
208+
// 2504: \"\\004\\005\""
209+
// 2501: 0x00000002
210+
// 2500: 1
211+
byte[] bytes =
212+
new byte[] {
213+
-96, -100, 1, 2, // keep
214+
-62, -100, 1, 2, 4, 5, // keep
215+
-83, -100, 1, 2, 0, 0, 0, // keep
216+
-96, -100, 1, 1 // keep
217+
};
218+
219+
MessageFields messageFields =
220+
PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields(
221+
bytes, "cel.expr.conformance.proto3.TestAllTypes");
222+
223+
assertThat(messageFields.values()).isEmpty();
224+
assertThat(messageFields.unknowns())
225+
.containsExactly(
226+
2500, 2L, 2500, 1L, 2501, 2, 2504, ByteString.copyFrom(new byte[] {0x04, 0x05}))
227+
.inOrder();
228+
}
229+
230+
@Test
231+
public void readAllFields_unknownFields(@TestParameter UnknownFieldsTestCase testCase)
232+
throws Exception {
233+
TestAllTypes parsedMsg =
234+
TestAllTypes.parseFrom(testCase.bytes, ExtensionRegistryLite.getEmptyRegistry());
235+
236+
MessageFields messageFields =
237+
PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields(
238+
testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes");
239+
240+
assertThat(messageFields.values()).isEmpty();
241+
assertThat(messageFields.unknowns()).containsExactlyEntriesIn(testCase.unknownMap).inOrder();
242+
assertThat(parsedMsg.toString().trim()).isEqualTo(testCase.formattedOutput);
243+
}
244+
245+
/**
246+
* Tests the following message:
247+
*
248+
* <pre>{@code
249+
* TestAllTypes.newBuilder()
250+
* // Unknowns
251+
* .setSingleInt64Unknown(1L)
252+
* .setSingleFixed32Unknown(2)
253+
* .setSingleFixed64Unknown(3L)
254+
* .setSingleStringUnknown("Hello world")
255+
* .addAllRepeatedInt64Unknown(ImmutableList.of(4L, 5L))
256+
* .putMapStringInt64Unknown("foo", 4L)
257+
* .putMapStringInt64Unknown("bar", 5L)
258+
* // Known values
259+
* .putMapBoolDouble(true, 1.5d)
260+
* .putMapBoolDouble(false, 2.5d)
261+
* .build();
262+
* }</pre>
263+
*/
264+
@Test
265+
@SuppressWarnings("unchecked")
266+
public void readAllFields_unknownFieldsWithValues() throws Exception {
267+
byte[] unknownMessageBytes = {
268+
-70, 4, 11, 8, 1, 17, 0, 0, 0, 0, 0, 0, -8, 63, -70, 4, 11, 8, 0, 17, 0, 0, 0, 0, 0, 0, 4, 64,
269+
-96, -100, 1, 1, -83, -100, 1, 2, 0, 0, 0, -79, -100, 1, 3, 0, 0, 0, 0, 0, 0, 0, -70, -100, 1,
270+
11, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, -62, -100, 1, 2, 4, 5, -54, -100, 1,
271+
7, 10, 3, 102, 111, 111, 16, 4, -54, -100, 1, 7, 10, 3, 98, 97, 114, 16, 5
272+
};
273+
TestAllTypes parsedMsg =
274+
TestAllTypes.parseFrom(unknownMessageBytes, ExtensionRegistryLite.getEmptyRegistry());
275+
276+
MessageFields fields =
277+
PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields(
278+
unknownMessageBytes, "cel.expr.conformance.proto3.TestAllTypes");
279+
280+
assertThat(parsedMsg.toString())
281+
.isEqualTo(
282+
"map_bool_double {\n"
283+
+ " key: false\n"
284+
+ " value: 2.5\n"
285+
+ "}\n"
286+
+ "map_bool_double {\n"
287+
+ " key: true\n"
288+
+ " value: 1.5\n"
289+
+ "}\n"
290+
+ "2500: 1\n"
291+
+ "2501: 0x00000002\n"
292+
+ "2502: 0x0000000000000003\n"
293+
+ "2503: \"Hello world\"\n"
294+
+ "2504: \"\\004\\005\"\n"
295+
+ "2505: {\n"
296+
+ " 1: \"foo\"\n"
297+
+ " 2: 4\n"
298+
+ "}\n"
299+
+ "2505: {\n"
300+
+ " 1: \"bar\"\n"
301+
+ " 2: 5\n"
302+
+ "}\n");
303+
assertThat(fields.values()).containsKey("map_bool_double");
304+
LinkedHashMap<Boolean, Double> mapBoolDoubleValues =
305+
(LinkedHashMap<Boolean, Double>) fields.values().get("map_bool_double");
306+
assertThat(mapBoolDoubleValues).containsExactly(true, 1.5d, false, 2.5d).inOrder();
307+
Multimap<Integer, Object> unknownValues = fields.unknowns();
308+
assertThat(unknownValues)
309+
.containsExactly(
310+
2500,
311+
1L,
312+
2501,
313+
2,
314+
2502,
315+
3L,
316+
2503,
317+
ByteString.copyFromUtf8("Hello world"),
318+
2504,
319+
ByteString.copyFrom(new byte[] {0x04, 0x05}),
320+
2505,
321+
ByteString.copyFromUtf8("\n\003foo\020\004"),
322+
2505,
323+
ByteString.copyFromUtf8("\n\003bar\020\005"))
324+
.inOrder();
325+
}
110326
}

common/src/test/java/dev/cel/common/values/ProtoMessageLiteValueTest.java

-25
Original file line numberDiff line numberDiff line change
@@ -271,29 +271,4 @@ public void selectField_defaultValue(@TestParameter DefaultValueTestCase testCas
271271
assertThat(selectedValue).isEqualTo(testCase.celValue);
272272
assertThat(selectedValue.isZeroValue()).isTrue();
273273
}
274-
275-
/** Test cases for repeated_int64: 1L,2L,3L */
276-
@SuppressWarnings("ImmutableEnumChecker") // Test only
277-
private enum RepeatedFieldBytesTestCase {
278-
PACKED(new byte[] {(byte) 0x82, 0x2, 0x3, 0x1, 0x2, 0x3}),
279-
NON_PACKED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x80, 0x2, 0x2, (byte) 0x80, 0x2, 0x3}),
280-
// 1L is not packed, but 2L and 3L are
281-
MIXED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x82, 0x2, 0x2, 0x2, 0x3});
282-
283-
private final byte[] bytes;
284-
285-
RepeatedFieldBytesTestCase(byte[] bytes) {
286-
this.bytes = bytes;
287-
}
288-
}
289-
290-
@Test
291-
public void readAllFields_repeatedFields_packedBytesCombinations(
292-
@TestParameter RepeatedFieldBytesTestCase testCase) throws Exception {
293-
ImmutableMap<String, Object> fields =
294-
PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields(
295-
testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes");
296-
297-
assertThat(fields).containsExactly("repeated_int64", ImmutableList.of(1L, 2L, 3L));
298-
}
299274
}

0 commit comments

Comments
 (0)