Skip to content

Commit

Permalink
Ensure overloads are searched in the order they are declared (#566)
Browse files Browse the repository at this point in the history
* Ensure overloads are searched in the order they are declared during dynamic dispatch
* Improved support for dynamic dispatch
  • Loading branch information
TristonianJones authored Jul 14, 2022
1 parent f3df06c commit efc893b
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 70 deletions.
205 changes: 172 additions & 33 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,10 @@ func TestGlobalVars(t *testing.T) {
t.Run("attrs_alt", func(t *testing.T) {
vars := map[string]interface{}{
"attrs": map[string]interface{}{"second": "yep"}}
out, _, _ := prg.Eval(vars)
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval(vars) failed: %v", err)
}
if out.Equal(types.String("yep")) != types.True {
t.Errorf("got '%v', expected 'yep'.", out.Value())
}
Expand Down Expand Up @@ -1657,7 +1660,7 @@ func TestDefaultUTCTimeZone(t *testing.T) {
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear() == 1970
&& x.getMonth() == 0
&& x.getDayOfYear() == 0
Expand Down Expand Up @@ -1687,16 +1690,10 @@ func TestDefaultUTCTimeZone(t *testing.T) {
&& x.getHours('23:15') == 1
&& x.getMinutes('23:15') == 20
&& x.getSeconds('23:15') == 6
&& x.getMilliseconds('23:15') == 1
`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()})
&& x.getMilliseconds('23:15') == 1`,
map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
Expand All @@ -1718,20 +1715,12 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) {
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear() == 1970
&& y.getHours() == 2
&& y.getMinutes() == 120
&& y.getSeconds() == 7235
&& y.getMilliseconds() == 7235000`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(
&& y.getMilliseconds() == 7235000`,
map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
"y": time.Duration(7235) * time.Second,
Expand All @@ -1750,7 +1739,7 @@ func TestDefaultUTCTimeZoneError(t *testing.T) {
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(`
out, err := interpret(t, env, `
x.getFullYear(':xx') == 1969
|| x.getDayOfYear('xx:') == 364
|| x.getMonth('Am/Ph') == 11
Expand All @@ -1761,30 +1750,180 @@ func TestDefaultUTCTimeZoneError(t *testing.T) {
|| x.getMinutes('Am/Ph') == 5
|| x.getSeconds('Am/Ph') == 6
|| x.getMilliseconds('Am/Ph') == 1
`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
`, map[string]interface{}{
"x": time.Unix(7506, 1000000).Local(),
},
)
if err == nil {
t.Fatalf("prg.Eval() got %v wanted error", out)
}
prg, err := env.Program(ast)
}

func TestDynamicDispatch(t *testing.T) {
env, err := NewEnv(
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
t.Fatalf("NewEnv() failed: %v", err)
}
out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()})
if err == nil {
t.Fatalf("prg.Eval() got %v wanted error", out)
out, err := interpret(t, env, `
[].first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"
&& [[], ["empty"]].first().first() == ""
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"
`, map[string]interface{}{},
)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Fatalf("prg.Eval() got %v wanted true", out)
}
}

func BenchmarkDynamicDispatch(b *testing.B) {
env, err := NewEnv(
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
if err != nil {
b.Fatalf("NewEnv() failed: %v", err)
}
prg := compile(b, env, `
[].first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"`)
prgDyn := compile(b, env, `
dyn([]).first() == 0
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"`)
b.ResetTimer()
b.Run("DirectDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prg.Eval(NoVars())
}
})
b.ResetTimer()
b.Run("DynamicDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prgDyn.Eval(NoVars())
}
})
}

func interpret(t *testing.T, env *Env, expr string, vars interface{}) (ref.Val, error) {
func compile(t testing.TB, env *Env, expr string) Program {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
t.Fatal(err)
}
return prg
}

func compileOrError(t testing.TB, env *Env, expr string) (Program, error) {
t.Helper()
ast, iss := env.Compile(expr)
if iss.Err() != nil {
return nil, fmt.Errorf("env.Compile(%s) failed: %v", expr, iss.Err())
}
prg, err := env.Program(ast)
prg, err := env.Program(ast, EvalOptions(OptOptimize))
if err != nil {
return nil, fmt.Errorf("env.Program() failed: %v", err)
}
return prg, nil
}

func interpret(t testing.TB, env *Env, expr string, vars interface{}) (ref.Val, error) {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
return nil, err
}
out, _, err := prg.Eval(vars)
if err != nil {
return nil, fmt.Errorf("prg.Eval(%v) failed: %v", vars, err)
Expand Down
Loading

0 comments on commit efc893b

Please sign in to comment.