Skip to content

Commit

Permalink
ContextProtoVars() to simplify proto-based inputs (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Jul 18, 2023
1 parent 215c1af commit c2302e2
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 29 deletions.
1 change: 1 addition & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ go_test(
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
],
)
57 changes: 52 additions & 5 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
wrapperspb "google.golang.org/protobuf/types/known/wrapperspb"

proto2pb "github.com/google/cel-go/test/proto2pb"
proto3pb "github.com/google/cel-go/test/proto3pb"
Expand Down Expand Up @@ -1622,17 +1625,61 @@ func TestResidualAstModified(t *testing.T) {
}
}

func TestDeclareContextProto(t *testing.T) {
func TestContextProto(t *testing.T) {
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
option := DeclareContextProto(descriptor)
env := testEnv(t, option)
expression := `single_int64 == 1 && single_double == 1.0 && single_bool == true && single_string == '' && single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& single_nested_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO && single_duration == duration('5s') && single_timestamp == timestamp('1972-01-01T10:00:20.021-05:00')
&& single_any == google.protobuf.Any{} && repeated_int32 == [1,2] && map_string_string == {'': ''} && map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
_, iss := env.Compile(expression)
expression := `
single_int64 == 1
&& single_double == 1.0
&& single_bool == true
&& single_string == ''
&& single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& standalone_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
&& single_duration == duration('5s')
&& single_timestamp == timestamp(63154820)
&& single_any == null
&& single_uint32_wrapper == null
&& single_uint64_wrapper == 0u
&& repeated_int32 == [1,2]
&& map_string_string == {'': ''}
&& map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
ast, iss := env.Compile(expression)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := &proto3pb.TestAllTypes{
SingleInt64: 1,
SingleDouble: 1.0,
SingleBool: true,
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
},
StandaloneEnum: proto3pb.TestAllTypes_FOO,
SingleDuration: &durationpb.Duration{Seconds: 5},
SingleTimestamp: &timestamppb.Timestamp{
Seconds: 63154820,
},
SingleUint64Wrapper: wrapperspb.UInt64(0),
RepeatedInt32: []int32{1, 2},
MapStringString: map[string]string{"": ""},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
}
vars, err := ContextProtoVars(in)
if err != nil {
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.True) != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}

func TestRegexOptimizer(t *testing.T) {
Expand Down
61 changes: 42 additions & 19 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -491,25 +490,21 @@ func CostLimit(costLimit uint64) ProgramOption {
}
}

func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) {
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
msgName := (string)(field.Message().FullName())
wellKnownType, found := pb.CheckedWellKnowns[msgName]
if found {
return wellKnownType, nil
}
return decls.NewObjectType(msgName), nil
return ObjectType(msgName), nil
}
if primitiveType, found := pb.CheckedPrimitives[field.Kind()]; found {
if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found {
return primitiveType, nil
}
if field.Kind() == protoreflect.EnumKind {
return decls.Int, nil
return IntType, nil
}
return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String())
}

func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
name := string(field.Name())
if field.IsMap() {
mapKey := field.MapKey()
Expand All @@ -522,44 +517,72 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
return Variable(name, MapType(keyType, valueType)), nil
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
return Variable(name, ListType(elemType)), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
return Variable(name, celType), nil
}

// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
// Each field of the proto defines a variable of the same name in the environment.
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment
func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return func(e *Env) (*Env, error) {
var decls []*exprpb.Decl
fields := descriptor.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
decl, err := fieldToDecl(field)
variable, err := fieldToVariable(field)
if err != nil {
return nil, err
}
e, err = variable(e)
if err != nil {
return nil, err
}
decls = append(decls, decl)
}
var err error
e, err = Declarations(decls...)(e)
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}

// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation.
//
// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using
// protocol buffers.
func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
if ctx == nil || !ctx.ProtoReflect().IsValid() {
return interpreter.EmptyActivation(), nil
}
reg, err := types.NewRegistry(ctx)
if err != nil {
return nil, err
}
pbRef := ctx.ProtoReflect()
typeName := string(pbRef.Descriptor().FullName())
fields := pbRef.Descriptor().Fields()
vars := make(map[string]any, fields.Len())
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
sft, found := reg.FindStructFieldType(typeName, field.TextName())
if !found {
return nil, fmt.Errorf("no such field: %s", field.TextName())
}
fieldVal, err := sft.GetFrom(ctx)
if err != nil {
return nil, err
}
return Types(dynamicpb.NewMessage(descriptor))(e)
vars[field.TextName()] = fieldVal
}
return interpreter.NewActivation(vars)
}

// EnableMacroCallTracking ensures that call expressions which are replaced by macros
Expand Down
8 changes: 5 additions & 3 deletions common/types/pb/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,13 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, er
unwrappedAny := &anypb.Any{}
err := Merge(unwrappedAny, msg)
if err != nil {
return nil, false, err
return nil, false, fmt.Errorf("unwrap dynamic field failed: %v", err)
}
dynMsg, err := unwrappedAny.UnmarshalNew()
if err != nil {
// Allow the error to move further up the stack as it should result in an type
// conversion error if the caller does not recover it somehow.
return nil, false, err
return nil, false, fmt.Errorf("unmarshal dynamic any failed: %v", err)
}
// Attempt to unwrap the dynamic type, otherwise return the dynamic message.
unwrapped, nested, err := unwrapDynamic(desc, dynMsg.ProtoReflect())
Expand Down Expand Up @@ -564,8 +564,10 @@ func zeroValueOf(msg proto.Message) proto.Message {
}

var (
jsonValueTypeURL = "types.googleapis.com/google.protobuf.Value"

zeroValueMap = map[string]proto.Message{
"google.protobuf.Any": &anypb.Any{},
"google.protobuf.Any": &anypb.Any{TypeUrl: jsonValueTypeURL},
"google.protobuf.Duration": &dpb.Duration{},
"google.protobuf.ListValue": &structpb.ListValue{},
"google.protobuf.Struct": &structpb.Struct{},
Expand Down
5 changes: 3 additions & 2 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func singularFieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsEnum() {
return IntType
}
return protoCELPrimitives[field.ProtoKind()]
return ProtoCELPrimitives[field.ProtoKind()]
}

// defaultTypeAdapter converts go native types to CEL values.
Expand Down Expand Up @@ -657,7 +657,8 @@ func fieldTypeConversionError(field *pb.FieldDescription, err error) error {
}

var (
protoCELPrimitives = map[protoreflect.Kind]*Type{
// ProtoCELPrimitives provides a map from the protoreflect Kind to the equivalent CEL type.
ProtoCELPrimitives = map[protoreflect.Kind]*Type{
protoreflect.BoolKind: BoolType,
protoreflect.BytesKind: BytesType,
protoreflect.DoubleKind: DoubleType,
Expand Down

0 comments on commit c2302e2

Please sign in to comment.