diff --git a/cel/cel_test.go b/cel/cel_test.go index f12834b4..36dbdcc4 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -202,8 +202,10 @@ func TestCustomEnvError(t *testing.T) { } func TestCustomEnv(t *testing.T) { - e, _ := NewCustomEnv(Variable("a.b.c", BoolType)) - + e, err := NewCustomEnv(Variable("a.b.c", BoolType)) + if err != nil { + t.Fatalf("NewCustomEnv(a.b.c:bool) failed: %v", err) + } t.Run("err", func(t *testing.T) { _, iss := e.Compile("a.b.c == true") if iss.Err() == nil { @@ -1704,7 +1706,11 @@ func TestDefaultUTCTimeZone(t *testing.T) { } func TestDefaultUTCTimeZoneExtension(t *testing.T) { - env, err := NewEnv(Variable("x", TimestampType), DefaultUTCTimeZone(true)) + env, err := NewEnv( + Variable("x", TimestampType), + Variable("y", DurationType), + DefaultUTCTimeZone(true), + ) if err != nil { t.Fatalf("NewEnv() failed: %v", err) } @@ -1712,7 +1718,12 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { if err != nil { t.Fatalf("env.Extend() failed: %v", err) } - ast, iss := env.Compile(`x.getFullYear() == 1970`) + ast, iss := env.Compile(` + x.getFullYear() == 1970 + && y.getHours() == 2 + && y.getMinutes() == 120 + && y.getSeconds() == 7235 + && y.getMilliseconds() == 7235000`) if iss.Err() != nil { t.Fatalf("env.Compile() failed: %v", iss.Err()) } @@ -1720,12 +1731,17 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { if err != nil { t.Fatalf("env.Program() failed: %v", err) } - out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()}) + out, _, err := prg.Eval( + map[string]interface{}{ + "x": time.Unix(7506, 1000000).Local(), + "y": time.Duration(7235) * time.Second, + }, + ) if err != nil { t.Fatalf("prg.Eval() failed: %v", err) } if out != types.True { - t.Errorf("Eval() got %v, wanted true", out) + t.Errorf("Eval() got %v, wanted true", out.Value()) } } diff --git a/cel/decls.go b/cel/decls.go index 12c8f256..55532788 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -662,8 +662,9 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { return merged, nil } -// addOverload ensures that the new overload does not collide with an existing overload signature, -// nor does it redefine an existing overload binding. +// addOverload ensures that the new overload does not collide with an existing overload signature; +// however, if the function signatures are identical, the implementation may be rewritten as its +// difficult to compare functions by object identity. func (f *functionDecl) addOverload(overload *overloadDecl) error { for id, o := range f.overloads { if id != overload.id && o.signatureOverlaps(overload) { @@ -671,11 +672,8 @@ func (f *functionDecl) addOverload(overload *overloadDecl) error { } if id == overload.id { if o.signatureEquals(overload) && o.nonStrict == overload.nonStrict { - if !o.hasBinding() && overload.hasBinding() { - f.overloads[id] = overload - } else if o.hasBinding() && overload.hasBinding() && o != overload { - return fmt.Errorf("overload binding collision in function %s: %s has multiple bindings", f.name, o.id) - } + // Allow redefinition of an overload implementation so long as the signatures match. + f.overloads[id] = overload } else { return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id) } diff --git a/cel/decls_test.go b/cel/decls_test.go index b4289679..655b0ed9 100644 --- a/cel/decls_test.go +++ b/cel/decls_test.go @@ -147,10 +147,6 @@ func TestFunctionMerge(t *testing.T) { t.Errorf("prg.Eval() got %v, wanted %v", out, want) } - _, err = NewCustomEnv(vectorExt, vectorExt) - if err == nil || !strings.Contains(err.Error(), "overload binding collision") { - t.Errorf("NewCustomEnv(vectorExt, vectorExt) did not produce expected error: %v", err) - } _, err = NewCustomEnv(size, size) if err == nil || !strings.Contains(err.Error(), "already has a binding") { t.Errorf("NewCustomEnv(size, size) did not produce the expected error: %v", err) diff --git a/cel/library.go b/cel/library.go index 2b403599..5ca52845 100644 --- a/cel/library.go +++ b/cel/library.go @@ -85,7 +85,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption { type timeUTCLibrary struct{} func (timeUTCLibrary) CompileOptions() []EnvOption { - return timestampOverloadDeclarations + return timeOverloadDeclarations } func (timeUTCLibrary) ProgramOptions() []ProgramOption { @@ -97,7 +97,31 @@ func (timeUTCLibrary) ProgramOptions() []ProgramOption { var ( utcTZ = types.String("UTC") - timestampOverloadDeclarations = []EnvOption{ + timeOverloadDeclarations = []EnvOption{ + Function(overloads.TimeGetHours, + MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType, + UnaryBinding(func(dur ref.Val) ref.Val { + d := dur.(types.Duration) + return types.Int(d.Hours()) + }))), + Function(overloads.TimeGetMinutes, + MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType, + UnaryBinding(func(dur ref.Val) ref.Val { + d := dur.(types.Duration) + return types.Int(d.Minutes()) + }))), + Function(overloads.TimeGetSeconds, + MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType, + UnaryBinding(func(dur ref.Val) ref.Val { + d := dur.(types.Duration) + return types.Int(d.Seconds()) + }))), + Function(overloads.TimeGetMilliseconds, + MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType, + UnaryBinding(func(dur ref.Val) ref.Val { + d := dur.(types.Duration) + return types.Int(d.Milliseconds()) + }))), Function(overloads.TimeGetFullYear, MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val {