diff --git a/decode.go b/decode.go index c8da155..b6e0bcc 100644 --- a/decode.go +++ b/decode.go @@ -54,6 +54,10 @@ type UnmarshalOptions struct { protoregistry.MessageTypeResolver protoregistry.ExtensionTypeResolver } + + // DiscardUnknown specifies whether to discard unknown fields instead of + // returning an error. + DiscardUnknown bool } // Unmarshal a Protobuf message from the given YAML data. @@ -643,6 +647,10 @@ func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, f message.ProtoReflect().Descriptor().FullName(), getNodeKind(node.Kind)) return } + u.unmarshalMessageFields(node, message, forAny) +} + +func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Message, forAny bool) { // Decode the fields msgDesc := message.ProtoReflect().Descriptor() for i := 0; i < len(node.Content); i += 2 { @@ -670,7 +678,9 @@ func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, f field, err := u.findField(key, msgDesc) switch { case errors.Is(err, protoregistry.NotFound): - u.addErrorf(keyNode, "unknown field %#v, expected one of %v", key, getFieldNames(msgDesc.Fields())) + if !u.options.DiscardUnknown { + u.addErrorf(keyNode, "unknown field %#v, expected one of %v", key, getFieldNames(msgDesc.Fields())) + } case err != nil: u.addError(keyNode, err) default: diff --git a/decode_test.go b/decode_test.go index 271b861..c956a45 100644 --- a/decode_test.go +++ b/decode_test.go @@ -97,3 +97,23 @@ func TestExtension(t *testing.T) { require.NoError(t, err) require.Equal(t, "hi", proto.GetExtension(actual, testv1.E_P2TStringExt)) } + +func TestDiscardUnknown(t *testing.T) { + t.Parallel() + + data := []byte(` +unknown: hi +values: + - oneof_string_value: hi +`) + + actual := &testv1.Proto2Test{} + err := Unmarshal(data, actual) + require.Error(t, err) + + err = UnmarshalOptions{ + DiscardUnknown: true, + }.Unmarshal(data, actual) + require.NoError(t, err) + require.Equal(t, "hi", actual.GetValues()[0].GetOneofStringValue()) +}