Skip to content

Commit

Permalink
Unify overflow checking in operators, field setting (#448)
Browse files Browse the repository at this point in the history
* Unify overflow checking in operators, field setting
* Additional overflow case for timestamp subtraction
* NaN checks and tests based on review feedback
  • Loading branch information
TristonianJones authored Sep 18, 2021
1 parent e7c178e commit 90c32ee
Show file tree
Hide file tree
Showing 18 changed files with 1,014 additions and 502 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ go_repository(
# CEL Spec deps
go_repository(
name = "com_google_cel_spec",
commit = "1a26bb4e2a611b694367e7579e74b68b17ebc536",
commit = "ad5c42c7f0a66f7ea43bd7299c0397ceef23beb5",
importpath = "github.com/google/cel-spec",
)

Expand Down
29 changes: 14 additions & 15 deletions common/types/double.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package types

import (
"fmt"
"math"
"reflect"

"github.com/google/cel-go/common/types/ref"
Expand Down Expand Up @@ -52,7 +51,7 @@ var (
func (d Double) Add(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
return d + otherDouble
}
Expand All @@ -61,7 +60,7 @@ func (d Double) Add(other ref.Val) ref.Val {
func (d Double) Compare(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
if d < otherDouble {
return IntNegOne
Expand Down Expand Up @@ -127,17 +126,17 @@ func (d Double) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case IntType:
i := math.Round(float64(d))
if i > math.MaxInt64 || i < math.MinInt64 {
return NewErr("range error converting %g to int", float64(d))
i, err := doubleToInt64Checked(float64(d))
if err != nil {
return wrapErr(err)
}
return Int(float64(i))
return Int(i)
case UintType:
i := math.Round(float64(d))
if i > math.MaxUint64 || i < 0 {
return NewErr("range error converting %g to int", float64(d))
i, err := doubleToUint64Checked(float64(d))
if err != nil {
return wrapErr(err)
}
return Uint(float64(i))
return Uint(i)
case DoubleType:
return d
case StringType:
Expand All @@ -152,7 +151,7 @@ func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
func (d Double) Divide(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
return d / otherDouble
}
Expand All @@ -161,7 +160,7 @@ func (d Double) Divide(other ref.Val) ref.Val {
func (d Double) Equal(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
// TODO: Handle NaNs properly.
return Bool(d == otherDouble)
Expand All @@ -171,7 +170,7 @@ func (d Double) Equal(other ref.Val) ref.Val {
func (d Double) Multiply(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
return d * otherDouble
}
Expand All @@ -185,7 +184,7 @@ func (d Double) Negate() ref.Val {
func (d Double) Subtract(subtrahend ref.Val) ref.Val {
subtraDouble, ok := subtrahend.(Double)
if !ok {
return ValOrErr(subtrahend, "no such overload")
return MaybeNoSuchOverloadErr(subtrahend)
}
return d - subtraDouble
}
Expand Down
124 changes: 107 additions & 17 deletions common/types/double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package types

import (
"errors"
"math"
"reflect"
"strings"
"testing"

"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"

anypb "google.golang.org/protobuf/types/known/anypb"
Expand Down Expand Up @@ -167,23 +170,110 @@ func TestDoubleConvertToNative_Wrapper(t *testing.T) {
}

func TestDoubleConvertToType(t *testing.T) {
if !Double(-4.5).ConvertToType(IntType).Equal(Int(-5)).(Bool) {
t.Error("Unsuccessful type conversion to int")
}
if !IsError(Double(-4.5).ConvertToType(UintType)) {
t.Error("Got uint, expected error")
}
if !Double(-4.5).ConvertToType(DoubleType).Equal(Double(-4.5)).(Bool) {
t.Error("Unsuccessful type conversion to double")
}
if !Double(-4.5).ConvertToType(StringType).Equal(String("-4.5")).(Bool) {
t.Error("Unsuccessful type conversion to string")
}
if !Double(-4.5).ConvertToType(TypeType).Equal(DoubleType).(Bool) {
t.Error("Unsuccessful type conversion to type")
}
if !IsError(Double(-4.5).ConvertToType(TimestampType)) {
t.Error("Got value, expected error")
tests := []struct {
name string
in float64
toType ref.Type
out interface{}
}{
{
name: "DoubleToDouble",
in: float64(-4.2),
toType: DoubleType,
out: float64(-4.2),
},
{
name: "DoubleToType",
in: float64(-4.2),
toType: TypeType,
out: DoubleType.TypeName(),
},
{
name: "DoubleToInt",
in: float64(4.2),
toType: IntType,
out: int64(4),
},
{
name: "DoubleToIntNaN",
in: math.NaN(),
toType: IntType,
out: errIntOverflow,
},
{
name: "DoubleToIntPosInf",
in: math.Inf(1),
toType: IntType,
out: errIntOverflow,
},
{
name: "DoubleToIntPosOverflow",
in: float64(math.MaxInt64),
toType: IntType,
out: errIntOverflow,
},
{
name: "DoubleToIntNegOverflow",
in: float64(math.MinInt64),
toType: IntType,
out: errIntOverflow,
},
{
name: "DoubleToUint",
in: float64(4.7),
toType: UintType,
out: uint64(4),
},
{
name: "DoubleToUintNaN",
in: math.NaN(),
toType: UintType,
out: errUintOverflow,
},
{
name: "DoubleToUintPosInf",
in: math.Inf(1),
toType: UintType,
out: errUintOverflow,
},
{
name: "DoubleToUintPosOverflow",
in: float64(math.MaxUint64),
toType: UintType,
out: errUintOverflow,
},
{
name: "DoubleToUintNegOverflow",
in: float64(-0.1),
toType: UintType,
out: errUintOverflow,
},
{
name: "DoubleToString",
in: float64(4.5),
toType: StringType,
out: "4.5",
},
{
name: "DoubleToUnsupportedType",
in: float64(4),
toType: MapType,
out: errors.New("type conversion error"),
},
}
for _, tst := range tests {
got := Double(tst.in).ConvertToType(tst.toType).Value()
var eq bool
switch gotVal := got.(type) {
case error:
eq = strings.Contains(gotVal.Error(), tst.out.(error).Error())
default:
eq = reflect.DeepEqual(gotVal, tst.out)
}
if !eq {
t.Errorf("Double(%v).ConvertToType(%v) failed, got: %v, wanted: %v",
tst.in, tst.toType, got, tst.out)
}
}
}

Expand Down
30 changes: 17 additions & 13 deletions common/types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,27 @@ func (d Duration) Add(other ref.Val) ref.Val {
switch other.Type() {
case DurationType:
dur2 := other.(Duration)
if val, ok := addDurationChecked(d.Duration, dur2.Duration); ok {
if val, err := addDurationChecked(d.Duration, dur2.Duration); err != nil {
return wrapErr(err)
} else {
return durationOf(val)
}
return errDurationOverflow
case TimestampType:
ts := other.(Timestamp).Time
if val, ok := addTimeDurationChecked(ts, d.Duration); ok {
if val, err := addTimeDurationChecked(ts, d.Duration); err != nil {
return wrapErr(err)
} else {
return timestampOf(val)
}
return errTimestampOverflow
}
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}

// Compare implements traits.Comparer.Compare.
func (d Duration) Compare(other ref.Val) ref.Val {
otherDur, ok := other.(Duration)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
d1 := d.Duration
d2 := otherDur.Duration
Expand Down Expand Up @@ -134,17 +136,18 @@ func (d Duration) ConvertToType(typeVal ref.Type) ref.Val {
func (d Duration) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(Duration)
if !ok {
return ValOrErr(other, "no such overload")
return MaybeNoSuchOverloadErr(other)
}
return Bool(d.Duration == otherDur.Duration)
}

// Negate implements traits.Negater.Negate.
func (d Duration) Negate() ref.Val {
if val, ok := negateDurationChecked(d.Duration); ok {
if val, err := negateDurationChecked(d.Duration); err != nil {
return wrapErr(err)
} else {
return durationOf(val)
}
return errDurationOverflow
}

// Receive implements traits.Receiver.Receive.
Expand All @@ -154,19 +157,20 @@ func (d Duration) Receive(function string, overload string, args []ref.Val) ref.
return f(d.Duration)
}
}
return NewErr("no such overload")
return NoSuchOverloadErr()
}

// Subtract implements traits.Subtractor.Subtract.
func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
subtraDur, ok := subtrahend.(Duration)
if !ok {
return ValOrErr(subtrahend, "no such overload")
return MaybeNoSuchOverloadErr(subtrahend)
}
if val, ok := subtractDurationChecked(d.Duration, subtraDur.Duration); ok {
if val, err := subtractDurationChecked(d.Duration, subtraDur.Duration); err != nil {
return wrapErr(err)
} else {
return durationOf(val)
}
return errDurationOverflow
}

// Type implements ref.Val.Type.
Expand Down
Loading

0 comments on commit 90c32ee

Please sign in to comment.