diff --git a/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java index a6933ac7..db20b966 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java @@ -16,6 +16,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Defaults; import com.google.common.collect.ImmutableList; @@ -221,7 +222,8 @@ private ImmutableList readPackedRepeatedFields( private Map.Entry readSingleMapEntry( CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException { ImmutableMap singleMapEntry = - readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName()); + readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName()) + .values(); Object key = checkNotNull(singleMapEntry.get("key")); Object value = checkNotNull(singleMapEntry.get("value")); @@ -229,19 +231,23 @@ private Map.Entry readSingleMapEntry( } @VisibleForTesting - ImmutableMap readAllFields(byte[] bytes, String protoTypeName) - throws IOException { - // TODO: Handle unknown fields by collecting them into a separate map. + MessageFields readAllFields(byte[] bytes, String protoTypeName) throws IOException { MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName); CodedInputStream inputStream = CodedInputStream.newInstance(bytes); + ImmutableMap.Builder unknownFields = ImmutableMap.builder(); ImmutableMap.Builder fieldValues = ImmutableMap.builder(); Map> repeatedFieldValues = new LinkedHashMap<>(); Map> mapFieldValues = new LinkedHashMap<>(); for (int tag = inputStream.readTag(); tag != 0; tag = inputStream.readTag()) { int tagWireType = WireFormat.getTagWireType(tag); int fieldNumber = WireFormat.getTagFieldNumber(tag); - FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber); + FieldLiteDescriptor fieldDescriptor = + messageDescriptor.findByFieldNumber(fieldNumber).orElse(null); + if (fieldDescriptor == null) { + unknownFields.put(fieldNumber, readUnknownField(tagWireType, inputStream)); + continue; + } Object payload; switch (tagWireType) { @@ -318,12 +324,32 @@ ImmutableMap readAllFields(byte[] bytes, String protoTypeName) // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields, // we accept the last value encountered. - return fieldValues.buildKeepingLast(); + return MessageFields.create(fieldValues.buildKeepingLast(), unknownFields.buildKeepingLast()); } ImmutableMap readAllFields(MessageLite msg, String protoTypeName) throws IOException { - return readAllFields(msg.toByteArray(), protoTypeName); + return readAllFields(msg.toByteArray(), protoTypeName).values(); + } + + private static Object readUnknownField(int tagWireType, CodedInputStream inputStream) + throws IOException { + switch (tagWireType) { + case WireFormat.WIRETYPE_VARINT: + return inputStream.readInt64(); + case WireFormat.WIRETYPE_FIXED64: + return inputStream.readFixed64(); + case WireFormat.WIRETYPE_LENGTH_DELIMITED: + return inputStream.readBytes(); + case WireFormat.WIRETYPE_FIXED32: + return inputStream.readFixed32(); + case WireFormat.WIRETYPE_START_GROUP: + case WireFormat.WIRETYPE_END_GROUP: + // TODO: Support groups + throw new UnsupportedOperationException("Groups are not supported"); + default: + throw new IllegalArgumentException("Unknown wire type: " + tagWireType); + } } @Override @@ -342,6 +368,19 @@ public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg return super.fromWellKnownProtoToCelValue(msg, wellKnownProto); } + @AutoValue + abstract static class MessageFields { + + abstract ImmutableMap values(); + + abstract ImmutableMap unknowns(); + + static MessageFields create( + ImmutableMap fieldValues, ImmutableMap unknownFields) { + return new AutoValue_ProtoLiteCelValueConverter_MessageFields(fieldValues, unknownFields); + } + } + private ProtoLiteCelValueConverter(CelLiteDescriptorPool celLiteDescriptorPool) { this.descriptorPool = celLiteDescriptorPool; } diff --git a/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java b/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java index c1471b93..680f9b6f 100644 --- a/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java +++ b/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java @@ -17,9 +17,12 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.protobuf.ByteString; import com.google.protobuf.Duration; +import com.google.protobuf.ExtensionRegistryLite; import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; @@ -32,9 +35,11 @@ import dev.cel.common.internal.CelLiteDescriptorPool; import dev.cel.common.internal.DefaultLiteDescriptorPool; import dev.cel.common.internal.WellKnownProto; +import dev.cel.common.values.ProtoLiteCelValueConverter.MessageFields; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.expr.conformance.proto3.TestAllTypesProto3CelDescriptor; import java.time.Instant; +import java.util.LinkedHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -107,4 +112,181 @@ public void fromProtoMessageToCelValue_withWellKnownProto_convertsToEquivalentCe assertThat(convertedCelValue).isEqualTo(testCase.celValue); } + + /** Test cases for repeated_int64: 1L,2L,3L */ + @SuppressWarnings("ImmutableEnumChecker") // Test only + private enum RepeatedFieldBytesTestCase { + PACKED(new byte[] {(byte) 0x82, 0x2, 0x3, 0x1, 0x2, 0x3}), + NON_PACKED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x80, 0x2, 0x2, (byte) 0x80, 0x2, 0x3}), + // 1L is not packed, but 2L and 3L are + MIXED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x82, 0x2, 0x2, 0x2, 0x3}); + + private final byte[] bytes; + + RepeatedFieldBytesTestCase(byte[] bytes) { + this.bytes = bytes; + } + } + + @Test + public void readAllFields_repeatedFields_packedBytesCombinations( + @TestParameter RepeatedFieldBytesTestCase testCase) throws Exception { + MessageFields fields = + PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields( + testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes"); + + assertThat(fields.values()).containsExactly("repeated_int64", ImmutableList.of(1L, 2L, 3L)); + } + + /** + * Unknown test with the following hypothetical fields: + * + *
{@code
+   * message TestAllTypes {
+   *   int64 single_int64_unknown = 2500;
+   *   fixed32 single_fixed32_unknown = 2501;
+   *   fixed64 single_fixed64_unknown = 2502;
+   *   string single_string_unknown = 2503;
+   *   repeated int64 repeated_int64_unknown = 2504;
+   *   map map_string_int64_unknown = 2505;
+   * }
+   * }
+ */ + @SuppressWarnings("ImmutableEnumChecker") // Test only + private enum UnknownFieldsTestCase { + INT64(new byte[] {-96, -100, 1, 1}, "2500: 1", ImmutableMap.of(2500, 1L)), + FIXED32(new byte[] {-83, -100, 1, 2, 0, 0, 0}, "2501: 0x00000002", ImmutableMap.of(2501, 2)), + FIXED64( + new byte[] {-79, -100, 1, 3, 0, 0, 0, 0, 0, 0, 0}, + "2502: 0x0000000000000003", + ImmutableMap.of(2502, 3L)), + STRING( + new byte[] {-70, -100, 1, 11, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100}, + "2503: \"Hello world\"", + ImmutableMap.of(2503, ByteString.copyFromUtf8("Hello world"))), + REPEATED_INT64( + new byte[] {-62, -100, 1, 2, 4, 5}, + "2504: \"\\004\\005\"", + ImmutableMap.of(2504, ByteString.copyFrom(new byte[] {4, 5}))), + MAP_STRING_INT64( + new byte[] { + -54, -100, 1, 7, 10, 3, 102, 111, 111, 16, 4, -54, -100, 1, 7, 10, 3, 98, 97, 114, 16, 5 + }, + "2505: {\n" + + " 1: \"foo\"\n" + + " 2: 4\n" + + "}\n" + + "2505: {\n" + + " 1: \"bar\"\n" + + " 2: 5\n" + + "}", + ImmutableMap.of(2505, ByteString.copyFrom(new byte[] {10, 3, 98, 97, 114, 16, 5}))), + ; + + private final byte[] bytes; + private final String formattedOutput; + private final ImmutableMap unknownMap; + + UnknownFieldsTestCase( + byte[] bytes, String formattedOutput, ImmutableMap unknownMap) { + this.bytes = bytes; + this.formattedOutput = formattedOutput; + this.unknownMap = unknownMap; + } + } + + @Test + public void readAllFields_unknownFields(@TestParameter UnknownFieldsTestCase testCase) + throws Exception { + TestAllTypes parsedMsg = + TestAllTypes.parseFrom(testCase.bytes, ExtensionRegistryLite.getEmptyRegistry()); + + MessageFields messageFields = + PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields( + testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes"); + + assertThat(messageFields.values()).isEmpty(); + assertThat(messageFields.unknowns()).containsExactlyEntriesIn(testCase.unknownMap).inOrder(); + assertThat(parsedMsg.toString().trim()).isEqualTo(testCase.formattedOutput); + } + + /** + * Tests the following message: + * + *
{@code
+   * TestAllTypes.newBuilder()
+   *     // Unknowns
+   *     .setSingleInt64Unknown(1L)
+   *     .setSingleFixed32Unknown(2)
+   *     .setSingleFixed64Unknown(3L)
+   *     .setSingleStringUnknown("Hello world")
+   *     .addAllRepeatedInt64Unknown(ImmutableList.of(4L, 5L))
+   *     .putMapStringInt64Unknown("foo", 4L)
+   *     .putMapStringInt64Unknown("bar", 5L)
+   *     // Known values
+   *     .putMapBoolDouble(true, 1.5d)
+   *     .putMapBoolDouble(false, 2.5d)
+   *     .build();
+   * }
+ */ + @Test + @SuppressWarnings("unchecked") + public void readAllFields_unknownFieldsWithValues() throws Exception { + byte[] unknownMessageBytes = { + -70, 4, 11, 8, 1, 17, 0, 0, 0, 0, 0, 0, -8, 63, -70, 4, 11, 8, 0, 17, 0, 0, 0, 0, 0, 0, 4, 64, + -96, -100, 1, 1, -83, -100, 1, 2, 0, 0, 0, -79, -100, 1, 3, 0, 0, 0, 0, 0, 0, 0, -70, -100, 1, + 11, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, -62, -100, 1, 2, 4, 5, -54, -100, 1, + 7, 10, 3, 102, 111, 111, 16, 4, -54, -100, 1, 7, 10, 3, 98, 97, 114, 16, 5 + }; + TestAllTypes parsedMsg = + TestAllTypes.parseFrom(unknownMessageBytes, ExtensionRegistryLite.getEmptyRegistry()); + + MessageFields fields = + PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields( + unknownMessageBytes, "cel.expr.conformance.proto3.TestAllTypes"); + + assertThat(parsedMsg.toString()) + .isEqualTo( + "map_bool_double {\n" + + " key: false\n" + + " value: 2.5\n" + + "}\n" + + "map_bool_double {\n" + + " key: true\n" + + " value: 1.5\n" + + "}\n" + + "2500: 1\n" + + "2501: 0x00000002\n" + + "2502: 0x0000000000000003\n" + + "2503: \"Hello world\"\n" + + "2504: \"\\004\\005\"\n" + + "2505: {\n" + + " 1: \"foo\"\n" + + " 2: 4\n" + + "}\n" + + "2505: {\n" + + " 1: \"bar\"\n" + + " 2: 5\n" + + "}\n"); + assertThat(fields.values()).containsKey("map_bool_double"); + LinkedHashMap mapBoolDoubleValues = + (LinkedHashMap) fields.values().get("map_bool_double"); + assertThat(mapBoolDoubleValues).containsExactly(true, 1.5d, false, 2.5d).inOrder(); + ImmutableMap unknownValues = fields.unknowns(); + assertThat(unknownValues) + .containsExactly( + 2500, + 1L, + 2501, + 2, + 2502, + 3L, + 2503, + ByteString.copyFromUtf8("Hello world"), + 2504, + ByteString.copyFrom(new byte[] {0x04, 0x05}), + 2505, + ByteString.copyFrom(new byte[] {10, 3, 98, 97, 114, 16, 5})) + .inOrder(); + } } diff --git a/common/src/test/java/dev/cel/common/values/ProtoMessageLiteValueTest.java b/common/src/test/java/dev/cel/common/values/ProtoMessageLiteValueTest.java index 35e7b78e..be30481c 100644 --- a/common/src/test/java/dev/cel/common/values/ProtoMessageLiteValueTest.java +++ b/common/src/test/java/dev/cel/common/values/ProtoMessageLiteValueTest.java @@ -271,29 +271,4 @@ public void selectField_defaultValue(@TestParameter DefaultValueTestCase testCas assertThat(selectedValue).isEqualTo(testCase.celValue); assertThat(selectedValue.isZeroValue()).isTrue(); } - - /** Test cases for repeated_int64: 1L,2L,3L */ - @SuppressWarnings("ImmutableEnumChecker") // Test only - private enum RepeatedFieldBytesTestCase { - PACKED(new byte[] {(byte) 0x82, 0x2, 0x3, 0x1, 0x2, 0x3}), - NON_PACKED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x80, 0x2, 0x2, (byte) 0x80, 0x2, 0x3}), - // 1L is not packed, but 2L and 3L are - MIXED(new byte[] {(byte) 0x80, 0x2, 0x1, (byte) 0x82, 0x2, 0x2, 0x2, 0x3}); - - private final byte[] bytes; - - RepeatedFieldBytesTestCase(byte[] bytes) { - this.bytes = bytes; - } - } - - @Test - public void readAllFields_repeatedFields_packedBytesCombinations( - @TestParameter RepeatedFieldBytesTestCase testCase) throws Exception { - ImmutableMap fields = - PROTO_LITE_CEL_VALUE_CONVERTER.readAllFields( - testCase.bytes, "cel.expr.conformance.proto3.TestAllTypes"); - - assertThat(fields).containsExactly("repeated_int64", ImmutableList.of(1L, 2L, 3L)); - } } diff --git a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java index e72f0776..45b06d3b 100644 --- a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java +++ b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java @@ -23,7 +23,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; +import java.util.Optional; import java.util.function.Supplier; /** @@ -78,12 +80,18 @@ public List getFieldDescriptors() { return fieldLiteDescriptors; } + public Optional findByFieldNumber(int fieldNumber) { + return Optional.ofNullable(fieldNumberToFieldDescriptors.get(fieldNumber)); + } + public FieldLiteDescriptor getByFieldNameOrThrow(String fieldName) { return Objects.requireNonNull(fieldNameToFieldDescriptors.get(fieldName)); } public FieldLiteDescriptor getByFieldNumberOrThrow(int fieldNumber) { - return Objects.requireNonNull(fieldNumberToFieldDescriptors.get(fieldNumber)); + return findByFieldNumber(fieldNumber) + .orElseThrow( + () -> new NoSuchElementException("Could not find field number: " + fieldNumber)); } /** Gets the builder for the message. Returns null for maps. */ diff --git a/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java b/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java index b434bb68..b7c23fc8 100644 --- a/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java +++ b/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java @@ -43,6 +43,7 @@ public static void createFile(String filePath, JavaFileGeneratorOption option) cfg.setDefaultEncoding("UTF-8"); cfg.setBooleanFormat("c"); cfg.setAPIBuiltinEnabled(true); + cfg.setNumberFormat("#"); // Prevent thousandth separator in numbers (eg: 1000 instead of 1,000) DefaultObjectWrapperBuilder wrapperBuilder = new DefaultObjectWrapperBuilder(version); wrapperBuilder.setExposeFields(true); cfg.setObjectWrapper(wrapperBuilder.build());