From 7e2a72a51ff8ca68e07a67db4787f6abaef12dc2 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 10:15:13 -0800 Subject: [PATCH 1/8] Add cel support for scalar types --- decode.go | 130 ++++++++++++++++++++++-- internal/testdata/basic.proto3test.txt | 22 ++-- internal/testdata/basic.proto3test.yaml | 4 + internal/testdata/dynamic.const.txt | 18 ++-- 4 files changed, 145 insertions(+), 29 deletions(-) diff --git a/decode.go b/decode.go index d2b8f39..a6bd465 100644 --- a/decode.go +++ b/decode.go @@ -25,6 +25,8 @@ import ( "time" "github.com/bufbuild/protovalidate-go" + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types/ref" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -71,6 +73,9 @@ type UnmarshalOptions struct { // DiscardUnknown specifies whether to discard unknown fields instead of // returning an error. DiscardUnknown bool + + // CelEnv is the CEL environment to use for evaluating CEL expressions. + CelEnv *cel.Env } // Unmarshal a Protobuf message from the given YAML data. @@ -151,6 +156,15 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, validator: o.Validator, lines: strings.Split(string(data), "\n"), } + if o.CelEnv != nil { + unm.celEnv = o.CelEnv + } else { + var err error + unm.celEnv, err = cel.NewEnv() + if err != nil { + return err + } + } // Unwrap the document node if node.Kind == yaml.DocumentNode { @@ -196,6 +210,7 @@ type unmarshaler struct { errors []error validator Validator lines []string + celEnv *cel.Env } func (u *unmarshaler) addError(node *yaml.Node, err error) { @@ -254,6 +269,22 @@ func (u *unmarshaler) findAnyType(node *yaml.Node) (protoreflect.MessageType, er return u.resolveAnyType(typeURL) } +func (u *unmarshaler) celEval(node *yaml.Node) (ref.Val, error) { + ast, issues := u.celEnv.Compile(node.Value) + if issues != nil && issues.Err() != nil { + return nil, issues.Err() + } + prg, err := u.celEnv.Program(ast) + if err != nil { + return nil, err + } + out, _, err := prg.Eval(map[string]interface{}{}) + if err != nil { + return nil, err + } + return out, nil +} + // Unmarshal the field based on the field kind, ignoring IsList and IsMap, // which are handled by the caller. func (u *unmarshaler) unmarshalScalar( @@ -279,8 +310,7 @@ func (u *unmarshaler) unmarshalScalar( case protoreflect.DoubleKind: return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64)), true case protoreflect.StringKind: - u.checkKind(node, yaml.ScalarNode) - return protoreflect.ValueOfString(node.Value), true + return protoreflect.ValueOfString(u.unmarshalString(node)), true case protoreflect.BytesKind: return protoreflect.ValueOfBytes(u.unmarshalBytes(node)), true case protoreflect.EnumKind: @@ -291,12 +321,50 @@ func (u *unmarshaler) unmarshalScalar( } } +func (u *unmarshaler) unmarshalString(node *yaml.Node) string { + if !u.checkKind(node, yaml.ScalarNode) { + return "" + } + + if node.Tag == "!!cel" { + if val, celErr := u.celEval(node); celErr != nil { + u.addErrorf(node, "invalid CEL expression: %v", celErr) + } else if strVal, ok := val.Value().(string); ok { + return strVal + } else { + u.addErrorf(node, "expected string, got %v", val.Type()) + return "" + } + } + + return node.Value +} + // Base64 decodes the given node value. func (u *unmarshaler) unmarshalBytes(node *yaml.Node) []byte { if !u.checkKind(node, yaml.ScalarNode) { return nil } + tryCel := func() ([]byte, error) { + val, celErr := u.celEval(node) + if celErr != nil { + return nil, celErr + } + if bytesVal, ok := val.Value().([]byte); ok { + return bytesVal, nil + } + return nil, fmt.Errorf("expected bytes, got %v", val.Type()) + } + + if node.Tag == "!!cel" { + data, err := tryCel() + if err != nil { + u.addError(node, err) + } + return data + } + enc := base64.StdEncoding if strings.ContainsAny(node.Value, "-_") { enc = base64.URLEncoding @@ -308,6 +376,9 @@ func (u *unmarshaler) unmarshalBytes(node *yaml.Node) []byte { // base64 decode the value. data, err := enc.DecodeString(node.Value) if err != nil { + if data, celErr := tryCel(); celErr == nil { + return data + } u.addErrorf(node, "invalid base64: %v", err) } return data @@ -328,7 +399,14 @@ func (u *unmarshaler) unmarshalBool(node *yaml.Node, forKey bool) bool { } return false default: - u.addErrorf(node, "expected bool, got %#v", node.Value) + if val, celErr := u.celEval(node); celErr == nil { + if boolVal, ok := val.Value().(bool); ok { + return boolVal + } + u.addErrorf(node, "expected bool, got %v", val.Type()) + } else { + u.addErrorf(node, "expected bool, got %#v", node.Value) + } } } return false @@ -375,7 +453,15 @@ func (u *unmarshaler) unmarshalFloat(node *yaml.Node, bits int) float64 { parsed, err := strconv.ParseFloat(node.Value, bits) if err != nil { - u.addErrorf(node, "invalid float: %v", err) + if val, celErr := u.celEval(node); celErr == nil { + if floatVal, ok := val.Value().(float64); ok && (bits == 64 || float64(float32(floatVal)) == floatVal) { + parsed = floatVal + } else { + u.addErrorf(node, "invalid float: %v", err) + } + } else { + u.addErrorf(node, "invalid float: %v", err) + } } return parsed } @@ -388,10 +474,21 @@ func (u *unmarshaler) unmarshalUnsigned(node *yaml.Node, bits int) uint64 { parsed, err := parseUintLiteral(node.Value) if err != nil { - u.addErrorf(node, "invalid integer: %v", err) + if val, celErr := u.celEval(node); celErr == nil { + if uintVal, ok := val.Value().(uint64); ok { + parsed = uintVal + } else if intVal, ok := val.Value().(int64); ok && intVal >= 0 { + parsed = uint64(intVal) + } else { + u.addErrorf(node, "expected unsigned integer, got %v", val.Type()) + } + } else { + u.addErrorf(node, "invalid unsigned integer: %v", err) + } } + if bits < 64 && parsed >= 1< %v", 1< %v", 1< 4294967295 +internal/testdata/basic.proto3test.yaml:63:20 unsigned integer is too large: > 4294967295 63 | - single_uint32: 4294967296 63 | ...................^ -internal/testdata/basic.proto3test.yaml:65:20 invalid integer: precision loss +internal/testdata/basic.proto3test.yaml:65:20 invalid unsigned integer: precision loss 65 | - single_uint64: 18446744073709551616 65 | ...................^ @@ -233,3 +225,7 @@ internal/testdata/basic.proto3test.yaml:130:23 expected fields for google.protob internal/testdata/basic.proto3test.yaml:135:7 unknown field "@type", expected one of [value] 135 | "@type": type.googleapis.com/google.protobuf.Int32Value 135 | ......^ + +internal/testdata/basic.proto3test.yaml:142:28 expected unsigned integer, got int + 142 | - single_uint32_wrapper: 1 - 2 + 142 | ...........................^ diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index 03afa62..4437945 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -136,3 +136,7 @@ values: value: 1 - single_bytes: "nopad+" - single_bytes: "web-safe__" + - single_bytes: b"hi" + - single_string: !!cel '"hi" + "there"' + - single_int32_wrapper: 1 + 3 + - single_uint32_wrapper: 1 - 2 diff --git a/internal/testdata/dynamic.const.txt b/internal/testdata/dynamic.const.txt index 9e7103c..995db96 100644 --- a/internal/testdata/dynamic.const.txt +++ b/internal/testdata/dynamic.const.txt @@ -6,7 +6,7 @@ internal/testdata/dynamic.const.yaml:13:12 expected scalar, got mapping 13 | value: {} 13 | ...........^ -internal/testdata/dynamic.const.yaml:16:12 expected bool, got "null" +internal/testdata/dynamic.const.yaml:16:12 expected bool, got null_type 16 | value: null 16 | ...........^ @@ -82,11 +82,11 @@ internal/testdata/dynamic.const.yaml:70:7 expected scalar, got sequence 70 | []: true 70 | ......^ -internal/testdata/dynamic.const.yaml:71:7 expected bool, got "1" +internal/testdata/dynamic.const.yaml:71:7 expected bool, got int 71 | 1: true 71 | ......^ -internal/testdata/dynamic.const.yaml:72:13 expected bool, got "1" +internal/testdata/dynamic.const.yaml:72:13 expected bool, got int 72 | true: 1 72 | ............^ @@ -118,27 +118,27 @@ internal/testdata/dynamic.const.yaml:81:92 integer is too large: > 9223372036854 81 | repeated_int64: [1.5, -9223372036854775808, -9223372036854775809, 9223372036854775807, 9223372036854775808] 81 | ...........................................................................................^ -internal/testdata/dynamic.const.yaml:82:23 invalid integer: precision loss +internal/testdata/dynamic.const.yaml:82:23 expected unsigned integer, got double 82 | repeated_uint32: [1.5, -1, 0, 4294967295, 4294967296] 82 | ......................^ -internal/testdata/dynamic.const.yaml:82:28 invalid integer: strconv.ParseUint: parsing "-1": invalid syntax +internal/testdata/dynamic.const.yaml:82:28 expected unsigned integer, got int 82 | repeated_uint32: [1.5, -1, 0, 4294967295, 4294967296] 82 | ...........................^ -internal/testdata/dynamic.const.yaml:82:47 integer is too large: > 4294967295 +internal/testdata/dynamic.const.yaml:82:47 unsigned integer is too large: > 4294967295 82 | repeated_uint32: [1.5, -1, 0, 4294967295, 4294967296] 82 | ..............................................^ -internal/testdata/dynamic.const.yaml:83:23 invalid integer: precision loss +internal/testdata/dynamic.const.yaml:83:23 expected unsigned integer, got double 83 | repeated_uint64: [1.5, -1, 0, 18446744073709551615, 18446744073709551616] 83 | ......................^ -internal/testdata/dynamic.const.yaml:83:28 invalid integer: strconv.ParseUint: parsing "-1": invalid syntax +internal/testdata/dynamic.const.yaml:83:28 expected unsigned integer, got int 83 | repeated_uint64: [1.5, -1, 0, 18446744073709551615, 18446744073709551616] 83 | ...........................^ -internal/testdata/dynamic.const.yaml:83:57 invalid integer: precision loss +internal/testdata/dynamic.const.yaml:83:57 invalid unsigned integer: precision loss 83 | repeated_uint64: [1.5, -1, 0, 18446744073709551615, 18446744073709551616] 83 | ........................................................^ From 3412fe1bc4cb7a3094440c4df130e41faebff3f2 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 10:56:07 -0800 Subject: [PATCH 2/8] lint --- .golangci.yml | 1 + decode.go | 211 ++++++++++++++---------- internal/testdata/basic.proto3test.yaml | 2 +- internal/testdata/dynamic.const.txt | 8 +- 4 files changed, 131 insertions(+), 91 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index d4eebbf..27f6163 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -12,6 +12,7 @@ linters-settings: - github.com/bufbuild/protoyaml-go/decode - github.com/bufbuild/protovalidate-go - buf.build/gen/go/bufbuild/protovalidate + - github.com/google/cel-go errcheck: check-type-assertions: true forbidigo: diff --git a/decode.go b/decode.go index a6bd465..7ebc5fd 100644 --- a/decode.go +++ b/decode.go @@ -46,6 +46,10 @@ var ( wktUnmarshalers map[protoreflect.FullName]customUnmarshaler ) +const ( + celTag = "!cel" +) + // Validator is an interface for validating a Protobuf message produced from a given YAML node. type Validator interface { // Validate the given message. @@ -155,6 +159,7 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, options: o, validator: o.Validator, lines: strings.Split(string(data), "\n"), + this: message, } if o.CelEnv != nil { unm.celEnv = o.CelEnv @@ -211,6 +216,7 @@ type unmarshaler struct { validator Validator lines []string celEnv *cel.Env + this proto.Message } func (u *unmarshaler) addError(node *yaml.Node, err error) { @@ -326,18 +332,20 @@ func (u *unmarshaler) unmarshalString(node *yaml.Node) string { return "" } - if node.Tag == "!!cel" { - if val, celErr := u.celEval(node); celErr != nil { - u.addErrorf(node, "invalid CEL expression: %v", celErr) - } else if strVal, ok := val.Value().(string); ok { - return strVal - } else { - u.addErrorf(node, "expected string, got %v", val.Type()) - return "" - } + if node.Tag != celTag { + return node.Value } - return node.Value + val, celErr := u.celEval(node) + if celErr != nil { + u.addErrorf(node, "invalid CEL expression: %v", celErr) + return "" + } + if strVal, ok := val.Value().(string); ok { + return strVal + } + u.addErrorf(node, "expected string, got %v", val.Type()) + return "" } // Base64 decodes the given node value. @@ -346,25 +354,6 @@ func (u *unmarshaler) unmarshalBytes(node *yaml.Node) []byte { return nil } - tryCel := func() ([]byte, error) { - val, celErr := u.celEval(node) - if celErr != nil { - return nil, celErr - } - if bytesVal, ok := val.Value().([]byte); ok { - return bytesVal, nil - } - return nil, fmt.Errorf("expected bytes, got %v", val.Type()) - } - - if node.Tag == "!!cel" { - data, err := tryCel() - if err != nil { - u.addError(node, err) - } - return data - } - enc := base64.StdEncoding if strings.ContainsAny(node.Value, "-_") { enc = base64.URLEncoding @@ -375,41 +364,61 @@ func (u *unmarshaler) unmarshalBytes(node *yaml.Node) []byte { // base64 decode the value. data, err := enc.DecodeString(node.Value) - if err != nil { - if data, celErr := tryCel(); celErr == nil { - return data + if err == nil { + return data + } + + val, celErr := u.celEval(node) + if celErr != nil { + if node.Tag == celTag { + u.addErrorf(node, "invalid CEL expression: %v", celErr) + } else { + u.addErrorf(node, "invalid base64: %v", err) } - u.addErrorf(node, "invalid base64: %v", err) + return nil + } + + if bytesVal, ok := val.Value().([]byte); ok { + return bytesVal } - return data + u.addErrorf(node, "expected bytes, got %v", val.Type()) + return nil } // Unmarshal raw `true` or `false` values, only allowing for strings for keys. func (u *unmarshaler) unmarshalBool(node *yaml.Node, forKey bool) bool { - if u.checkKind(node, yaml.ScalarNode) { - switch node.Value { - case "true": - if !forKey { - u.checkTag(node, "!!bool") - } - return true - case "false": - if !forKey { - u.checkTag(node, "!!bool") - } - return false - default: - if val, celErr := u.celEval(node); celErr == nil { - if boolVal, ok := val.Value().(bool); ok { - return boolVal - } - u.addErrorf(node, "expected bool, got %v", val.Type()) + if !u.checkKind(node, yaml.ScalarNode) { + return false + } + + switch node.Value { + case "true": + if !forKey { + u.checkTag(node, "!!bool") + } + return true + case "false": + if !forKey { + u.checkTag(node, "!!bool") + } + return false + default: + val, celErr := u.celEval(node) + if celErr != nil { + if node.Tag == celTag { + u.addErrorf(node, "invalid CEL expression: %v", celErr) } else { u.addErrorf(node, "expected bool, got %#v", node.Value) } + return false + } + if boolVal, ok := val.Value().(bool); ok { + return boolVal + } else { + u.addErrorf(node, "expected bool, got %v", val.Type()) } + return false } - return false } // Unmarshal the given node into an enum value. @@ -452,16 +461,24 @@ func (u *unmarshaler) unmarshalFloat(node *yaml.Node, bits int) float64 { } parsed, err := strconv.ParseFloat(node.Value, bits) - if err != nil { - if val, celErr := u.celEval(node); celErr == nil { - if floatVal, ok := val.Value().(float64); ok && (bits == 64 || float64(float32(floatVal)) == floatVal) { - parsed = floatVal - } else { - u.addErrorf(node, "invalid float: %v", err) - } + if err == nil { + return parsed + } + + val, celErr := u.celEval(node) + if celErr != nil { + if node.Tag == celTag { + u.addErrorf(node, "invalid CEL expression: %v", celErr) } else { u.addErrorf(node, "invalid float: %v", err) } + return 0 + } + + if floatVal, ok := val.Value().(float64); ok && (bits == 64 || float64(float32(floatVal)) == floatVal) { + parsed = floatVal + } else { + u.addErrorf(node, "invalid float: %v", err) } return parsed } @@ -473,18 +490,29 @@ func (u *unmarshaler) unmarshalUnsigned(node *yaml.Node, bits int) uint64 { } parsed, err := parseUintLiteral(node.Value) - if err != nil { - if val, celErr := u.celEval(node); celErr == nil { - if uintVal, ok := val.Value().(uint64); ok { - parsed = uintVal - } else if intVal, ok := val.Value().(int64); ok && intVal >= 0 { - parsed = uint64(intVal) - } else { - u.addErrorf(node, "expected unsigned integer, got %v", val.Type()) - } + if err == nil { + if bits < 64 && parsed >= 1< %v", 1<= 0 { + parsed = uint64(intVal) + } else { + u.addErrorf(node, "expected unsigned integer, got %v", val.Type()) } if bits < 64 && parsed >= 1< Date: Wed, 1 Jan 2025 13:52:29 -0800 Subject: [PATCH 3/8] add this --- decode.go | 12 +++++++++++- internal/testdata/basic.proto3test.yaml | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index 7ebc5fd..824945f 100644 --- a/decode.go +++ b/decode.go @@ -170,6 +170,14 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, return err } } + var err error + unm.celEnv, err = unm.celEnv.Extend( + cel.Types(message), + cel.Variable("this", cel.ObjectType(string(message.ProtoReflect().Descriptor().FullName()))), + ) + if err != nil { + return err + } // Unwrap the document node if node.Kind == yaml.DocumentNode { @@ -284,7 +292,9 @@ func (u *unmarshaler) celEval(node *yaml.Node) (ref.Val, error) { if err != nil { return nil, err } - out, _, err := prg.Eval(map[string]interface{}{}) + out, _, err := prg.Eval(map[string]interface{}{ + "this": u.this, + }) if err != nil { return nil, err } diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index c0ad913..cc94d92 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -140,3 +140,4 @@ values: - single_string: !cel '"hi" + "there"' - single_int32_wrapper: 1 + 3 - single_uint32_wrapper: 1 - 2 + - single_bool: this.values[0].single_bool From 058c4c4a534c0594f34e35633d73bf0a0e590bb1 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 13:53:43 -0800 Subject: [PATCH 4/8] nit --- internal/testdata/basic.proto3test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index cc94d92..47d5daf 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -140,4 +140,4 @@ values: - single_string: !cel '"hi" + "there"' - single_int32_wrapper: 1 + 3 - single_uint32_wrapper: 1 - 2 - - single_bool: this.values[0].single_bool + - single_bool: "!this.values[0].single_bool" From 6bb5b324c042e8131755260f5be6c4b62edc0b07 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 15:51:17 -0800 Subject: [PATCH 5/8] support duration and timestamp --- decode.go | 60 +++++++++++++++++++++++-- internal/testdata/basic.proto3test.txt | 2 +- internal/testdata/basic.proto3test.yaml | 2 + 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/decode.go b/decode.go index 824945f..bbb976b 100644 --- a/decode.go +++ b/decode.go @@ -160,6 +160,7 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, validator: o.Validator, lines: strings.Split(string(data), "\n"), this: message, + now: time.Now(), } if o.CelEnv != nil { unm.celEnv = o.CelEnv @@ -174,6 +175,7 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, unm.celEnv, err = unm.celEnv.Extend( cel.Types(message), cel.Variable("this", cel.ObjectType(string(message.ProtoReflect().Descriptor().FullName()))), + cel.Variable("now", cel.TimestampType), ) if err != nil { return err @@ -225,6 +227,7 @@ type unmarshaler struct { lines []string celEnv *cel.Env this proto.Message + now time.Time } func (u *unmarshaler) addError(node *yaml.Node, err error) { @@ -294,6 +297,7 @@ func (u *unmarshaler) celEval(node *yaml.Node) (ref.Val, error) { } out, _, err := prg.Eval(map[string]interface{}{ "this": u.this, + "now": u.now, }) if err != nil { return nil, err @@ -938,6 +942,7 @@ func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error { if err != nil { return err } + // Validate seconds. secs := parsed.Unix() if secs < minTimestampSeconds { @@ -966,14 +971,36 @@ func setFieldByName(message proto.Message, name string, value protoreflect.Value return true } +func evalDuration(node *yaml.Node, unm *unmarshaler, parseErr error) (*durationpb.Duration, error) { + val, celErr := unm.celEval(node) + if celErr != nil { + if node.Tag == celTag { + return nil, fmt.Errorf("invalid CEL expression: %w", celErr) + } + return nil, parseErr + } + + if durVal, ok := val.Value().(time.Duration); ok { + return durationpb.New(durVal), nil + } + if val.Type() == cel.IntType { + return nil, parseErr + } + + return nil, fmt.Errorf("expected duration, got %v", val.Type()) +} + func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool { if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) { return false } duration, err := ParseDuration(node.Value) if err != nil { - unm.addError(node, err) - return true + duration, err = evalDuration(node, unm, err) + if err != nil { + unm.addError(node, err) + return true + } } if value, ok := message.(*durationpb.Duration); ok { @@ -987,6 +1014,24 @@ func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Messa setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos())) } +func (u *unmarshaler) evalTimestamp(node *yaml.Node, timestamp *timestamppb.Timestamp, parseErr error) error { + val, celErr := u.celEval(node) + if celErr != nil { + if node.Tag == celTag { + return fmt.Errorf("invalid CEL expression: %w", celErr) + } + return fmt.Errorf("invalid timestamp: %w", parseErr) + } + + if timeVal, ok := val.Value().(time.Time); ok { + timestamp.Seconds = timeVal.Unix() + timestamp.Nanos = int32(timeVal.Nanosecond()) //nolint:gosec + return nil + } + + return fmt.Errorf("expected timestamp, got %v", val.Type()) +} + func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool { if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) { return false @@ -995,10 +1040,17 @@ func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Mess if !ok { timestamp = ×tamppb.Timestamp{} } + err := parseTimestamp(node.Value, timestamp) if err != nil { - unm.addErrorf(node, "invalid timestamp: %v", err) - } else if !ok { + err = unm.evalTimestamp(node, timestamp, err) + if err != nil { + unm.addError(node, err) + return true + } + } + + if !ok { return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(timestamp.GetSeconds())) && setFieldByName(message, "nanos", protoreflect.ValueOfInt32(timestamp.GetNanos())) } diff --git a/internal/testdata/basic.proto3test.txt b/internal/testdata/basic.proto3test.txt index d3eb0d1..b056af7 100644 --- a/internal/testdata/basic.proto3test.txt +++ b/internal/testdata/basic.proto3test.txt @@ -210,7 +210,7 @@ internal/testdata/basic.proto3test.yaml:121:23 invalid timestamp: parsing time " 121 | - single_timestamp: 9999-12-31T23:59:60Z 121 | ......................^ -internal/testdata/basic.proto3test.yaml:125:23 invalid timestamp: parsing time "10" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "10" as "2006" +internal/testdata/basic.proto3test.yaml:125:23 expected timestamp, got int 125 | - single_timestamp: 10 125 | ......................^ diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index 47d5daf..227f52a 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -141,3 +141,5 @@ values: - single_int32_wrapper: 1 + 3 - single_uint32_wrapper: 1 - 2 - single_bool: "!this.values[0].single_bool" + - single_timestamp: now + - single_duration: now - timestamp("1970-01-01T00:00:00Z") From 619d1e99855c63e7c25fe840b11d842abdf1c280 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 16:13:19 -0800 Subject: [PATCH 6/8] nit --- decode.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/decode.go b/decode.go index bbb976b..260ea7b 100644 --- a/decode.go +++ b/decode.go @@ -760,8 +760,8 @@ func (u *unmarshaler) unmarshalList(node *yaml.Node, field protoreflect.FieldDes case protoreflect.MessageKind, protoreflect.GroupKind: for _, itemNode := range node.Content { msgVal := list.NewElement() - u.unmarshalMessage(itemNode, msgVal.Message().Interface(), false) list.Append(msgVal) + u.unmarshalMessage(itemNode, msgVal.Message().Interface(), false) } default: for _, itemNode := range node.Content { @@ -1103,10 +1103,10 @@ func dynSetListValue(message proto.Message, list *structpb.ListValue) bool { values := message.ProtoReflect().Mutable(valuesFld).List() for _, item := range list.GetValues() { value := values.NewElement() + values.Append(value) if !dynSetValue(value.Message().Interface(), item) { return false } - values.Append(value) } return true } From 92422b1d6ba1238ecb2a06b0fd44f2c45d0ade25 Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 22:36:06 -0800 Subject: [PATCH 7/8] moar --- decode.go | 34 ++++++++++++++++--------- internal/testdata/basic.proto3test.yaml | 4 +++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/decode.go b/decode.go index 260ea7b..cf2818f 100644 --- a/decode.go +++ b/decode.go @@ -765,11 +765,9 @@ func (u *unmarshaler) unmarshalList(node *yaml.Node, field protoreflect.FieldDes } default: for _, itemNode := range node.Content { - val, ok := u.unmarshalScalar(itemNode, field, false) - if !ok { - continue + if val, ok := u.unmarshalScalar(itemNode, field, false); ok { + list.Append(val) } - list.Append(val) } } } @@ -792,14 +790,12 @@ func (u *unmarshaler) unmarshalMap(node *yaml.Node, field protoreflect.FieldDesc switch mapValueField.Kind() { case protoreflect.MessageKind, protoreflect.GroupKind: mapValue := mapVal.NewValue() - u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false) mapVal.Set(mapKey.MapKey(), mapValue) + u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false) default: - val, ok := u.unmarshalScalar(valueNode, mapValueField, false) - if !ok { - continue + if val, ok := u.unmarshalScalar(valueNode, mapValueField, false); ok { + mapVal.Set(mapKey.MapKey(), val) } - mapVal.Set(mapKey.MapKey(), val) } } } @@ -850,12 +846,26 @@ func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, f if isNull(node) { return // Null is always allowed for messages } - if node.Kind != yaml.MappingNode { + switch node.Kind { + case yaml.MappingNode: + u.unmarshalMessageFields(node, message, forAny) + case yaml.ScalarNode: + if val, err := u.celEval(node); err == nil { + if protoVal, ok := val.Value().(proto.Message); ok { + if protoVal.ProtoReflect().Descriptor() == message.ProtoReflect().Descriptor() { + proto.Merge(message, protoVal) + return + } + } + } else if node.Tag == celTag { + u.addErrorf(node, "invalid CEL expression: %v", err) + return + } + fallthrough + default: u.addErrorf(node, "expected fields for %v, got %v", message.ProtoReflect().Descriptor().FullName(), getNodeKind(node.Kind)) - return } - u.unmarshalMessageFields(node, message, forAny) } func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Message, forAny bool) { diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index 227f52a..3171683 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -143,3 +143,7 @@ values: - single_bool: "!this.values[0].single_bool" - single_timestamp: now - single_duration: now - timestamp("1970-01-01T00:00:00Z") + - map_int64_nested_type: + 1: { payload: { single_int32: 1 } } + 2: { payload: { single_sint32: "this.values[135].map_int64_nested_type[1].payload.single_int32 + 1" } } + 3: "this.values[135].map_int64_nested_type[2]" From 42de9161def27cd51580017f64a6c26b32bd513d Mon Sep 17 00:00:00 2001 From: Alfred Fuller Date: Wed, 1 Jan 2025 22:43:55 -0800 Subject: [PATCH 8/8] more tests --- internal/testdata/basic.proto3test.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/testdata/basic.proto3test.yaml b/internal/testdata/basic.proto3test.yaml index 3171683..09490ea 100644 --- a/internal/testdata/basic.proto3test.yaml +++ b/internal/testdata/basic.proto3test.yaml @@ -144,6 +144,6 @@ values: - single_timestamp: now - single_duration: now - timestamp("1970-01-01T00:00:00Z") - map_int64_nested_type: - 1: { payload: { single_int32: 1 } } - 2: { payload: { single_sint32: "this.values[135].map_int64_nested_type[1].payload.single_int32 + 1" } } + 1: { payload: { single_int32: 1, single_fixed32: "this.values[135].map_int64_nested_type[1].payload.single_int32 + 1" } } + 2: { payload: { single_sint32: "this.values[135].map_int64_nested_type[1].payload.single_fixed32 + 1u" } } 3: "this.values[135].map_int64_nested_type[2]"