Skip to content

Commit

Permalink
Fix for working with byte arrays (#958)
Browse files Browse the repository at this point in the history
Fix for conversion of bytes to a [n]byte array
  • Loading branch information
TristonianJones authored Jun 10, 2024
1 parent e31b401 commit b66ac6c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 15 deletions.
12 changes: 11 additions & 1 deletion common/types/bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ func (b Bytes) Compare(other ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method.
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Array, reflect.Slice:
case reflect.Array:
if len(b) != typeDesc.Len() {
return nil, fmt.Errorf("[%d]byte not assignable to [%d]byte array", len(b), typeDesc.Len())
}
refArrPtr := reflect.New(reflect.ArrayOf(len(b), typeDesc.Elem()))
refArr := refArrPtr.Elem()
for i, byt := range b {
refArr.Index(i).Set(reflect.ValueOf(byt).Convert(typeDesc.Elem()))
}
return refArr.Interface(), nil
case reflect.Slice:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
switch typeDesc {
Expand Down
18 changes: 18 additions & 0 deletions common/types/bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package types
import (
"bytes"
"reflect"
"strings"
"testing"

"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -71,6 +72,23 @@ func TestBytesConvertToNative_ByteSlice(t *testing.T) {
}
}

func TestBytesConvertToNative_ByteArray(t *testing.T) {
val, err := Bytes("123").ConvertToNative(reflect.TypeOf([3]byte{}))
if err != nil {
t.Error("Got unexpected value, wanted []byte{49, 50, 51}", err, val)
}
if val.([3]byte) != [3]byte{49, 50, 51} {
t.Errorf("Got %v, wanted [3]byte{49, 50, 51}", val)
}
}

func TestBytesConvertToNative_ByteArrayError(t *testing.T) {
_, err := Bytes("123").ConvertToNative(reflect.TypeOf([1]byte{}))
if !strings.Contains(err.Error(), "[3]byte not assignable to [1]byte") {
t.Errorf("Got unexpected error %v, wanted not assignable error", err)
}
}

func TestBytesConvertToNative_Error(t *testing.T) {
val, err := Bytes("123").ConvertToNative(reflect.TypeOf(""))
if err == nil {
Expand Down
8 changes: 8 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,14 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
refKind := refValue.Kind()
switch refKind {
case reflect.Array, reflect.Slice:
if refValue.Type().Elem() == reflect.TypeOf(byte(0)) {
if refValue.CanAddr() {
return Bytes(refValue.Bytes()), true
}
tmp := reflect.New(refValue.Type())
tmp.Elem().Set(refValue)
return Bytes(tmp.Elem().Bytes()), true
}
return NewDynamicList(a, v), true
case reflect.Map:
return NewDynamicMap(a, v), true
Expand Down
2 changes: 2 additions & 0 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ func TestNativeToValue_Primitive(t *testing.T) {
expectNativeToValue(t, float64(-5.5), Double(-5.5))
expectNativeToValue(t, "hello", String("hello"))
expectNativeToValue(t, []byte("world"), Bytes("world"))
expectNativeToValue(t, [4]byte{1, 2, 3, 4}, Bytes([]byte{1, 2, 3, 4}))
expectNativeToValue(t, &[4]byte{1, 2, 3, 4}, Bytes([]byte{1, 2, 3, 4}))
expectNativeToValue(t, time.Duration(500), Duration{Duration: time.Duration(500)})
expectNativeToValue(t, time.Unix(12345, 0), Timestamp{Time: time.Unix(12345, 0)})
expectNativeToValue(t, dpb.New(time.Duration(500)), Duration{Duration: time.Duration(500)})
Expand Down
3 changes: 3 additions & 0 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
case []byte:
return tp.baseAdapter.NativeToValue(val)
default:
if refVal.Type().Elem() == reflect.TypeOf(byte(0)) {
return tp.baseAdapter.NativeToValue(val)
}
return types.NewDynamicList(tp, val)
}
case reflect.Map:
Expand Down
55 changes: 41 additions & 14 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ func TestNativeTypes(t *testing.T) {
{expr: `[TestAllTypes{BoolVal: true}, TestAllTypes{BoolVal: false}].exists(t, t.BoolVal == true)`},
{expr: `[TestAllTypes{CustomName: 'Alice'}, TestAllTypes{CustomName: 'Bob'}].exists(t, t.CustomName == 'Alice')`},
{expr: `[TestAllTypes{custom_name: 'Alice'}, TestAllTypes{custom_name: 'Bob'}].exists(t, t.custom_name == 'Alice')`, envOpts: []any{ParseStructTags(true)}},
{expr: `TestAllTypes{BytesArrayVal: b'1234'}.BytesArrayVal != b'123'`},
{expr: `TestAllTypes{BytesArrayVal: b'1234'}.BytesArrayVal == b'1234'`},
{
expr: `tests.all(t, t.Int32Val > 17)`,
in: map[string]any{
Expand Down Expand Up @@ -577,37 +579,61 @@ func TestNativeTypesConvertToNative(t *testing.T) {
env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{})))
adapter := env.CELTypeAdapter()
conversions := []struct {
in any
out any
err string
in any
inType *cel.Type
out any
err string
}{
{
in: &TestAllTypes{BoolVal: true},
out: &TestAllTypes{BoolVal: true},
in: &TestAllTypes{BoolVal: true},
inType: cel.ObjectType("ext.TestAllTypes"),
out: &TestAllTypes{BoolVal: true},
},
{
in: TestAllTypes{BoolVal: true},
out: &TestAllTypes{BoolVal: true},
in: TestAllTypes{BoolVal: true},
inType: cel.ObjectType("ext.TestAllTypes"),
out: &TestAllTypes{BoolVal: true},
},
{
in: &TestAllTypes{BoolVal: true},
out: TestAllTypes{BoolVal: true},
in: &TestAllTypes{BoolVal: true},
inType: cel.ObjectType("ext.TestAllTypes"),
out: TestAllTypes{BoolVal: true},
},
{
in: nil,
out: types.NullValue,
in: nil,
inType: cel.NullType,
out: types.NullValue,
},
{
in: &TestAllTypes{BoolVal: true},
out: &proto3pb.TestAllTypes{},
err: "type conversion error",
in: &TestAllTypes{BoolVal: true},
inType: cel.ObjectType("ext.TestAllTypes"),
out: &proto3pb.TestAllTypes{},
err: "type conversion error",
},
{
in: [3]int32{1, 2, 3},
inType: cel.ListType(cel.IntType),
out: []int32{1, 2, 3},
},
{
in: &[3]byte{1, 2, 3},
inType: cel.BytesType,
out: []byte{1, 2, 3},
},
{
in: [3]byte{1, 2, 3},
inType: cel.BytesType,
out: []byte{1, 2, 3},
},
}
for _, c := range conversions {
inVal := adapter.NativeToValue(c.in)
if types.IsError(inVal) {
t.Fatalf("adapter.NativeToValue(%v) failed: %v", c.in, inVal)
}
if inVal.Type().TypeName() != c.inType.TypeName() {
t.Fatalf("adapter.NativeToValue() got type %v, wanted type %v", inVal.Type(), c.inType)
}
out, err := inVal.ConvertToNative(reflect.TypeOf(c.out))
if err != nil {
if c.err != "" {
Expand Down Expand Up @@ -848,6 +874,7 @@ type TestAllTypes struct {
Uint64Val uint64
ListVal []*TestNestedType
ArrayVal [1]*TestNestedType
BytesArrayVal [4]byte
MapVal map[string]TestAllTypes
PbVal *proto3pb.TestAllTypes
CustomSliceVal []TestNestedSliceType
Expand Down

0 comments on commit b66ac6c

Please sign in to comment.