diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go index 503390627..5e1c12909 100644 --- a/encoding/protojson/decode.go +++ b/encoding/protojson/decode.go @@ -40,6 +40,10 @@ type UnmarshalOptions struct { // If DiscardUnknown is set, unknown fields and enum name values are ignored. DiscardUnknown bool + // If ForceUppercaseEnums is set, enum values are first uppercased + // before being looked up by name. + ForceUppercaseEnums bool + // Resolver is used for looking up types when unmarshaling // google.protobuf.Any messages or extension fields. // If nil, this defaults to using protoregistry.GlobalTypes. @@ -331,7 +335,7 @@ func (d decoder) unmarshalScalar(fd protoreflect.FieldDescriptor) (protoreflect. } case protoreflect.EnumKind: - if v, ok := unmarshalEnum(tok, fd, d.opts.DiscardUnknown); ok { + if v, ok := unmarshalEnum(tok, fd, d.opts.DiscardUnknown, d.opts.ForceUppercaseEnums); ok { return v, nil } @@ -476,11 +480,16 @@ func unmarshalBytes(tok json.Token) (protoreflect.Value, bool) { return protoreflect.ValueOfBytes(b), true } -func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor, discardUnknown bool) (protoreflect.Value, bool) { +func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor, discardUnknown bool, forceUppercaseEnums bool) (protoreflect.Value, bool) { switch tok.Kind() { case json.String: // Lookup EnumNumber based on name. s := tok.ParsedString() + + if forceUppercaseEnums { + s = strings.ToUpper(s) + } + if enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)); enumVal != nil { return protoreflect.ValueOfEnum(enumVal.Number()), true } diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go index 417da1dea..b9fd99908 100644 --- a/encoding/protojson/decode_test.go +++ b/encoding/protojson/decode_test.go @@ -2474,22 +2474,35 @@ func TestUnmarshal(t *testing.T) { }, }, }, { - desc: "weak fields", - inputMessage: &testpb.TestWeak{}, - inputText: `{"weak_message1":{"a":1}}`, - wantMessage: func() *testpb.TestWeak { - m := new(testpb.TestWeak) - m.SetWeakMessage1(&weakpb.WeakImportMessage1{A: proto.Int32(1)}) - return m - }(), - skip: !flags.ProtoLegacy, - }, { - desc: "weak fields; unknown field", - inputMessage: &testpb.TestWeak{}, - inputText: `{"weak_message1":{"a":1}, "weak_message2":{"a":1}}`, - wantErr: `unknown field "weak_message2"`, // weak_message2 is unknown since the package containing it is not imported - skip: !flags.ProtoLegacy, - }} + desc: "ForceUppercaseEnums: lowercase value with enum", + umo: protojson.UnmarshalOptions{ + ForceUppercaseEnums: true, + }, + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": "one" + }`, + wantMessage: &pb3.Enums{ + SEnum: pb3.Enum_ONE, + }, + }, + { + desc: "weak fields", + inputMessage: &testpb.TestWeak{}, + inputText: `{"weak_message1":{"a":1}}`, + wantMessage: func() *testpb.TestWeak { + m := new(testpb.TestWeak) + m.SetWeakMessage1(&weakpb.WeakImportMessage1{A: proto.Int32(1)}) + return m + }(), + skip: !flags.ProtoLegacy, + }, { + desc: "weak fields; unknown field", + inputMessage: &testpb.TestWeak{}, + inputText: `{"weak_message1":{"a":1}, "weak_message2":{"a":1}}`, + wantErr: `unknown field "weak_message2"`, // weak_message2 is unknown since the package containing it is not imported + skip: !flags.ProtoLegacy, + }} for _, tt := range tests { tt := tt