diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index d3b9ae22..0c740dd8 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -30,6 +30,7 @@ go_library( "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + "@org_golang_google_protobuf//types/dynamicpb:go_default_library", "@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", ], importpath = "github.com/google/cel-go/cel", diff --git a/cel/cel_test.go b/cel/cel_test.go index 7fdaa016..b5babe07 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -35,11 +35,13 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" proto2pb "github.com/google/cel-go/test/proto2pb" proto3pb "github.com/google/cel-go/test/proto3pb" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" descpb "google.golang.org/protobuf/types/descriptorpb" + dynamicpb "google.golang.org/protobuf/types/dynamicpb" ) func Example() { @@ -562,6 +564,68 @@ func TestDynamicProto(t *testing.T) { } } +func TestDynamicProto_Input(t *testing.T) { + b, err := ioutil.ReadFile("testdata/team.fds") + if err != nil { + t.Fatalf("ioutil.ReadFile() failed: %v", err) + } + var fds descpb.FileDescriptorSet + if err = proto.Unmarshal(b, &fds); err != nil { + t.Fatalf("proto.Unmarshal() failed: %v", err) + } + files := (&fds).GetFile() + fileCopy := make([]interface{}, len(files)) + for i := 0; i < len(files); i++ { + fileCopy[i] = files[i] + } + pbFiles, err := protodesc.NewFiles(&fds) + if err != nil { + t.Fatalf("protodesc.NewFiles() failed: %v", err) + } + desc, err := pbFiles.FindDescriptorByName("cel.testdata.Mutant") + if err != nil { + t.Fatalf("pbFiles.FindDescriptorByName() could not find Mutant: %v", err) + } + msgDesc, ok := desc.(protoreflect.MessageDescriptor) + if !ok { + t.Fatalf("desc not convertible to MessageDescriptor: %T", desc) + } + wolverine := dynamicpb.NewMessage(msgDesc) + wolverine.ProtoReflect().Set(msgDesc.Fields().ByName("name"), protoreflect.ValueOfString("Wolverine")) + e, err := NewEnv( + // The following is identical to registering the FileDescriptorSet; + // however, it tests a different code path which aggregates individual + // FileDescriptorProto values together. + TypeDescs(fileCopy...), + Declarations(decls.NewVar("mutant", decls.NewObjectType("cel.testdata.Mutant"))), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + src := `has(mutant.name) && mutant.name == 'Wolverine'` + ast, iss := e.Compile(src) + if iss.Err() != nil { + t.Fatalf("env.Compile(%s) failed: %v", src, iss.Err()) + } + prg, err := e.Program(ast, EvalOptions(OptOptimize)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(map[string]interface{}{ + "mutant": wolverine, + }) + if err != nil { + t.Fatalf("program.Eval() failed: %v", err) + } + obj, ok := out.(types.Bool) + if !ok { + t.Fatalf("unable to convert output to object: %v", out) + } + if obj != types.True { + t.Errorf("got %v, wanted true", out) + } +} + func TestGlobalVars(t *testing.T) { mapStrDyn := decls.NewMapType(decls.String, decls.Dyn) e, _ := NewEnv( diff --git a/common/types/pb/type.go b/common/types/pb/type.go index 64a56c70..c5060a46 100644 --- a/common/types/pb/type.go +++ b/common/types/pb/type.go @@ -198,7 +198,16 @@ func (fd *FieldDescription) Descriptor() protoreflect.FieldDescriptor { func (fd *FieldDescription) IsSet(target interface{}) bool { switch v := target.(type) { case proto.Message: - return v.ProtoReflect().Has(fd.desc) + pbRef := v.ProtoReflect() + pbDesc := pbRef.Descriptor() + if pbDesc == fd.desc.ContainingMessage() { + // When the target protobuf shares the same message descriptor instance as the field + // descriptor, use the cached field descriptor value. + return pbRef.Has(fd.desc) + } + // Otherwise, fallback to a dynamic lookup of the field descriptor from the target + // instance as an attempt to use the cached field descriptor will result in a panic. + return pbRef.Has(pbDesc.Fields().ByName(protoreflect.Name(fd.Name()))) default: return false } @@ -215,7 +224,18 @@ func (fd *FieldDescription) GetFrom(target interface{}) (interface{}, error) { if !ok { return nil, fmt.Errorf("unsupported field selection target: (%T)%v", target, target) } - fieldVal := v.ProtoReflect().Get(fd.desc).Interface() + pbRef := v.ProtoReflect() + pbDesc := pbRef.Descriptor() + var fieldVal interface{} + if pbDesc == fd.desc.ContainingMessage() { + // When the target protobuf shares the same message descriptor instance as the field + // descriptor, use the cached field descriptor value. + fieldVal = pbRef.Get(fd.desc).Interface() + } else { + // Otherwise, fallback to a dynamic lookup of the field descriptor from the target + // instance as an attempt to use the cached field descriptor will result in a panic. + fieldVal = pbRef.Get(pbDesc.Fields().ByName(protoreflect.Name(fd.Name()))).Interface() + } switch fv := fieldVal.(type) { // Fast-path return for primitive types. case bool, []byte, float32, float64, int32, int64, string, uint32, uint64, protoreflect.List: